diff --git a/.gitignore b/.gitignore index d11a504bdc56ee98b3d5a0c33f9f75d996e45567..be75938ec401b1d72fa54773c85191aaac7d7f35 100644 --- a/.gitignore +++ b/.gitignore @@ -6,7 +6,7 @@ node_modules /bazel-* /bazel_pip /tools/python_bin_path.sh -/tools/git/gen +/tensorflow/tools/git/gen /pip_test /_python_build *.pyc @@ -26,4 +26,11 @@ Podfile.lock /tensorflow/contrib/lite/gen/** /tensorflow/contrib/lite/examples/ios/simple/data/*.txt /tensorflow/contrib/lite/examples/ios/simple/data/*.tflite -xcuserdata/** \ No newline at end of file +xcuserdata/** + +# Android +.gradle +.idea +*.iml +local.properties +gradleBuild diff --git a/BUILD b/BUILD index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4bf647e47aa56cff0b3fd5af7d5df99d8b70549b 100644 --- a/BUILD +++ b/BUILD @@ -0,0 +1,6 @@ +exports_files( + [ + "LICENSE", + "ACKNOWLEDGEMENTS", + ], +) diff --git a/CODEOWNERS b/CODEOWNERS index 57a4df40e651f45dc03493af631d73332e46c182..007a304c3e706ce968576ec8979c08f1a3bcc552 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,53 +1,53 @@ # NOTE: Disabled temporarily because it's too noisy on pushes. # Where component owners are known, add them here. -#tensorflow/core/platform/windows/* @mrry -#tensorflow/java/* @asimshankar -#tensorflow/tensorboard/* @jart @dandelionmane -#tensorflow/tools/docs/* @markdaoust +# /tensorflow/core/platform/windows/ @mrry +# /tensorflow/java/ @asimshankar +# /tensorflow/tensorboard/ @jart @dandelionmane +# /tensorflow/tools/docs/ @markdaoust # contrib -# NEED OWNER: tensorflow/contrib/avro/* -#tensorflow/contrib/batching/* @alextp @chrisolston -#tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon -#tensorflow/contrib/boosted_trees/* @sshrdp @yk5 @nataliaponomareva -#tensorflow/contrib/cmake/* @mrry @benoitsteiner -#tensorflow/contrib/copy_graph/* @tucker @poxvoculi -#tensorflow/contrib/crf/* @kentonl -#tensorflow/contrib/data/* @mrry -#tensorflow/contrib/distributions/* @jvdillon @langmore @rsepassi -#tensorflow/contrib/factorization/* @agarwal-ashish @xavigonzalvo -#tensorflow/contrib/ffmpeg/* @fredbertsch -# NEED OWNER: tensorflow/contrib/framework/* -#tensorflow/contrib/graph_editor/* @purpledog -# NEED OWNER: tensorflow/contrib/grid_rnn/* -#tensorflow/contrib/hvx/* @satok16 -#tensorflow/contrib/integrate/* @shoyer -#tensorflow/contrib/kernel_methods/* @petrosmol -#tensorflow/contrib/ios_examples/* @petewarden -#tensorflow/contrib/labeled_tensor/* @shoyer -#tensorflow/contrib/layers/* @fchollet @martinwicke -#tensorflow/contrib/learn/* @martinwicke @ispirmustafa @alextp -#tensorflow/contrib/linalg/* @langmore -#tensorflow/contrib/linear_optimizer/* @petrosmol @andreasst @katsiapis -#tensorflow/contrib/lookup/* @ysuematsu @andreasst -#tensorflow/contrib/losses/* @alextp @ispirmustafa -#tensorflow/contrib/makefile/* @petewarden @satok16 @wolffg -#tensorflow/contrib/metrics/* @alextp @honkentuber @ispirmustafa -#tensorflow/contrib/nccl/* @cwhipkey @zheng-xq -#tensorflow/contrib/opt/* @strategist333 -#tensorflow/contrib/pi_examples/* @maciekcc -#tensorflow/contrib/quantization/* @petewarden @cwhipkey @keveman -#tensorflow/contrib/rnn/* @ebrevdo -#tensorflow/contrib/saved_model/* @nfiedel @sukritiramesh -#tensorflow/contrib/seq2seq/* @lukaszkaiser -#tensorflow/contrib/session_bundle/* @nfiedel @sukritiramesh -#tensorflow/contrib/slim/* @sguada @thenbasilmanran -#tensorflow/contrib/stateless/* @girving -#tensorflow/contrib/tensor_forest/* @gilberthendry @thomascolthurst -#tensorflow/contrib/testing/* @dandelionmane -#tensorflow/contrib/timeseries/* @allenlavoie -#tensorflow/contrib/tpu/* @frankchn @saeta @jhseu -#tensorflow/contrib/training/* @joel-shor @ebrevdo -#tensorflow/contrib/util/* @sherrym +# NEED OWNER: /tensorflow/contrib/avro/ +# /tensorflow/contrib/batching/ @alextp @chrisolston +# /tensorflow/contrib/bayesflow/ @ebrevdo @rsepassi @jvdillon +# /tensorflow/contrib/boosted_trees/ @sshrdp @yk5 @nataliaponomareva +# /tensorflow/contrib/cmake/ @mrry @benoitsteiner +# /tensorflow/contrib/copy_graph/ @tucker @poxvoculi +# /tensorflow/contrib/crf/ @kentonl +# /tensorflow/contrib/data/ @mrry +# /tensorflow/contrib/distributions/ @jvdillon @langmore @rsepassi +# /tensorflow/contrib/factorization/ @agarwal-ashish @xavigonzalvo +# /tensorflow/contrib/ffmpeg/ @fredbertsch +# NEED OWNER: /tensorflow/contrib/framework/ +# /tensorflow/contrib/graph_editor/ @purpledog +# NEED OWNER: /tensorflow/contrib/grid_rnn/ +# /tensorflow/contrib/hvx/ @satok16 +# /tensorflow/contrib/integrate/ @shoyer +# /tensorflow/contrib/kernel_methods/ @petrosmol +# /tensorflow/contrib/ios_examples/ @petewarden +# /tensorflow/contrib/labeled_tensor/ @shoyer +# /tensorflow/contrib/layers/ @fchollet @martinwicke +# /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp +# /tensorflow/contrib/linalg/ @langmore +# /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis +# /tensorflow/contrib/lookup/ @ysuematsu @andreasst +# /tensorflow/contrib/losses/ @alextp @ispirmustafa +# /tensorflow/contrib/makefile/ @petewarden @satok16 @wolffg +# /tensorflow/contrib/metrics/ @alextp @honkentuber @ispirmustafa +# /tensorflow/contrib/nccl/ @cwhipkey @zheng-xq +# /tensorflow/contrib/opt/ @strategist333 +# /tensorflow/contrib/pi_examples/ @maciekcc +# /tensorflow/contrib/quantization/ @petewarden @cwhipkey @keveman +# /tensorflow/contrib/rnn/ @ebrevdo +# /tensorflow/contrib/saved_model/ @nfiedel @sukritiramesh +# /tensorflow/contrib/seq2seq/ @lukaszkaiser +# /tensorflow/contrib/session_bundle/ @nfiedel @sukritiramesh +# /tensorflow/contrib/slim/ @sguada @thenbasilmanran +# /tensorflow/contrib/stateless/ @girving +# /tensorflow/contrib/tensor_forest/ @gilberthendry @thomascolthurst +# /tensorflow/contrib/testing/ @dandelionmane +# /tensorflow/contrib/timeseries/ @allenlavoie +# /tensorflow/contrib/tpu/ @frankchn @saeta @jhseu +# /tensorflow/contrib/training/ @joel-shor @ebrevdo +# /tensorflow/contrib/util/ @sherrym diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index ff11d131409b65880f16b80f9fe38dc39ac0e5fa..5fff9d05a1c589636bc9c711e6eb7cc4aba86b2f 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -67,4 +67,4 @@ If the Project Stewards receive a report alleging a violation of the Code of Con ## Attribution -This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at http://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct. +This Code of Conduct is adapted from the Contributor Covenant, version 1.4, available at https://contributor-covenant.org/version/1/4, and includes some aspects of the Geek Feminism Code of Conduct and the Drupal Code of Conduct. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 1b537ca73cc94e992e7537fe69c8d0cc8fd13102..3dad41a88c8212b7445c32f241d887306d3c19ad 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -8,8 +8,8 @@ We'd love to accept your patches! Before we can take them, we have to jump a cou Please fill out either the individual or corporate Contributor License Agreement (CLA). - * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](http://code.google.com/legal/individual-cla-v1.0.html). - * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](http://code.google.com/legal/corporate-cla-v1.0.html). + * If you are an individual writing original source code and you're sure you own the intellectual property, then you'll need to sign an [individual CLA](https://code.google.com/legal/individual-cla-v1.0.html). + * If you work for a company that wants to allow you to contribute your work, then you'll need to sign a [corporate CLA](https://code.google.com/legal/corporate-cla-v1.0.html). Follow either of the two links above to access the appropriate CLA and instructions for how to sign and return it. Once we receive it, we'll be able to accept your pull requests. @@ -20,6 +20,9 @@ Follow either of the two links above to access the appropriate CLA and instructi If you have improvements to TensorFlow, send us your pull requests! For those just getting started, Github has a [howto](https://help.github.com/articles/using-pull-requests/). +TensorFlow team members will be assigned to review your pull requests. Once the pull requests are approved and pass continuous integration checks, we will merge the pull requests. +For some pull requests, we will apply the patch for each pull request to our internal version control system first, and export the change out as a new commit later, at which point the original pull request will be closed. The commits in the pull request will be squashed into a single commit with the pull request creator as the author. These pull requests will be labeled as pending merge internally. + If you want to contribute but you're not sure where to start, take a look at the [issues with the "contributions welcome" label](https://github.com/tensorflow/tensorflow/labels/stat%3Acontributions%20welcome). These are issues that we believe are particularly well suited for outside @@ -38,7 +41,7 @@ TensorFlow coding style. #### General guidelines and philosophy for contribution * Include unit tests when you contribute new features, as they help to - a) prove that your code works correctly, b) guard against future breaking + a) prove that your code works correctly, and b) guard against future breaking changes to lower the maintenance cost. * Bug fixes also generally require unit tests, because the presence of bugs usually indicates insufficient test coverage. @@ -48,7 +51,7 @@ TensorFlow coding style. non-backward-compatible API changes without a major release. Reviewers of your pull request will comment on any API compatibility issues. * When you contribute a new feature to TensorFlow, the maintenance burden is (by - default) transferred to the TensorFlow team. This means that benefit of + default) transferred to the TensorFlow team. This means that benefit of the contribution must be compared against the cost of maintaining the feature. * Full new features (e.g., a new op implementing a cutting-edge algorithm) typically will live in @@ -65,8 +68,8 @@ Include a license at the top of new files. * [Java license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/java/src/main/java/org/tensorflow/Graph.java#L1) * [Go license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/go/operation.go#L1) * [Bash license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/ci_build/ci_sanity.sh#L2) -* [HTML license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/dist/index.html#L2) -* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/components/tf_backend/backend.ts#L1) +* [HTML license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/tf-backend.html#L2) +* [JavaScript/TypeScript license example](https://github.com/tensorflow/tensorboard/blob/master/tensorboard/components/tf_backend/backend.ts#L1) Bazel BUILD files also need to include a license section, e.g., [BUILD example](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/BUILD#L61). @@ -114,7 +117,7 @@ pylint --rcfile=/tmp/pylintrc myfile.py * [Google Java Style Guide](https://google.github.io/styleguide/javaguide.html) * [Google JavaScript Style Guide](https://google.github.io/styleguide/jsguide.html) * [Google Shell Style Guide](https://google.github.io/styleguide/shell.xml) -* [Google Objective-C Style Guide](http://google.github.io/styleguide/objcguide.html) +* [Google Objective-C Style Guide](https://google.github.io/styleguide/objcguide.html) #### Running sanity check @@ -160,7 +163,7 @@ There are two ways to run TensorFlow unit tests. bazel test ${flags} //tensorflow/python/... ``` -2. Using [Docker](www.docker.com) and TensorFlow's CI scripts. +2. Using [Docker](https://www.docker.com) and TensorFlow's CI scripts. ```bash # Install Docker first, then this will build and run cpu tests diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 1a401997c649518766acb2ebb0dea1c128bd0ba4..2f3df7cda9cec29ed0c2266629022f0a22b37df9 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -4,7 +4,7 @@ https://stackoverflow.com/questions/tagged/tensorflow If you open a GitHub issue, here is our policy: -1. It must be a bug or a feature request. +1. It must be a bug, a feature request, or a significant problem with documentation (for small docs fixes please send a PR instead). 2. The form below must be filled out. 3. It shouldn't be a TensorBoard issue. Those go [here](https://github.com/tensorflow/tensorboard/issues). diff --git a/LICENSE b/LICENSE index 15ae42140452d32ccf929f59f7eca01a3c7b555f..4862420c0234f7542d4fe8f3520516b484a64aed 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -Copyright 2017 The TensorFlow Authors. All rights reserved. +Copyright 2018 The TensorFlow Authors. All rights reserved. Apache License Version 2.0, January 2004 diff --git a/README.md b/README.md index aff3427bddb307aea6d6c2466eac14c9edffcc32..916e5200b29841028652c861c49dbb3650baea3c 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ | **`Linux CPU`** | **`Linux GPU`** | **`Mac OS CPU`** | **`Windows CPU`** | **`Android`** | |-----------------|---------------------|------------------|-------------------|---------------| -| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) | +| [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-cpu)](https://ci.tensorflow.org/job/tensorflow-master-cpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-linux-gpu)](https://ci.tensorflow.org/job/tensorflow-master-linux-gpu) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-mac)](https://ci.tensorflow.org/job/tensorflow-master-mac) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-win-cmake-py)](https://ci.tensorflow.org/job/tensorflow-master-win-cmake-py) | [![Build Status](https://ci.tensorflow.org/buildStatus/icon?job=tensorflow-master-android)](https://ci.tensorflow.org/job/tensorflow-master-android) [ ![Download](https://api.bintray.com/packages/google/tensorflow/tensorflow/images/download.svg) ](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) | **TensorFlow** is an open source software library for numerical computation using data flow graphs. The graph nodes represent mathematical operations, while @@ -27,10 +27,14 @@ guidelines](CONTRIBUTING.md). This project adheres to TensorFlow's uphold this code.** **We use [GitHub issues](https://github.com/tensorflow/tensorflow/issues) for -tracking requests and bugs. So please see +tracking requests and bugs. So please see [TensorFlow Discuss](https://groups.google.com/a/tensorflow.org/forum/#!forum/discuss) for general questions and discussion, and please direct specific questions to [Stack Overflow](https://stackoverflow.com/questions/tagged/tensorflow).** +The TensorFlow project strives to abide by generally accepted best practices in open-source software development: + +[![CII Best Practices](https://bestpractices.coreinfrastructure.org/projects/1486/badge)](https://bestpractices.coreinfrastructure.org/projects/1486) + ## Installation *See [Installing TensorFlow](https://www.tensorflow.org/get_started/os_setup.html) for instructions on how to install our release binaries or how to build from source.* @@ -46,11 +50,11 @@ packages on Linux, Mac, and Windows. **Individual whl files** -* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) -* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) +* Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=cpu-slave/)) +* Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) / [Python 3.6](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp36-cp36m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.6,label=gpu-linux/)) * Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) +* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) +* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) * Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) diff --git a/RELEASE.md b/RELEASE.md index e04bd3fc505d51ade9e9fa12c822cb695e90b4f3..0720a8c639f8ab87214b11f6a8092b432b916853 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,306 @@ +# Release 1.6.0 + +## Breaking Changes +* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7. +* Prebuilt binaries will use AVX instructions. This may break TF on older CPUs. + +## Major Features And Improvements +* New Optimizer internal API for non-slot variables. Descendants of AdamOptimizer that access _beta[12]_power will need to be updated. +* `tf.estimator.{FinalExporter,LatestExporter}` now export stripped SavedModels. This improves forward compatibility of the SavedModel. +* FFT support added to XLA CPU/GPU. + +## Bug Fixes and Other Changes +* Documentation updates: + * Added a second version of Getting Started, which is aimed at ML +newcomers. + * Clarified documentation on `resize_images.align_corners` parameter. + * Additional documentation for TPUs. +* Google Cloud Storage (GCS): + * Add client-side throttle. + * Add a `FlushCaches()` method to the FileSystem interface, with an implementation for GcsFileSystem. +* Other: + * Add `tf.contrib.distributions.Kumaraswamy`. + * `RetryingFileSystem::FlushCaches()` calls the base FileSystem's `FlushCaches()`. + * Add auto_correlation to distributions. + * Add `tf.contrib.distributions.Autoregressive`. + * Add SeparableConv1D layer. + * Add convolutional Flipout layers. + * When both inputs of `tf.matmul` are bfloat16, it returns bfloat16, instead of float32. + * Added `tf.contrib.image.connected_components`. + * Add `tf.contrib.framework.CriticalSection` that allows atomic variable access. + * Output variance over trees predictions for classifications tasks. + * For `pt` and `eval` commands, allow writing tensor values to filesystem as numpy files. + * gRPC: Propagate truncated errors (instead of returning gRPC internal error). + * Augment parallel_interleave to support 2 kinds of prefetching. + * Improved XLA support for C64-related ops log, pow, atan2, tanh. + * Add probabilistic convolutional layers. + +## API Changes +* Introducing prepare_variance boolean with default setting to False for backward compatibility. +* Move `layers_dense_variational_impl.py` to `layers_dense_variational.py`. + +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Ag Ramesh, Aiden Scandella, Akimasa Kimura, Alex Rothberg, Allen Goodman, +amilioto, Andrei Costinescu, Andrei Nigmatulin, Anjum Sayed, Anthony Platanios, +Anush Elangovan, Armando Fandango, Ashish Kumar Ram, Ashwini Shukla, Ben, Bhavani Subramanian, +Brett Koonce, Carl Thomé, cclauss, Cesc, Changming Sun, Christoph Boeddeker, Clayne Robison, +Clemens Schulz, Clint (Woonhyuk Baek), codrut3, Cole Gerdemann, Colin Raffel, Daniel Trebbien, +Daniel Ylitalo, Daniel Zhang, Daniyar, Darjan Salaj, Dave Maclachlan, David Norman, Dong--Jian, +dongsamb, dssgsra, Edward H, eladweiss, elilienstein, Eric Lilienstein, error.d, Eunji Jeong, fanlu, +Florian Courtial, fo40225, Fred, Gregg Helt, Guozhong Zhuang, Hanchen Li, hsm207, hyunyoung2, +ImSheridan, Ishant Mrinal Haloi, Jacky Ko, Jay Young, Jean Flaherty, Jerome, JerrikEph, Jesse +Kinkead, jfaath, Jian Lin, jinghuangintel, Jiongyan Zhang, Joel Hestness, Joel Shor, Johnny Chan, +Julian Niedermeier, Julian Wolff, JxKing, K-W-W, Karl Lessard, Kasper Marstal, Keiji Ariyama, +Koan-Sin Tan, Loki Der Quaeler, Loo Rong Jie, Luke Schaefer, Lynn Jackson, ManHyuk, Matt Basta, +Matt Smith, Matthew Schulkind, Michael, michaelkhan3, Miguel Piedrafita, Mikalai Drabovich, +Mike Knapp, mjwen, mktozk, Mohamed Aly, Mohammad Ashraf Bhuiyan, Myungjoo Ham, Naman Bhalla, +Namrata-Ibm, Nathan Luehr, nathansilberman, Netzeband, Niranjan Hasabnis, Omar Aflak, Ozge +Yalcinkaya, Parth P Panchal, patrickzzy, Patryk Chrabaszcz, Paul Van Eck, Paweł Kapica, Peng Yu, +Philip Yang, Pierre Blondeau, Po-Hsien Chu, powderluv, Puyu Wang, Rajendra Arora, Rasmus, Renat +Idrisov, resec, Robin Richtsfeld, Ronald Eddy Jr, Sahil Singh, Sam Matzek, Sami Kama, sandipmgiri, +Santiago Castro, Sayed Hadi Hashemi, Scott Tseng, Sergii Khomenko, Shahid, Shengpeng Liu, Shreyash +Sharma, Shrinidhi Kl, Simone Cirillo, simsicon, Stanislav Levental, starsblinking, Stephen Lumenta, +Steven Hickson, Su Tang, Taehoon Lee, Takuya Wakisaka, Ted Chang, Ted Ying, Tijmen Verhulsdonck, +Timofey Kondrashov, vade, vaibhav, Valentin Khrulkov, vchigrin, Victor Costan, Viraj Navkal, +Vivek Rane, wagonhelm, Yan Facai (颜发才), Yanbo Liang, Yaroslav Bulatov, yegord, Yong Tang, +Yoni Tsafir, yordun, Yuan (Terry) Tang, Yuxin Wu, zhengdi, Zhengsheng Wei, 田传武 + +# Release 1.5.0 + +## Breaking Changes +* Prebuilt binaries are now built against CUDA 9.0 and cuDNN 7. +* Starting from 1.6 release, our prebuilt binaries will use AVX instructions. + This may break TF on older CPUs. + +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + +## Major Features And Improvements +* [Eager execution](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/eager) + preview version is now available. +* [TensorFlow Lite](https://github.com/tensorflow/tensorflow/tree/r1.5/tensorflow/contrib/lite) + dev preview is now available. +* CUDA 9.0 and cuDNN 7 support. +* Accelerated Linear Algebra (XLA): + * Add `complex64` support to XLA compiler. + * `bfloat` support is now added to XLA infrastructure. + * Make `ClusterSpec` propagation work with XLA devices. + * Use a determinisitic executor to generate XLA graph. +* `tf.contrib`: + * `tf.contrib.distributions`: + * Add `tf.contrib.distributions.Autoregressive`. + * Make `tf.contrib.distributions` QuadratureCompound classes support batch + * Infer `tf.contrib.distributions.RelaxedOneHotCategorical` `dtype` from arguments. + * Make `tf.contrib.distributions` quadrature family parameterized by + `quadrature_grid_and_prob` vs `quadrature_degree`. + * `auto_correlation` added to `tf.contrib.distributions` + * Add `tf.contrib.bayesflow.layers`, a collection of probabilistic (neural) layers. + * Add `tf.contrib.bayesflow.halton_sequence`. + * Add `tf.contrib.data.make_saveable_from_iterator.` + * Add `tf.contrib.data.shuffle_and_repeat`. + * Add new custom transformation: `tf.contrib.data.scan()`. + * `tf.contrib.distributions.bijectors`: + * Add `tf.contrib.distributions.bijectors.MaskedAutoregressiveFlow`. + * Add `tf.contrib.distributions.bijectors.Permute`. + * Add `tf.contrib.distributions.bijectors.Gumbel`. + * Add `tf.contrib.distributions.bijectors.Reshape`. + * Support shape inference (i.e., shapes containing -1) in the Reshape bijector. +* Add `streaming_precision_recall_at_equal_thresholds,` a method for computing + streaming precision and recall with `O(num_thresholds + size of predictions)` + time and space complexity. +* Change `RunConfig` default behavior to not set a random seed, making random + behavior independently random on distributed workers. We expect this to + generally improve training performance. Models that do rely on determinism + should set a random seed explicitly. +* Replaced the implementation of `tf.flags` with `absl.flags`. +* Add support for `CUBLAS_TENSOR_OP_MATH` in fp16 GEMM +* Add support for CUDA on NVIDIA Tegra devices + +## Bug Fixes and Other Changes +* Documentation updates: + * Clarified that you can only install TensorFlow on 64-bit machines. + * Added a short doc explaining how `Estimator`s save checkpoints. + * Add documentation for ops supported by the `tf2xla` bridge. + * Fix minor typos in the doc of `SpaceToDepth` and `DepthToSpace`. + * Updated documentation comments in `mfcc_mel_filterbank.h` and `mfcc.h` to + clarify that the input domain is squared magnitude spectra and the weighting + is done on linear magnitude spectra (sqrt of inputs). + * Change `tf.contrib.distributions` docstring examples to use `tfd` alias + rather than `ds`, `bs`. + * Fix docstring typos in `tf.distributions.bijectors.Bijector`. + * `tf.assert_equal` no longer raises `ValueError.` It now raises + `InvalidArgumentError,` as documented. + * Update Getting Started docs and API intro. +* Google Cloud Storage (GCS): + * Add userspace DNS caching for the GCS client. + * Customize request timeouts for the GCS filesystem. + * Improve GCS filesystem caching. +* Bug Fixes: + * Fix bug where partitioned integer variables got their wrong shapes. Before + * Fix correctness bug in CPU and GPU implementations of Adadelta. + * Fix a bug in `import_meta_graph`'s handling of partitioned variables when + importing into a scope. WARNING: This may break loading checkpoints of + graphs with partitioned variables saved after using `import_meta_graph` with + a non-empty `import_scope` argument. + * Fix bug in offline debugger which prevented viewing events. + * Added the `WorkerService.DeleteWorkerSession` method to the gRPC interface, + to fix a memory leak. Ensure that your master and worker servers are running + the same version of TensorFlow to avoid compatibility issues. + * Fix bug in peephole implementation of BlockLSTM cell. + * Fix bug by casting dtype of `log_det_jacobian` to match `log_prob` in + `TransformedDistribution`. + * Fix a bug in `import_meta_graph`'s handling of partitioned variables when + * Ensure `tf.distributions.Multinomial` doesn't underflow in `log_prob`. + Before this change, all partitions of an integer variable were initialized + with the shape of the unpartitioned variable; after this change they are + initialized correctly. +* Other: + * Add necessary shape util support for bfloat16. + * Add a way to run ops using a step function to MonitoredSession. + * Add `DenseFlipout` probabilistic layer. + * A new flag `ignore_live_threads` is available on train. If set to `True`, it + will ignore threads that remain running when tearing down infrastructure + after successfully completing training, instead of throwing a RuntimeError. + * Restandardize `DenseVariational` as simpler template for other probabilistic + layers. + * `tf.data` now supports `tf.SparseTensor` components in dataset elements. + * It is now possible to iterate over `Tensor`s. + * Allow `SparseSegmentReduction` ops to have missing segment IDs. + * Modify custom export strategy to account for multidimensional sparse float + splits. + * `Conv2D`, `Conv2DBackpropInput`, `Conv2DBackpropFilter` now supports arbitrary + dilations with GPU and cuDNNv6 support. + * `Estimator` now supports `Dataset`: `input_fn` can return a `Dataset` + instead of `Tensor`s. + * Add `RevBlock`, a memory-efficient implementation of reversible residual layers. + * Reduce BFCAllocator internal fragmentation. + * Add `cross_entropy` and `kl_divergence` to `tf.distributions.Distribution`. + * Add `tf.nn.softmax_cross_entropy_with_logits_v2` which enables backprop + w.r.t. the labels. + * GPU back-end now uses `ptxas` to compile generated PTX. + * `BufferAssignment`'s protocol buffer dump is now deterministic. + * Change embedding op to use parallel version of `DynamicStitch`. + * Add support for sparse multidimensional feature columns. + * Speed up the case for sparse float columns that have only 1 value. + * Allow sparse float splits to support multivalent feature columns. + * Add `quantile` to `tf.distributions.TransformedDistribution`. + * Add `NCHW_VECT_C` support for `tf.depth_to_space` on GPU. + * Add `NCHW_VECT_C` support for `tf.space_to_depth` on GPU. + +## API Changes +* Rename `SqueezeDims` attribute to `Axis` in C++ API for Squeeze op. +* `Stream::BlockHostUntilDone` now returns Status rather than bool. +* Minor refactor: move stats files from `stochastic` to `common` and remove + `stochastic`. + +## Known Bugs +* Using XLA:GPU with CUDA 9 and CUDA 9.1 results in garbage results and/or + `CUDA_ILLEGAL_ADDRESS` failures. + + Google discovered in mid-December 2017 that the PTX-to-SASS compiler in CUDA 9 + and CUDA 9.1 sometimes does not properly compute the carry bit when + decomposing 64-bit address calculations with large offsets (e.g. `load [x + + large_constant]`) into 32-bit arithmetic in SASS. + + As a result, these versions of `ptxas` miscompile most XLA programs which use + more than 4GB of temp memory. This results in garbage results and/or + `CUDA_ERROR_ILLEGAL_ADDRESS` failures. + + A fix in CUDA 9.1.121 is expected in late February 2018. We do not expect a + fix for CUDA 9.0.x. Until the fix is available, the only workaround is to + [downgrade](https://developer.nvidia.com/cuda-toolkit-archive) to CUDA 8.0.x + or disable XLA:GPU. + + TensorFlow will print a warning if you use XLA:GPU with a known-bad version of + CUDA; see e00ba24c4038e7644da417ddc639169b6ea59122. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +Adam Zahran, Ag Ramesh, Alan Lee, Alan Yee, Alex Sergeev, Alexander, Amir H. Jadidinejad, +Amy, Anastasios Doumoulakis, Andrei Costinescu, Andrei Nigmatulin, Anthony Platanios, +Anush Elangovan, arixlin, Armen Donigian, ArtëM Sobolev, Atlas7, Ben Barsdell, Bill Prin, +Bo Wang, Brett Koonce, Cameron Thomas, Carl Thomé, Cem Eteke, cglewis, Changming Sun, +Charles Shenton, Chi-Hung, Chris Donahue, Chris Filo Gorgolewski, Chris Hoyean Song, +Chris Tava, Christian Grail, Christoph Boeddeker, cinqS, Clayne Robison, codrut3, concerttttt, +CQY, Dan Becker, Dan Jarvis, Daniel Zhang, David Norman, dmaclach, Dmitry Trifonov, +Donggeon Lim, dongpilYu, Dr. Kashif Rasul, Edd Wilder-James, Eric Lv, fcharras, Felix Abecassis, +FirefoxMetzger, formath, FredZhang, Gaojin Cao, Gary Deer, Guenther Schmuelling, Hanchen Li, +Hanmin Qin, hannesa2, hyunyoung2, Ilya Edrenkin, Jackson Kontny, Jan, Javier Luraschi, +Jay Young, Jayaram Bobba, Jeff, Jeff Carpenter, Jeremy Sharpe, Jeroen BéDorf, Jimmy Jia, +Jinze Bai, Jiongyan Zhang, Joe Castagneri, Johan Ju, Josh Varty, Julian Niedermeier, +JxKing, Karl Lessard, Kb Sriram, Keven Wang, Koan-Sin Tan, Kyle Mills, lanhin, LevineHuang, +Loki Der Quaeler, Loo Rong Jie, Luke Iwanski, LáSzló Csomor, Mahdi Abavisani, Mahmoud Abuzaina, +ManHyuk, Marek ŠUppa, MathSquared, Mats Linander, Matt Wytock, Matthew Daley, Maximilian Bachl, +mdymczyk, melvyniandrag, Michael Case, Mike Traynor, miqlas, Namrata-Ibm, Nathan Luehr, +Nathan Van Doorn, Noa Ezra, Nolan Liu, Oleg Zabluda, opensourcemattress, Ouwen Huang, +Paul Van Eck, peisong, Peng Yu, PinkySan, pks, powderluv, Qiao Hai-Jun, Qiao Longfei, +Rajendra Arora, Ralph Tang, resec, Robin Richtsfeld, Rohan Varma, Ryohei Kuroki, SaintNazaire, +Samuel He, Sandeep Dcunha, sandipmgiri, Sang Han, scott, Scott Mudge, Se-Won Kim, Simon Perkins, +Simone Cirillo, Steffen Schmitz, Suvojit Manna, Sylvus, Taehoon Lee, Ted Chang, Thomas Deegan, +Till Hoffmann, Tim, Toni Kunic, Toon Verstraelen, Tristan Rice, Urs KöSter, Utkarsh Upadhyay, +Vish (Ishaya) Abrams, Winnie Tsang, Yan Chen, Yan Facai (颜发才), Yi Yang, Yong Tang, +Youssef Hesham, Yuan (Terry) Tang, Zhengsheng Wei, zxcqwe4906, 张志豪, 田传武 + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. + +# Release 1.4.1 + +## Bug Fixes and Other Changes +* `LinearClassifier` fix. + +# Release 1.4.0 + +## Major Features And Improvements +* `tf.keras` is now part of the core TensorFlow API. +* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of + the core TensorFlow API. + * The API is now subject to backwards compatibility guarantees. + # Release 1.4.0 ## Major Features And Improvements @@ -351,7 +654,7 @@ answered questions, and were part of inspiring discussions. * Fixed LIBXSMM integration. * Make decode_jpeg/decode_png/decode_gif handle all formats, since users frequently try to decode an image as the wrong type. * Improve implicit broadcasting lowering. -* Improving stability of GCS/Bigquery clients by a faster retrying of stale transmissions. +* Improving stability of GCS/BigQuery clients by a faster retrying of stale transmissions. * Remove OpKernelConstruction::op_def() as part of minimizing proto dependencies. * VectorLaplaceDiag distribution added. * Android demo no longer requires libtensorflow_demo.so to run (libtensorflow_inference.so still required) diff --git a/WORKSPACE b/WORKSPACE index b40913801ba8e3c8ee73f7ba69540b520ad698a6..1e38a9a8cd754886fc5232531816b875de0879a3 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", - strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", + sha256 = "6691c58a2cd30a86776dd9bb34898b041e37136f2dc7e24cadaeaf599c95c657", + strip_prefix = "rules_closure-08039ba8ca59f64248bb3b6ae016460fe9c9914f", urls = [ - "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", - "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/08039ba8ca59f64248bb3b6ae016460fe9c9914f.tar.gz", # 2018-01-16 ], ) @@ -41,12 +41,12 @@ load("//tensorflow:workspace.bzl", "tf_workspace") tf_workspace() new_http_archive( - name = "inception5h", + name = "inception_v1", build_file = "models.BUILD", - sha256 = "d13569f6a98159de37e92e9c8ec4dae8f674fbf475f69fe6199b514f756d4364", + sha256 = "7efe12a8363f09bc24d7b7a450304a15655a57a7751929b2c1593a71183bb105", urls = [ - "http://storage.googleapis.com/download.tensorflow.org/models/inception5h.zip", - "http://download.tensorflow.org/models/inception5h.zip", + "http://storage.googleapis.com/download.tensorflow.org/models/inception_v1.zip", + "http://download.tensorflow.org/models/inception_v1.zip", ], ) diff --git a/configure.py b/configure.py index cf562bdee8ef288e4c2938f50e5c6366ce05ccff..3aa1a3e956c6a559b89cdeb593a96a95188c32ae 100644 --- a/configure.py +++ b/configure.py @@ -34,16 +34,26 @@ except ImportError: _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), '.tf_configure.bazelrc') -_DEFAULT_CUDA_VERSION = '8.0' -_DEFAULT_CUDNN_VERSION = '6' +_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), + 'WORKSPACE') +_DEFAULT_CUDA_VERSION = '9.0' +_DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' _DEFAULT_CUDA_PATH_LINUX = '/opt/cuda' _DEFAULT_CUDA_PATH_WIN = ('C:/Program Files/NVIDIA GPU Computing ' 'Toolkit/CUDA/v%s' % _DEFAULT_CUDA_VERSION) +_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/x86_64-linux-gnu' _TF_OPENCL_VERSION = '1.2' _DEFAULT_COMPUTECPP_TOOLKIT_PATH = '/usr/local/computecpp' _DEFAULT_TRISYCL_INCLUDE_DIR = '/usr/local/triSYCL/include' +_SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15] + +_DEFAULT_PROMPT_ASK_ATTEMPTS = 10 + + +class UserInputError(Exception): + pass def is_windows(): @@ -158,7 +168,7 @@ def get_python_path(environ_cp, python_bin_path): try: library_paths = run_shell( [python_bin_path, '-c', - 'import site; print("\\n".join(site.getsitepackages()))']).split("\n") + 'import site; print("\\n".join(site.getsitepackages()))']).split('\n') except subprocess.CalledProcessError: library_paths = [run_shell( [python_bin_path, '-c', @@ -256,19 +266,6 @@ def reset_tf_configure_bazelrc(): f.write('import %workspace%/.tf_configure.bazelrc\n') -def run_gen_git_source(environ_cp): - """Run the gen_git_source to create links. - - The links are for bazel to track dependencies for git hash propagation. - - Args: - environ_cp: copy of the os.environ. - """ - cmd = '"%s" tensorflow/tools/git/gen_git_source.py --configure %s' % ( - environ_cp.get('PYTHON_BIN_PATH'), os.getcwd()) - os.system(cmd) - - def cleanup_makefile(): """Delete any leftover BUILD files from the Makefile build. @@ -301,11 +298,17 @@ def get_var(environ_cp, System". enabled_by_default: boolean for default behavior. question: optional string for how to ask for user input. - yes_reply: optionanl string for reply when feature is enabled. + yes_reply: optional string for reply when feature is enabled. no_reply: optional string for reply when feature is disabled. Returns: boolean value of the variable. + + Raises: + UserInputError: if an environment variable is set, but it cannot be + interpreted as a boolean indicator, assume that the user has made a + scripting error, and will continue to provide invalid input. + Raise the error to avoid infinitely looping. """ if not question: question = 'Do you wish to build TensorFlow with %s support?' % query_item @@ -323,6 +326,23 @@ def get_var(environ_cp, question += ' [y/N]: ' var = environ_cp.get(var_name) + if var is not None: + var_content = var.strip().lower() + true_strings = ('1', 't', 'true', 'y', 'yes') + false_strings = ('0', 'f', 'false', 'n', 'no') + if var_content in true_strings: + var = True + elif var_content in false_strings: + var = False + else: + raise UserInputError( + 'Environment variable %s must be set as a boolean indicator.\n' + 'The following are accepted as TRUE : %s.\n' + 'The following are accepted as FALSE: %s.\n' + 'Current value is %s.' % ( + var_name, ', '.join(true_strings), ', '.join(false_strings), + var)) + while var is None: user_input_origin = get_input(question) user_input = user_input_origin.strip().lower() @@ -391,7 +411,7 @@ def set_action_env_var(environ_cp, System". enabled_by_default: boolean for default behavior. question: optional string for how to ask for user input. - yes_reply: optionanl string for reply when feature is enabled. + yes_reply: optional string for reply when feature is enabled. no_reply: optional string for reply when feature is disabled. """ var = int( @@ -425,7 +445,7 @@ def convert_version_to_int(version): def check_bazel_version(min_version): - """Check installed bezel version is at least min_version. + """Check installed bazel version is at least min_version. Args: min_version: string for minimum bazel version. @@ -509,6 +529,21 @@ def set_tf_cuda_clang(environ_cp): no_reply=no_reply) +def set_tf_download_clang(environ_cp): + """Set TF_DOWNLOAD_CLANG action_env.""" + question = 'Do you want to download a fresh release of clang? (Experimental)' + yes_reply = 'Clang will be downloaded and used to compile tensorflow.' + no_reply = 'Clang will not be downloaded.' + set_action_env_var( + environ_cp, + 'TF_DOWNLOAD_CLANG', + None, + False, + question=question, + yes_reply=yes_reply, + no_reply=no_reply) + + def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, var_default): """Get var_name either from env, or user or default. @@ -557,6 +592,219 @@ def set_clang_cuda_compiler_path(environ_cp): clang_cuda_compiler_path) +def prompt_loop_or_load_from_env( + environ_cp, + var_name, + var_default, + ask_for_var, + check_success, + error_msg, + suppress_default_error=False, + n_ask_attempts=_DEFAULT_PROMPT_ASK_ATTEMPTS +): + """Loop over user prompts for an ENV param until receiving a valid response. + + For the env param var_name, read from the environment or verify user input + until receiving valid input. When done, set var_name in the environ_cp to its + new value. + + Args: + environ_cp: (Dict) copy of the os.environ. + var_name: (String) string for name of environment variable, e.g. "TF_MYVAR". + var_default: (String) default value string. + ask_for_var: (String) string for how to ask for user input. + check_success: (Function) function that takes one argument and returns a + boolean. Should return True if the value provided is considered valid. May + contain a complex error message if error_msg does not provide enough + information. In that case, set suppress_default_error to True. + error_msg: (String) String with one and only one '%s'. Formatted with each + invalid response upon check_success(input) failure. + suppress_default_error: (Bool) Suppress the above error message in favor of + one from the check_success function. + n_ask_attempts: (Integer) Number of times to query for valid input before + raising an error and quitting. + + Returns: + [String] The value of var_name after querying for input. + + Raises: + UserInputError: if a query has been attempted n_ask_attempts times without + success, assume that the user has made a scripting error, and will + continue to provide invalid input. Raise the error to avoid infinitely + looping. + """ + default = environ_cp.get(var_name) or var_default + full_query = '%s [Default is %s]: ' % ( + ask_for_var, + default, + ) + + for _ in range(n_ask_attempts): + val = get_from_env_or_user_or_default(environ_cp, + var_name, + full_query, + default) + if check_success(val): + break + if not suppress_default_error: + print(error_msg % val) + environ_cp[var_name] = '' + else: + raise UserInputError('Invalid %s setting was provided %d times in a row. ' + 'Assuming to be a scripting mistake.' % + (var_name, n_ask_attempts)) + + environ_cp[var_name] = val + return val + + +def create_android_ndk_rule(environ_cp): + """Set ANDROID_NDK_HOME and write Android NDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_ndk_path = cygpath('%s/Android/Sdk/ndk-bundle' % + environ_cp['APPDATA']) + elif is_macos(): + default_ndk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_ndk_path = '%s/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + + def valid_ndk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'source.properties'))) + + android_ndk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_NDK_HOME', + var_default=default_ndk_path, + ask_for_var='Please specify the home path of the Android NDK to use.', + check_success=valid_ndk_path, + error_msg=('The path %s or its child file "source.properties" ' + 'does not exist.') + ) + + write_android_ndk_workspace_rule(android_ndk_home_path) + + +def create_android_sdk_rule(environ_cp): + """Set Android variables and write Android SDK WORKSPACE rule.""" + if is_windows() or is_cygwin(): + default_sdk_path = cygpath('%s/Android/Sdk' % environ_cp['APPDATA']) + elif is_macos(): + default_sdk_path = '%s/library/Android/Sdk/ndk-bundle' % environ_cp['HOME'] + else: + default_sdk_path = '%s/Android/Sdk' % environ_cp['HOME'] + + def valid_sdk_path(path): + return (os.path.exists(path) and + os.path.exists(os.path.join(path, 'platforms')) and + os.path.exists(os.path.join(path, 'build-tools'))) + + android_sdk_home_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_SDK_HOME', + var_default=default_sdk_path, + ask_for_var='Please specify the home path of the Android SDK to use.', + check_success=valid_sdk_path, + error_msg=('Either %s does not exist, or it does not contain the ' + 'subdirectories "platforms" and "build-tools".')) + + platforms = os.path.join(android_sdk_home_path, 'platforms') + api_levels = sorted(os.listdir(platforms)) + api_levels = [x.replace('android-', '') for x in api_levels] + + def valid_api_level(api_level): + return os.path.exists(os.path.join(android_sdk_home_path, + 'platforms', + 'android-' + api_level)) + + android_api_level = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_API_LEVEL', + var_default=api_levels[-1], + ask_for_var=('Please specify the Android SDK API level to use. ' + '[Available levels: %s]') % api_levels, + check_success=valid_api_level, + error_msg='Android-%s is not present in the SDK path.') + + build_tools = os.path.join(android_sdk_home_path, 'build-tools') + versions = sorted(os.listdir(build_tools)) + + def valid_build_tools(version): + return os.path.exists(os.path.join(android_sdk_home_path, + 'build-tools', + version)) + + android_build_tools_version = prompt_loop_or_load_from_env( + environ_cp, + var_name='ANDROID_BUILD_TOOLS_VERSION', + var_default=versions[-1], + ask_for_var=('Please specify an Android build tools version to use. ' + '[Available versions: %s]') % versions, + check_success=valid_build_tools, + error_msg=('The selected SDK does not have build-tools version %s ' + 'available.')) + + write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level) + + +def write_android_sdk_workspace_rule(android_sdk_home_path, + android_build_tools_version, + android_api_level): + print('Writing android_sdk_workspace rule.\n') + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_sdk_repository( + name="androidsdk", + api_level=%s, + path="%s", + build_tools_version="%s")\n +""" % (android_api_level, android_sdk_home_path, android_build_tools_version)) + + +def write_android_ndk_workspace_rule(android_ndk_home_path): + print('Writing android_ndk_workspace rule.') + ndk_api_level = check_ndk_level(android_ndk_home_path) + if int(ndk_api_level) not in _SUPPORTED_ANDROID_NDK_VERSIONS: + print('WARNING: The API level of the NDK in %s is %s, which is not ' + 'supported by Bazel (officially supported versions: %s). Please use ' + 'another version. Compiling Android targets may result in confusing ' + 'errors.\n' % (android_ndk_home_path, ndk_api_level, + _SUPPORTED_ANDROID_NDK_VERSIONS)) + with open(_TF_WORKSPACE, 'a') as f: + f.write(""" +android_ndk_repository( + name="androidndk", + path="%s", + api_level=%s)\n +""" % (android_ndk_home_path, ndk_api_level)) + + +def check_ndk_level(android_ndk_home_path): + """Check the revision number of an Android NDK path.""" + properties_path = '%s/source.properties' % android_ndk_home_path + if is_windows() or is_cygwin(): + properties_path = cygpath(properties_path) + with open(properties_path, 'r') as f: + filedata = f.read() + + revision = re.search(r'Pkg.Revision = (\d+)', filedata) + if revision: + return revision.group(1) + return None + + +def workspace_has_any_android_rule(): + """Check the WORKSPACE for existing android_*_repository rules.""" + with open(_TF_WORKSPACE, 'r') as f: + workspace = f.read() + has_any_rule = re.search(r'^android_[ns]dk_repository', + workspace, + re.MULTILINE) + return has_any_rule + + def set_gcc_host_compiler_path(environ_cp): """Set GCC_HOST_COMPILER_PATH.""" default_gcc_host_compiler_path = which('gcc') or '' @@ -566,24 +814,39 @@ def set_gcc_host_compiler_path(environ_cp): # os.readlink is only available in linux default_gcc_host_compiler_path = os.path.realpath(cuda_bin_symlink) - ask_gcc_path = ( - 'Please specify which gcc should be used by nvcc as the ' - 'host compiler. [Default is %s]: ') % default_gcc_host_compiler_path - while True: - gcc_host_compiler_path = get_from_env_or_user_or_default( - environ_cp, 'GCC_HOST_COMPILER_PATH', ask_gcc_path, - default_gcc_host_compiler_path) + gcc_host_compiler_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='GCC_HOST_COMPILER_PATH', + var_default=default_gcc_host_compiler_path, + ask_for_var= + 'Please specify which gcc should be used by nvcc as the host compiler.', + check_success=os.path.exists, + error_msg='Invalid gcc path. %s cannot be found.', + ) - if os.path.exists(gcc_host_compiler_path): - break + write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) - # Reset and retry - print('Invalid gcc path. %s cannot be found' % gcc_host_compiler_path) - environ_cp['GCC_HOST_COMPILER_PATH'] = '' - # Set GCC_HOST_COMPILER_PATH - environ_cp['GCC_HOST_COMPILER_PATH'] = gcc_host_compiler_path - write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) +def reformat_version_sequence(version_str, sequence_count): + """Reformat the version string to have the given number of sequences. + + For example: + Given (7, 2) -> 7.0 + (7.0.1, 2) -> 7.0 + (5, 1) -> 5 + (5.0.3.2, 1) -> 5 + + Args: + version_str: String, the version string. + sequence_count: int, an integer. + Returns: + string, reformatted version string. + """ + v = version_str.split('.') + if len(v) < sequence_count: + v = v + (['0'] * (sequence_count - len(v))) + + return '.'.join(v[:sequence_count]) def set_tf_cuda_version(environ_cp): @@ -592,10 +855,11 @@ def set_tf_cuda_version(environ_cp): 'Please specify the CUDA SDK version you want to use, ' 'e.g. 7.0. [Leave empty to default to CUDA %s]: ') % _DEFAULT_CUDA_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): # Configure the Cuda SDK version to use. tf_cuda_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) + tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2) # Find out where the CUDA toolkit is installed default_cuda_path = _DEFAULT_CUDA_PATH @@ -630,6 +894,11 @@ def set_tf_cuda_version(environ_cp): environ_cp['TF_CUDA_VERSION'] = '' environ_cp['CUDA_TOOLKIT_PATH'] = '' + else: + raise UserInputError('Invalid TF_CUDA_SETTING setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + # Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION environ_cp['CUDA_TOOLKIT_PATH'] = cuda_toolkit_path write_action_env_to_bazelrc('CUDA_TOOLKIT_PATH', cuda_toolkit_path) @@ -643,10 +912,11 @@ def set_tf_cudnn_version(environ_cp): 'Please specify the cuDNN version you want to use. ' '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION - while True: + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) + tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version) ,1) default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' @@ -702,6 +972,10 @@ def set_tf_cudnn_version(environ_cp): print('%s.%s' % (cudnn_path_from_ldconfig, tf_cudnn_version)) environ_cp['TF_CUDNN_VERSION'] = '' + else: + raise UserInputError('Invalid TF_CUDNN setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) # Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION environ_cp['CUDNN_INSTALL_PATH'] = cudnn_install_path @@ -710,6 +984,128 @@ def set_tf_cudnn_version(environ_cp): write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) +def set_tf_tensorrt_install_path(environ_cp): + """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. + + Adapted from code contributed by Sami Kama (https://github.com/samikama). + + Args: + environ_cp: copy of the os.environ. + + Raises: + ValueError: if this method was called under non-Linux platform. + UserInputError: if user has provided invalid input multiple times. + """ + if not is_linux(): + raise ValueError('Currently TensorRT is only supported on Linux platform.') + + # Ask user whether to add TensorRT support. + if str(int(get_var( + environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': + return + + for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): + ask_tensorrt_path = (r'Please specify the location where TensorRT is ' + 'installed. [Default is %s]:') % ( + _DEFAULT_TENSORRT_PATH_LINUX) + trt_install_path = get_from_env_or_user_or_default( + environ_cp, 'TENSORRT_INSTALL_PATH', ask_tensorrt_path, + _DEFAULT_TENSORRT_PATH_LINUX) + + # Result returned from "read" will be used unexpanded. That make "~" + # unusable. Going through one more level of expansion to handle that. + trt_install_path = os.path.realpath( + os.path.expanduser(trt_install_path)) + + def find_libs(search_path): + """Search for libnvinfer.so in "search_path".""" + fl = set() + if os.path.exists(search_path) and os.path.isdir(search_path): + fl.update([os.path.realpath(os.path.join(search_path, x)) + for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + return fl + + possible_files = find_libs(trt_install_path) + possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) + possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) + + def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): + """Check the compatibility between tensorrt and cudnn/cudart libraries.""" + ldd_bin = which('ldd') or '/usr/bin/ldd' + ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) + cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') + cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') + cudnn = None + cudart = None + for line in ldd_out: + if 'libcudnn.so' in line: + cudnn = cudnn_pattern.search(line) + elif 'libcudart.so' in line: + cudart = cuda_pattern.search(line) + if cudnn and len(cudnn.group(1)): + cudnn = convert_version_to_int(cudnn.group(1)) + if cudart and len(cudart.group(1)): + cudart = convert_version_to_int(cudart.group(1)) + return (cudnn == cudnn_ver) and (cudart == cuda_ver) + + cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) + cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) + nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') + highest_ver = [0, None, None] + + for lib_file in possible_files: + if is_compatible(lib_file, cuda_ver, cudnn_ver): + ver_str = nvinfer_pattern.search(lib_file).group(1) + ver = convert_version_to_int(ver_str) if len(ver_str) else 0 + if ver > highest_ver[0]: + highest_ver = [ver, ver_str, lib_file] + if highest_ver[1] is not None: + trt_install_path = os.path.dirname(highest_ver[2]) + tf_tensorrt_version = highest_ver[1] + break + + # Try another alternative from ldconfig. + ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' + ldconfig_output = run_shell([ldconfig_bin, '-p']) + search_result = re.search( + '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) + if search_result: + libnvinfer_path_from_ldconfig = search_result.group(2) + if os.path.exists(libnvinfer_path_from_ldconfig): + if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): + trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) + tf_tensorrt_version = search_result.group(1) + break + + # Reset and Retry + if len(possible_files): + print('TensorRT libraries found in one the following directories', + 'are not compatible with selected cuda and cudnn installations') + print(trt_install_path) + print(os.path.join(trt_install_path, 'lib')) + print(os.path.join(trt_install_path, 'lib64')) + if search_result: + print(libnvinfer_path_from_ldconfig) + else: + print('Invalid path to TensorRT. None of the following files can be found:') + print(trt_install_path) + print(os.path.join(trt_install_path, 'lib')) + print(os.path.join(trt_install_path, 'lib64')) + if search_result: + print(libnvinfer_path_from_ldconfig) + + else: + raise UserInputError('Invalid TF_TENSORRT setting was provided %d ' + 'times in a row. Assuming to be a scripting mistake.' % + _DEFAULT_PROMPT_ASK_ATTEMPTS) + + # Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION + environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path + write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path) + environ_cp['TF_TENSORRT_VERSION'] = tf_tensorrt_version + write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_tensorrt_version) + + def get_native_cuda_compute_capabilities(environ_cp): """Get native cuda compute capabilities. @@ -810,90 +1206,83 @@ def set_other_cuda_vars(environ_cp): def set_host_cxx_compiler(environ_cp): """Set HOST_CXX_COMPILER.""" default_cxx_host_compiler = which('g++') or '' - ask_cxx_host_compiler = ( - 'Please specify which C++ compiler should be used as' - ' the host C++ compiler. [Default is %s]: ') % default_cxx_host_compiler - - while True: - host_cxx_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_CXX_COMPILER', ask_cxx_host_compiler, - default_cxx_host_compiler) - if os.path.exists(host_cxx_compiler): - break - # Reset and retry - print('Invalid C++ compiler path. %s cannot be found' % host_cxx_compiler) - environ_cp['HOST_CXX_COMPILER'] = '' + host_cxx_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_CXX_COMPILER', + var_default=default_cxx_host_compiler, + ask_for_var=('Please specify which C++ compiler should be used as the ' + 'host C++ compiler.'), + check_success=os.path.exists, + error_msg='Invalid C++ compiler path. %s cannot be found.', + ) - # Set HOST_CXX_COMPILER - environ_cp['HOST_CXX_COMPILER'] = host_cxx_compiler write_action_env_to_bazelrc('HOST_CXX_COMPILER', host_cxx_compiler) def set_host_c_compiler(environ_cp): """Set HOST_C_COMPILER.""" default_c_host_compiler = which('gcc') or '' - ask_c_host_compiler = ( - 'Please specify which C compiler should be used as the' - ' host C compiler. [Default is %s]: ') % default_c_host_compiler - - while True: - host_c_compiler = get_from_env_or_user_or_default( - environ_cp, 'HOST_C_COMPILER', ask_c_host_compiler, - default_c_host_compiler) - if os.path.exists(host_c_compiler): - break - # Reset and retry - print('Invalid C compiler path. %s cannot be found' % host_c_compiler) - environ_cp['HOST_C_COMPILER'] = '' + host_c_compiler = prompt_loop_or_load_from_env( + environ_cp, + var_name='HOST_C_COMPILER', + var_default=default_c_host_compiler, + ask_for_var=('Please specify which C compiler should be used as the host' + 'C compiler.'), + check_success=os.path.exists, + error_msg='Invalid C compiler path. %s cannot be found.', + ) - # Set HOST_C_COMPILER - environ_cp['HOST_C_COMPILER'] = host_c_compiler write_action_env_to_bazelrc('HOST_C_COMPILER', host_c_compiler) def set_computecpp_toolkit_path(environ_cp): """Set COMPUTECPP_TOOLKIT_PATH.""" - ask_computecpp_toolkit_path = ('Please specify the location where ComputeCpp ' - 'for SYCL %s is installed. [Default is %s]: ' - ) % (_TF_OPENCL_VERSION, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) - while True: - computecpp_toolkit_path = get_from_env_or_user_or_default( - environ_cp, 'COMPUTECPP_TOOLKIT_PATH', ask_computecpp_toolkit_path, - _DEFAULT_COMPUTECPP_TOOLKIT_PATH) + def toolkit_exists(toolkit_path): + """Check if a computecpp toolkit path is valid.""" if is_linux(): sycl_rt_lib_path = 'lib/libComputeCpp.so' else: sycl_rt_lib_path = '' - sycl_rt_lib_path_full = os.path.join(computecpp_toolkit_path, + sycl_rt_lib_path_full = os.path.join(toolkit_path, sycl_rt_lib_path) - if os.path.exists(sycl_rt_lib_path_full): - break + exists = os.path.exists(sycl_rt_lib_path_full) + if not exists: + print('Invalid SYCL %s library path. %s cannot be found' % + (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) + return exists - print('Invalid SYCL %s library path. %s cannot be found' % - (_TF_OPENCL_VERSION, sycl_rt_lib_path_full)) - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = '' + computecpp_toolkit_path = prompt_loop_or_load_from_env( + environ_cp, + var_name='COMPUTECPP_TOOLKIT_PATH', + var_default=_DEFAULT_COMPUTECPP_TOOLKIT_PATH, + ask_for_var=( + 'Please specify the location where ComputeCpp for SYCL %s is ' + 'installed.' % _TF_OPENCL_VERSION), + check_success=toolkit_exists, + error_msg='Invalid SYCL compiler path. %s cannot be found.', + suppress_default_error=True) - # Set COMPUTECPP_TOOLKIT_PATH - environ_cp['COMPUTECPP_TOOLKIT_PATH'] = computecpp_toolkit_path write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH', computecpp_toolkit_path) + def set_trisycl_include_dir(environ_cp): - """Set TRISYCL_INCLUDE_DIR""" + """Set TRISYCL_INCLUDE_DIR.""" + ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' 'include directory. (Use --config=sycl_trisycl ' 'when building with Bazel) ' '[Default is %s]: ' - ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + ) % (_DEFAULT_TRISYCL_INCLUDE_DIR) + while True: trisycl_include_dir = get_from_env_or_user_or_default( - environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, - _DEFAULT_TRISYCL_INCLUDE_DIR) + environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, + _DEFAULT_TRISYCL_INCLUDE_DIR) if os.path.exists(trisycl_include_dir): break @@ -905,50 +1294,30 @@ def set_trisycl_include_dir(environ_cp): write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) -def set_trisycl_include_dir(environ_cp): - """Set TRISYCL_INCLUDE_DIR.""" - ask_trisycl_include_dir = ('Please specify the location of the triSYCL ' - 'include directory. (Use --config=sycl_trisycl ' - 'when building with Bazel) ' - '[Default is %s]: ') % ( - _DEFAULT_TRISYCL_INCLUDE_DIR) - while True: - trisycl_include_dir = get_from_env_or_user_or_default( - environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir, - _DEFAULT_TRISYCL_INCLUDE_DIR) - if os.path.exists(trisycl_include_dir): - break - - print('Invalid triSYCL include directory, %s cannot be found' % - (trisycl_include_dir)) - - # Set TRISYCL_INCLUDE_DIR - environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir - write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir) - def set_mpi_home(environ_cp): """Set MPI_HOME.""" + default_mpi_home = which('mpirun') or which('mpiexec') or '' default_mpi_home = os.path.dirname(os.path.dirname(default_mpi_home)) - ask_mpi_home = ('Please specify the MPI toolkit folder. [Default is %s]: ' - ) % default_mpi_home - while True: - mpi_home = get_from_env_or_user_or_default(environ_cp, 'MPI_HOME', - ask_mpi_home, default_mpi_home) - - if os.path.exists(os.path.join(mpi_home, 'include')) and os.path.exists( - os.path.join(mpi_home, 'lib')): - break - - print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % - (os.path.join(mpi_home, 'include'), - os.path.exists(os.path.join(mpi_home, 'lib')))) - environ_cp['MPI_HOME'] = '' + def valid_mpi_path(mpi_home): + exists = (os.path.exists(os.path.join(mpi_home, 'include')) and + os.path.exists(os.path.join(mpi_home, 'lib'))) + if not exists: + print('Invalid path to the MPI Toolkit. %s or %s cannot be found' % + (os.path.join(mpi_home, 'include'), + os.path.exists(os.path.join(mpi_home, 'lib')))) + return exists - # Set MPI_HOME - environ_cp['MPI_HOME'] = str(mpi_home) + _ = prompt_loop_or_load_from_env( + environ_cp, + var_name='MPI_HOME', + var_default=default_mpi_home, + ask_for_var='Please specify the MPI toolkit folder.', + check_success=valid_mpi_path, + error_msg='', + suppress_default_error=True) def set_other_mpi_vars(environ_cp): @@ -983,47 +1352,25 @@ def set_other_mpi_vars(environ_cp): raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home) -def set_mkl(): - write_to_bazelrc('build:mkl --define using_mkl=true') - write_to_bazelrc('build:mkl -c opt') - print( - 'Add "--config=mkl" to your bazel command to build with MKL ' - 'support.\nPlease note that MKL on MacOS or windows is still not ' - 'supported.\nIf you would like to use a local MKL instead of ' - 'downloading, please set the environment variable \"TF_MKL_ROOT\" every ' - 'time before build.') - - -def set_monolithic(): - # Add --config=monolithic to your bazel command to use a mostly-static - # build and disable modular op registration support (this will revert to - # loading TensorFlow with RTLD_GLOBAL in Python). By default (without - # --config=monolithic), TensorFlow will build with a dependence on - # //tensorflow:libtensorflow_framework.so. - write_to_bazelrc('build:monolithic --define framework_shared_object=false') - # For projects which use TensorFlow as part of a Bazel build process, putting - # nothing in a bazelrc will default to a monolithic build. The following line - # opts in to modular op registration support by default: - write_to_bazelrc('build --define framework_shared_object=true') - - -def create_android_bazelrc_configs(): - # Flags for --config=android - write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool') - write_to_bazelrc( - 'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain') - # Flags for --config=android_arm - write_to_bazelrc('build:android_arm --config=android') - write_to_bazelrc('build:android_arm --cpu=armeabi-v7a') - # Flags for --config=android_arm64 - write_to_bazelrc('build:android_arm64 --config=android') - write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a') - - def set_grpc_build_flags(): write_to_bazelrc('build --define grpc_no_ares=true') +def set_windows_build_flags(): + if is_windows(): + # The non-monolithic build is not supported yet + write_to_bazelrc('build --config monolithic') + # Suppress warning messages + write_to_bazelrc('build --copt=-w --host_copt=-w') + # Output more verbose information when something goes wrong + write_to_bazelrc('build --verbose_failures') + + +def config_info_line(name, help_text): + """Helper function to print formatted help text for Bazel config options.""" + print('\t--config=%-12s\t# %s' % (name, help_text)) + + def main(): # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. @@ -1034,20 +1381,22 @@ def main(): reset_tf_configure_bazelrc() cleanup_makefile() setup_python(environ_cp) - run_gen_git_source(environ_cp) if is_windows(): environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_KAFKA'] = '0' environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' + environ_cp['TF_NEED_TENSORRT'] = '0' if is_macos(): environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_TENSORRT'] = '0' set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) @@ -1057,6 +1406,8 @@ def main(): 'with_hdfs_support', True, 'hdfs') set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 'with_s3_support', True, 's3') + set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', + 'with_kafka_support', False, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1079,12 +1430,27 @@ def main(): 'TF_CUDA_CONFIG_REPO' not in environ_cp): set_tf_cuda_version(environ_cp) set_tf_cudnn_version(environ_cp) + if is_linux(): + set_tf_tensorrt_install_path(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) + if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get('LD_LIBRARY_PATH') != '1': + write_action_env_to_bazelrc('LD_LIBRARY_PATH', environ_cp.get('LD_LIBRARY_PATH')) set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': - # Set up which clang we should use as the cuda / host compiler. - set_clang_cuda_compiler_path(environ_cp) + if not is_windows(): + # Ask if we want to download clang release while building. + set_tf_download_clang(environ_cp) + else: + # We use bazel's generated crosstool on Windows and there is no + # way to provide downloaded toolchain for that yet. + # TODO(ibiryukov): Investigate using clang as a cuda compiler on + # Windows. + environ_cp['TF_DOWNLOAD_CLANG'] = '0' + + if environ_cp.get('TF_DOWNLOAD_CLANG') != '1': + # Set up which clang we should use as the cuda / host compiler. + set_clang_cuda_compiler_path(environ_cp) else: # Set up which gcc nvcc should use as the host compiler # No need to set this on Windows @@ -1099,9 +1465,29 @@ def main(): set_grpc_build_flags() set_cc_opt_flags(environ_cp) - set_mkl() - set_monolithic() - create_android_bazelrc_configs() + set_windows_build_flags() + + if workspace_has_any_android_rule(): + print('The WORKSPACE file has at least one of ["android_sdk_repository", ' + '"android_ndk_repository"] already set. Will not ask to help ' + 'configure the WORKSPACE. Please delete the existing rules to ' + 'activate the helper.\n') + else: + if get_var( + environ_cp, 'TF_SET_ANDROID_WORKSPACE', 'android workspace', + False, + ('Would you like to interactively configure ./WORKSPACE for ' + 'Android builds?'), + 'Searching for NDK and SDK installations.', + 'Not configuring the WORKSPACE for Android builds.'): + create_android_ndk_rule(environ_cp) + create_android_sdk_rule(environ_cp) + + print('Preconfigured Bazel build configs. You can use any of the below by ' + 'adding "--config=<>" to your build command. See tools/bazel.rc for ' + 'more details.') + config_info_line('mkl', 'Build with MKL support.') + config_info_line('monolithic', 'Config for mostly static monolithic build.') if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index bfebe8a5678a2c0508b31f5dd898eac22186a072..dc995d231d3e591771f801e28024a76610cdba26 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -211,6 +211,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_kafka_support", + define_values = {"with_kafka_support": "true"}, + visibility = ["//visibility:public"], +) + # Crosses between platforms and file system libraries not supported on those # platforms due to limitations in nested select() statements. config_setting( @@ -364,11 +370,17 @@ config_setting( visibility = ["//visibility:public"], ) -# Make a dummy rule that we can change "default" in select statements to. -# to disable dependencies in copybara. config_setting( - name = "dummy_disabled_internal", - values = {"define": "with_dummy_disabled_internal=true"}, + name = "override_eigen_strong_inline", + values = {"define": "override_eigen_strong_inline=true"}, + visibility = ["//visibility:public"], +) + +# TODO(laigd): consider removing this option and make TensorRT enabled +# automatically when CUDA is enabled. +config_setting( + name = "with_tensorrt_support", + values = {"define": "with_tensorrt_support=true"}, visibility = ["//visibility:public"], ) @@ -378,6 +390,7 @@ package_group( "//learning/meta_rank/...", "//tensorflow/...", "//tensorflow_fold/llgtm/...", + "//third_party/py/tensor2tensor/...", ], ) @@ -409,6 +422,8 @@ filegroup( "//tensorflow/c:all_files", "//tensorflow/cc:all_files", "//tensorflow/cc/saved_model:all_files", + "//tensorflow/cc/saved_model/python:all_files", + "//tensorflow/cc/tools:all_files", "//tensorflow/compiler/aot:all_files", "//tensorflow/compiler/aot/tests:all_files", "//tensorflow/compiler/jit:all_files", @@ -427,6 +442,7 @@ filegroup( "//tensorflow/compiler/xla/client:all_files", "//tensorflow/compiler/xla/client/lib:all_files", "//tensorflow/compiler/xla/legacy_flags:all_files", + "//tensorflow/compiler/xla/python:all_files", "//tensorflow/compiler/xla/service:all_files", "//tensorflow/compiler/xla/service/cpu:all_files", "//tensorflow/compiler/xla/service/gpu:all_files", @@ -440,9 +456,6 @@ filegroup( "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", "//tensorflow/contrib/batching:all_files", - "//tensorflow/contrib/batching/kernels:all_files", - "//tensorflow/contrib/batching/test_util:all_files", - "//tensorflow/contrib/batching/util:all_files", "//tensorflow/contrib/bayesflow:all_files", "//tensorflow/contrib/boosted_trees:all_files", "//tensorflow/contrib/boosted_trees/estimator_batch:all_files", @@ -452,6 +465,7 @@ filegroup( "//tensorflow/contrib/cloud:all_files", "//tensorflow/contrib/cloud/kernels:all_files", "//tensorflow/contrib/cluster_resolver:all_files", + "//tensorflow/contrib/coder:all_files", "//tensorflow/contrib/compiler:all_files", "//tensorflow/contrib/copy_graph:all_files", "//tensorflow/contrib/crf:all_files", @@ -461,11 +475,15 @@ filegroup( "//tensorflow/contrib/data/python/kernel_tests:all_files", "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/decision_trees/proto:all_files", + "//tensorflow/contrib/deprecated:all_files", "//tensorflow/contrib/distributions:all_files", + "//tensorflow/contrib/eager/proto:all_files", "//tensorflow/contrib/eager/python:all_files", "//tensorflow/contrib/estimator:all_files", "//tensorflow/contrib/factorization:all_files", + "//tensorflow/contrib/factorization/examples:all_files", "//tensorflow/contrib/factorization/kernels:all_files", + "//tensorflow/contrib/feature_column:all_files", "//tensorflow/contrib/ffmpeg:all_files", "//tensorflow/contrib/ffmpeg/default:all_files", "//tensorflow/contrib/framework:all_files", @@ -475,6 +493,7 @@ filegroup( "//tensorflow/contrib/graph_editor:all_files", "//tensorflow/contrib/grid_rnn:all_files", "//tensorflow/contrib/hooks:all_files", + "//tensorflow/contrib/hvx/clock_cycle_profiling:all_files", "//tensorflow/contrib/hvx/hvx_ops_support_checker:all_files", "//tensorflow/contrib/image:all_files", "//tensorflow/contrib/input_pipeline:all_files", @@ -492,6 +511,8 @@ filegroup( "//tensorflow/contrib/layers/kernels:all_files", "//tensorflow/contrib/learn:all_files", "//tensorflow/contrib/learn/python/learn/datasets:all_files", + "//tensorflow/contrib/legacy_seq2seq:all_files", + "//tensorflow/contrib/libsvm:all_files", "//tensorflow/contrib/linalg:all_files", "//tensorflow/contrib/linear_optimizer:all_files", "//tensorflow/contrib/lite:all_files", @@ -516,15 +537,23 @@ filegroup( "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", "//tensorflow/contrib/makefile:all_files", + "//tensorflow/contrib/memory_stats:all_files", "//tensorflow/contrib/meta_graph_transform:all_files", "//tensorflow/contrib/metrics:all_files", "//tensorflow/contrib/model_pruning:all_files", - "//tensorflow/contrib/mpi_collectives:all_files", - "//tensorflow/contrib/ndlstm:all_files", + "//tensorflow/contrib/model_pruning/examples/cifar10:all_files", + "//tensorflow/contrib/nccl:all_files", "//tensorflow/contrib/nearest_neighbor:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", + "//tensorflow/contrib/periodic_resample:all_files", "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/py2tf:all_files", + "//tensorflow/contrib/py2tf/converters:all_files", + "//tensorflow/contrib/py2tf/impl:all_files", + "//tensorflow/contrib/py2tf/pyct:all_files", + "//tensorflow/contrib/py2tf/pyct/static_analysis:all_files", + "//tensorflow/contrib/py2tf/utils:all_files", "//tensorflow/contrib/quantize:all_files", "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", @@ -553,6 +582,7 @@ filegroup( "//tensorflow/contrib/tensor_forest/proto:all_files", "//tensorflow/contrib/tensorboard:all_files", "//tensorflow/contrib/tensorboard/db:all_files", + "//tensorflow/contrib/tensorrt:all_files", "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/text:all_files", "//tensorflow/contrib/tfprof:all_files", @@ -567,6 +597,7 @@ filegroup( "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", "//tensorflow/core:all_files", + "//tensorflow/core/api_def:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", "//tensorflow/core/distributed_runtime/rpc:all_files", @@ -577,6 +608,9 @@ filegroup( "//tensorflow/core/grappler/optimizers:all_files", "//tensorflow/core/grappler/utils:all_files", "//tensorflow/core/kernels:all_files", + "//tensorflow/core/kernels/batching_util:all_files", + "//tensorflow/core/kernels/data:all_files", + "//tensorflow/core/kernels/data/sql:all_files", "//tensorflow/core/kernels/fuzzing:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/kernels/neon:all_files", @@ -591,6 +625,7 @@ filegroup( "//tensorflow/core/profiler/internal/advisor:all_files", "//tensorflow/core/util/ctc:all_files", "//tensorflow/core/util/tensor_bundle:all_files", + "//tensorflow/examples/adding_an_op:all_files", "//tensorflow/examples/android:all_files", "//tensorflow/examples/benchmark:all_files", "//tensorflow/examples/get_started/regression:all_files", @@ -598,10 +633,13 @@ filegroup( "//tensorflow/examples/image_retraining:all_files", "//tensorflow/examples/label_image:all_files", "//tensorflow/examples/learn:all_files", + "//tensorflow/examples/multibox_detector:all_files", "//tensorflow/examples/saved_model:all_files", "//tensorflow/examples/speech_commands:all_files", "//tensorflow/examples/tutorials/estimators:all_files", + "//tensorflow/examples/tutorials/layers:all_files", "//tensorflow/examples/tutorials/mnist:all_files", + "//tensorflow/examples/tutorials/monitors:all_files", "//tensorflow/examples/tutorials/word2vec:all_files", "//tensorflow/examples/wav_to_spectrogram:all_files", "//tensorflow/go:all_files", @@ -610,6 +648,7 @@ filegroup( "//tensorflow/java/src/main/native:all_files", "//tensorflow/python:all_files", "//tensorflow/python/data:all_files", + "//tensorflow/python/data/kernel_tests:all_files", "//tensorflow/python/data/ops:all_files", "//tensorflow/python/data/util:all_files", "//tensorflow/python/debug:all_files", @@ -623,6 +662,7 @@ filegroup( "//tensorflow/python/kernel_tests/random:all_files", "//tensorflow/python/ops/distributions:all_files", "//tensorflow/python/ops/linalg:all_files", + "//tensorflow/python/ops/losses:all_files", "//tensorflow/python/profiler:all_files", "//tensorflow/python/profiler/internal:all_files", "//tensorflow/python/saved_model:all_files", @@ -633,6 +673,7 @@ filegroup( "//tensorflow/tools/api/tests:all_files", "//tensorflow/tools/benchmark:all_files", "//tensorflow/tools/build_info:all_files", + "//tensorflow/tools/ci_build/gpu_build:all_files", "//tensorflow/tools/common:all_files", "//tensorflow/tools/compatibility:all_files", "//tensorflow/tools/dist_test/server:all_files", @@ -640,17 +681,20 @@ filegroup( "//tensorflow/tools/docker/notebooks:all_files", "//tensorflow/tools/docs:all_files", "//tensorflow/tools/git:all_files", + "//tensorflow/tools/graph_transforms:all_files", "//tensorflow/tools/mlpbtxt:all_files", "//tensorflow/tools/proto_text:all_files", "//tensorflow/tools/quantization:all_files", "//tensorflow/tools/test:all_files", "//tensorflow/user_ops:all_files", + "//third_party/eigen3:all_files", + "//third_party/fft2d:all_files", + "//third_party/flatbuffers:all_files", "//third_party/hadoop:all_files", - "//third_party/mpi:all_files", "//third_party/sycl:all_files", "//third_party/sycl/sycl:all_files", ], - visibility = [":__subpackages__"], + visibility = ["//visibility:public"], ) load( @@ -774,6 +818,7 @@ tf_cc_shared_object( "//tensorflow/cc:cc_ops", "//tensorflow/cc:client_session", "//tensorflow/cc:scope", + "//tensorflow/cc/profiler", "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/SECURITY.md b/tensorflow/SECURITY.md new file mode 100644 index 0000000000000000000000000000000000000000..6ddac1f964dfba3afd240441e2a036bc24ee6d91 --- /dev/null +++ b/tensorflow/SECURITY.md @@ -0,0 +1,239 @@ +# Using TensorFlow Securely + +This document discusses how to safely deal with untrusted programs (models or +model parameters), and input data. Below, we also provide guidelines on how to +report vulnerabilities in TensorFlow. + +## TensorFlow models are programs + +TensorFlow's runtime system interprets and executes programs. What machine +learning practitioners term +[**models**](https://developers.google.com/machine-learning/glossary/#model) are +expressed as programs that TensorFlow executes. TensorFlow programs are encoded +as computation +[**graphs**](https://developers.google.com/machine-learning/glossary/#graph). +The model's parameters are often stored separately in **checkpoints**. + +At runtime, TensorFlow executes the computation graph using the parameters +provided. Note that the behavior of the computation graph may change +depending on the parameters provided. TensorFlow itself is not a sandbox. When +executing the computation graph, TensorFlow may read and write files, send and +receive data over the network, and even spawn additional processes. All these +tasks are performed with the permissions of the TensorFlow process. Allowing +for this flexibility makes for a powerful machine learning platform, +but it has implications for security. + +The computation graph may also accept **inputs**. Those inputs are the +data you supply to TensorFlow to train a model, or to use a model to run +inference on the data. + +**TensorFlow models are programs, and need to be treated as such from a security +perspective.** + +## Running untrusted models + +As a general rule: **Always** execute untrusted models inside a sandbox (e.g., +[nsjail](https://github.com/google/nsjail)). + +There are several ways in which a model could become untrusted. Obviously, if an +untrusted party supplies TensorFlow kernels, arbitrary code may be executed. +The same is true if the untrusted party provides Python code, such as the +Python code that generates TensorFlow graphs. + +Even if the untrusted party only supplies the serialized computation +graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the +set of computation primitives available to TensorFlow is powerful enough that +you should assume that the TensorFlow process effectively executes arbitrary +code. One common solution is to whitelist only a few safe Ops. While this is +possible in theory, we still recommend you sandbox the execution. + +It depends on the computation graph whether a user provided checkpoint is safe. +It is easily possible to create computation graphs in which malicious +checkpoints can trigger unsafe behavior. For example, consider a graph that +contains a `tf.cond` depending on the value of a `tf.Variable`. One branch of +the `tf.cond` is harmless, but the other is unsafe. Since the `tf.Variable` is +stored in the checkpoint, whoever provides the checkpoint now has the ability to +trigger unsafe behavior, even though the graph is not under their control. + +In other words, graphs can contain vulnerabilities of their own. To allow users +to provide checkpoints to a model you run on their behalf (e.g., in order to +compare model quality for a fixed model architecture), you must carefully audit +your model, and we recommend you run the TensorFlow process in a sandbox. + +## Accepting untrusted Inputs + +It is possible to write models that are secure in a sense that they can safely +process untrusted inputs assuming there are no bugs. There are two main reasons +to not rely on this: first, it is easy to write models which must not be exposed +to untrusted inputs, and second, there are bugs in any software system of +sufficient complexity. Letting users control inputs could allow them to trigger +bugs either in TensorFlow or in dependent libraries. + +In general, it is good practice to isolate parts of any system which is exposed +to untrusted (e.g., user-provided) inputs in a sandbox. + +A useful analogy to how any TensorFlow graph is executed is any interpreted +programming language, such as Python. While it is possible to write secure +Python code which can be exposed to user supplied inputs (by, e.g., carefully +quoting and sanitizing input strings, size-checking input blobs, etc.), it is +very easy to write Python programs which are insecure. Even secure Python code +could be rendered insecure by a bug in the Python interpreter, or in a bug in a +Python library used (e.g., +[this one](https://www.cvedetails.com/cve/CVE-2017-12852/)). + +## Running a TensorFlow server + +TensorFlow is a platform for distributed computing, and as such there is a +TensorFlow server (`tf.train.Server`). **The TensorFlow server is meant for +internal communication only. It is not built for use in an untrusted network.** + +For performance reasons, the default TensorFlow server does not include any +authorization protocol and sends messages unencrypted. It accepts connections +from anywhere, and executes the graphs it is sent without performing any checks. +Therefore, if you run a `tf.train.Server` in your network, anybody with +access to the network can execute what you should consider arbitrary code with +the privileges of the process running the `tf.train.Server`. + +When running distributed TensorFlow, you must isolate the network in which the +cluster lives. Cloud providers provide instructions for setting up isolated +networks, which are sometimes branded as "virtual private cloud." Refer to the +instructions for +[GCP](https://cloud.google.com/compute/docs/networks-and-firewalls) and +[AWS](https://aws.amazon.com/vpc/)) for details. + +Note that `tf.train.Server` is different from the server created by +`tensorflow/serving` (the default binary for which is called `ModelServer`). +By default, `ModelServer` also has no built-in mechanism for authentication. +Connecting it to an untrusted network allows anyone on this network to run the +graphs known to the `ModelServer`. This means that an attacker may run +graphs using untrusted inputs as described above, but they would not be able to +execute arbitrary graphs. It is possible to safely expose a `ModelServer` +directly to an untrusted network, **but only if the graphs it is configured to +use have been carefully audited to be safe**. + +Similar to best practices for other servers, we recommend running any +`ModelServer` with appropriate privileges (i.e., using a separate user with +reduced permisisons). In the spirit of defense in depth, we recommend +authenticating requests to any TensorFlow server connected to an untrusted +network, as well as sandboxing the server to minimize the adverse effects of +any breach. + +## Vulnerabilities in TensorFlow + +TensorFlow is a large and complex system. It also depends on a large set of +third party libraries (e.g., `numpy`, `libjpeg-turbo`, PNG parsers, `protobuf`). +It is possible that TensorFlow or its dependent libraries contain +vulnerabilities that would allow triggering unexpected or dangerous behavior +with specially crafted inputs. + +### What is a vulnerability? + +Given TensorFlow's flexibility, it is possible to specify computation graphs +which exhibit unexpected or unwanted behaviors. The fact that TensorFlow models +can perform arbitrary computations means that they may read and write files, +communicate via the network, produce deadlocks and infinite loops, or run out +of memory. It is only when these behaviors are outside the specifications of the +operations involved that such behavior is a vulnerability. + +A `FileWriter` writing a file is not unexpected behavior and therefore is not a +vulnerability in TensorFlow. A `MatMul` allowing arbitrary binary code execution +**is** a vulnerability. + +This is more subtle from a system perspective. For example, it is easy to cause +a TensorFlow process to try to allocate more memory than available by specifying +a computation graph containing an ill-considered `tf.tile` operation. TensorFlow +should exit cleanly in this case (it would raise an exception in Python, or +return an error `Status` in C++). However, if the surrounding system is not +expecting the possibility, such behavior could be used in a denial of service +attack (or worse). Because TensorFlow behaves correctly, this is not a +vulnerability in TensorFlow (although it would be a vulnerability of this +hypothetical system). + +As a general rule, it is incorrect behavior for Tensorflow to access memory it +does not own, or to terminate in an unclean way. Bugs in TensorFlow that lead to +such behaviors constitute a vulnerability. + +One of the most critical parts of any system is input handling. If malicious +input can trigger side effects or incorrect behavior, this is a bug, and likely +a vulnerability. + +### Reporting vulnerabilities + +Please email reports about any security related issues you find to +`security@tensorflow.org`. This mail is delivered to a small security team. Your +email will be acknowledged within one business day, and you'll receive a more +detailed response to your email within 7 days indicating the next steps in +handling your report. For critical problems, you may encrypt your report (see +below). + +Please use a descriptive subject line for your report email. After the initial +reply to your report, the security team will endeavor to keep you informed of +the progress being made towards a fix and announcement. + +If you believe that an existing (public) issue is security-related, please send +an email to `security@tensorflow.org`. The email should include the issue ID and +a short description of why it should be handled according to this security +policy. + +Once an issue is reported, TensorFlow uses the following disclosure process: + +* When a report is received, we confirm the issue and determine its severity. +* If we know of specific third-party services or software based on TensorFlow + that require mitigation before publication, those projects will be notified. +* An advisory is prepared (but not published) which details the problem and + steps for mitigation. +* Wherever possible, fixes are prepared for the last minor release of the two + latest major releases, as well as the master branch. We will attempt to + commit these fixes as soon as possible, and as close together as + possible. +* Patch releases are published for all fixed released versions, a + notification is sent to discuss@tensorflow.org, and the advisory is published. + +Past security advisories are listed below. We credit reporters for identifying +security issues, although we keep your name confidential if you request it. + +#### Encryption key for `security@tensorflow.org` + +If your disclosure is extremely sensitive, you may choose to encrypt your +report using the key below. Please only use this for critical security +reports. + +``` +-----BEGIN PGP PUBLIC KEY BLOCK----- + +mQENBFpqdzwBCADTeAHLNEe9Vm77AxhmGP+CdjlY84O6DouOCDSq00zFYdIU/7aI +LjYwhEmDEvLnRCYeFGdIHVtW9YrVktqYE9HXVQC7nULU6U6cvkQbwHCdrjaDaylP +aJUXkNrrxibhx9YYdy465CfusAaZ0aM+T9DpcZg98SmsSml/HAiiY4mbg/yNVdPs +SEp/Ui4zdIBNNs6at2gGZrd4qWhdM0MqGJlehqdeUKRICE/mdedXwsWLM8AfEA0e +OeTVhZ+EtYCypiF4fVl/NsqJ/zhBJpCx/1FBI1Uf/lu2TE4eOS1FgmIqb2j4T+jY +e+4C8kGB405PAC0n50YpOrOs6k7fiQDjYmbNABEBAAG0LVRlbnNvckZsb3cgU2Vj +dXJpdHkgPHNlY3VyaXR5QHRlbnNvcmZsb3cub3JnPokBTgQTAQgAOBYhBEkvXzHm +gOJBnwP4Wxnef3wVoM2yBQJaanc8AhsDBQsJCAcCBhUKCQgLAgQWAgMBAh4BAheA +AAoJEBnef3wVoM2yNlkIAICqetv33MD9W6mPAXH3eon+KJoeHQHYOuwWfYkUF6CC +o+X2dlPqBSqMG3bFuTrrcwjr9w1V8HkNuzzOJvCm1CJVKaxMzPuXhBq5+DeT67+a +T/wK1L2R1bF0gs7Pp40W3np8iAFEh8sgqtxXvLGJLGDZ1Lnfdprg3HciqaVAiTum +HBFwszszZZ1wAnKJs5KVteFN7GSSng3qBcj0E0ql2nPGEqCVh+6RG/TU5C8gEsEf +3DX768M4okmFDKTzLNBm+l08kkBFt+P43rNK8dyC4PXk7yJa93SmS/dlK6DZ16Yw +2FS1StiZSVqygTW59rM5XNwdhKVXy2mf/RtNSr84gSi5AQ0EWmp3PAEIALInfBLR +N6fAUGPFj+K3za3PeD0fWDijlC9f4Ety/icwWPkOBdYVBn0atzI21thPRbfuUxfe +zr76xNNrtRRlbDSAChA1J5T86EflowcQor8dNC6fS+oHFCGeUjfEAm16P6mGTo0p +osdG2XnnTHOOEFbEUeWOwR/zT0QRaGGknoy2pc4doWcJptqJIdTl1K8xyBieik/b +nSoClqQdZJa4XA3H9G+F4NmoZGEguC5GGb2P9NHYAJ3MLHBHywZip8g9oojIwda+ +OCLL4UPEZ89cl0EyhXM0nIAmGn3Chdjfu3ebF0SeuToGN8E1goUs3qSE77ZdzIsR +BzZSDFrgmZH+uP0AEQEAAYkBNgQYAQgAIBYhBEkvXzHmgOJBnwP4Wxnef3wVoM2y +BQJaanc8AhsMAAoJEBnef3wVoM2yX4wIALcYZbQhSEzCsTl56UHofze6C3QuFQIH +J4MIKrkTfwiHlCujv7GASGU2Vtis5YEyOoMidUVLlwnebE388MmaJYRm0fhYq6lP +A3vnOCcczy1tbo846bRdv012zdUA+wY+mOITdOoUjAhYulUR0kiA2UdLSfYzbWwy +7Obq96Jb/cPRxk8jKUu2rqC/KDrkFDtAtjdIHh6nbbQhFuaRuWntISZgpIJxd8Bt +Gwi0imUVd9m9wZGuTbDGi6YTNk0GPpX5OMF5hjtM/objzTihSw9UN+65Y/oSQM81 +v//Fw6ZeY+HmRDFdirjD7wXtIuER4vqCryIqR6Xe9X8oJXz9L/Jhslc= +=CDME +-----END PGP PUBLIC KEY BLOCK----- +``` + +### Known vulnerabilities + +| Type | Versions affected | Reported by | Additional Information | +|------|:-----------------:|---------------------------------------| +| out of bounds read| <=1.4 | TenCent Blade Team | [issue report](https://github.com/tensorflow/tensorflow/issues/14959) | + diff --git a/tensorflow/__init__.py b/tensorflow/__init__.py index 083634bd7964b0c12e10a1f3c71be5eab597a6c4..78ad6aec19f3bbbfcb389012ac1577573b3e4901 100644 --- a/tensorflow/__init__.py +++ b/tensorflow/__init__.py @@ -21,7 +21,7 @@ from __future__ import division from __future__ import print_function # pylint: disable=wildcard-import -from tensorflow.python import * +from tensorflow.python import * # pylint: disable=redefined-builtin # pylint: enable=wildcard-import from tensorflow.python.util.lazy_loader import LazyLoader diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index ef7eb5a4d16b29aecc34f33cb41dd7cf9450c5f2..9060c58c1395f07eff0ccef7bd430b3402f8c826 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -6,6 +6,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", "tf_cc_test", + "tf_cuda_cc_test", "tf_copts", "tf_cuda_library", "tf_custom_op_library", @@ -26,6 +27,18 @@ filegroup( visibility = ["//tensorflow:__subpackages__"], ) +filegroup( + name = "srcs", + srcs = glob( + [ + "*.cc", + "*.h", + ], + exclude = ["*test*"], + ), + visibility = ["//visibility:public"], +) + tf_cuda_library( name = "c_api_internal", srcs = ["c_api.h"], @@ -42,6 +55,7 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:op_gen_lib", ], }), ) @@ -73,10 +87,17 @@ tf_cuda_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:op_gen_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", ], + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], }), ) @@ -121,15 +142,21 @@ tf_cuda_library( testonly = 1, srcs = ["c_test_util.cc"], hdrs = ["c_test_util.h"], + visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", + ], deps = [ ":c_api", + "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", "//tensorflow/core:test", ], ) -tf_cc_test( +tf_cuda_cc_test( name = "c_api_test", size = "small", srcs = ["c_api_test.cc"], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index bb41f92306b413d610bf115d144b15faa568ee14..85f1d1639b4d09f2de77d326481a86ec246270d0 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/framework/scope_internal.h" #include "tensorflow/cc/ops/while_loop.h" #include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/framework/op_gen_lib.h" #endif #include "tensorflow/c/c_api_internal.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -63,6 +64,7 @@ using tensorflow::AllocationDescription; using tensorflow::DataType; using tensorflow::Graph; using tensorflow::GraphDef; +using tensorflow::mutex_lock; using tensorflow::NameRangeMap; using tensorflow::NameRangesForNode; using tensorflow::NewSession; @@ -76,6 +78,7 @@ using tensorflow::RunMetadata; using tensorflow::RunOptions; using tensorflow::Session; using tensorflow::Status; +using tensorflow::string; using tensorflow::Tensor; using tensorflow::TensorBuffer; using tensorflow::TensorId; @@ -86,8 +89,6 @@ using tensorflow::error::Code; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; -using tensorflow::mutex_lock; -using tensorflow::string; using tensorflow::strings::StrCat; extern "C" { @@ -108,6 +109,10 @@ TF_Status* TF_NewStatus() { return new TF_Status; } void TF_DeleteStatus(TF_Status* s) { delete s; } void TF_SetStatus(TF_Status* s, TF_Code code, const char* msg) { + if (code == TF_OK) { + s->status = Status::OK(); + return; + } s->status = Status(static_cast(code), tensorflow::StringPiece(msg)); } @@ -194,11 +199,11 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, reinterpret_cast(data) % EIGEN_MAX_ALIGN_BYTES != 0) { // TF_STRING and TF_RESOURCE tensors have a different representation in // TF_Tensor than they do in tensorflow::Tensor. So a copy here is a waste - // (any alignement requirements will be taken care of by TF_TensorToTensor + // (any alignment requirements will be taken care of by TF_TensorToTensor // and TF_TensorFromTensor). // - // Other types have the same represntation, so copy only if it is safe to do - // so. + // Other types have the same representation, so copy only if it is safe to + // do so. buf->data_ = allocate_tensor("TF_NewTensor", len); std::memcpy(buf->data_, data, len); buf->deallocator_ = deallocate_buffer; @@ -210,7 +215,13 @@ TF_Tensor* TF_NewTensor(TF_DataType dtype, const int64_t* dims, int num_dims, buf->deallocator_ = deallocator; buf->deallocator_arg_ = deallocator_arg; } - return new TF_Tensor{dtype, TensorShape(dimvec), buf}; + TF_Tensor* ret = new TF_Tensor{dtype, TensorShape(dimvec), buf}; + size_t elem_size = TF_DataTypeSize(dtype); + if (elem_size > 0 && len < (elem_size * ret->shape.num_elements())) { + delete ret; + return nullptr; + } + return ret; } TF_Tensor* TF_TensorMaybeMove(TF_Tensor* tensor) { @@ -383,12 +394,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, // be less than the total node count. Status ValidateNoCycles(const Graph& g) { // TODO(nolivia): check this on a subset of the graph instead of all of it. - int total_num_nodes = g.num_node_ids(); // A node is ready when all of its inputs have been visited. std::vector ready; - std::vector pending_count(total_num_nodes, 0); + std::vector pending_count(g.num_node_ids(), 0); - for (int i = 0; i < total_num_nodes; ++i) { + for (int i = 0; i < g.num_node_ids(); ++i) { const Node* n = g.FindNodeId(i); if (n == nullptr) continue; pending_count[i] = n->in_edges().size(); @@ -421,7 +431,7 @@ Status ValidateNoCycles(const Graph& g) { } } - if (processed < total_num_nodes) { + if (processed < g.num_nodes()) { std::vector nodes_in_cycle; for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3; ++i) { @@ -430,7 +440,7 @@ Status ValidateNoCycles(const Graph& g) { } } return errors::InvalidArgument( - "Graph is invalid, contains a cycle with ", total_num_nodes - processed, + "Graph is invalid, contains a cycle with ", g.num_nodes() - processed, " nodes, including: ", str_util::Join(nodes_in_cycle, ", ")); } return Status::OK(); @@ -580,6 +590,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (string #", i, " of ", srcarray.size(), "): ", status->status.error_message()); + delete[] base; return nullptr; } dst += consumed; @@ -589,6 +600,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, status->status = InvalidArgument( "invalid string tensor encoding (decoded ", (dst - base), " bytes, but the tensor is encoded in ", size, " bytes"); + delete[] base; return nullptr; } @@ -625,6 +637,73 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in, return Status::OK(); } +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type) + EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { + // If any session has already run this node_id, mark this session as + // unrunnable. + for (auto it : graph->sessions) { + if (it.first->last_num_graph_nodes > op.node.id()) { + it.second = FailedPrecondition( + "Operation '", op.node.DebugString(), "' was changed by ", + mutation_type, + " after it was run by a session. Nodes can be mutated " + "only before they are executed by a session. Either don't modify " + "nodes after running them or create a new session."); + } + } +} + +namespace { + +// Helper method that creates a shape handle for a shape described by dims. +tensorflow::shape_inference::ShapeHandle ShapeHandleFromDims( + tensorflow::shape_inference::InferenceContext* ic, int num_dims, + const int64_t* dims) { + if (num_dims != -1) { + std::vector dim_vec; + dim_vec.reserve(num_dims); + for (int i = 0; i < num_dims; ++i) { + dim_vec.push_back(ic->MakeDim(dims[i])); + } + return ic->MakeShape(dim_vec); + } else { + return ic->UnknownShape(); + } +} + +} // namespace + +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status) { + Node* node = &output.oper->node; + + mutex_lock l(graph->mu); + tensorflow::shape_inference::InferenceContext* ic = + graph->refiner.GetContext(node); + if (ic == nullptr) { + status->status = + InvalidArgument("Node ", node->name(), " was not found in the graph"); + return; + } + + auto shape_and_type_vec = + std::vector( + num_shapes_and_types); + for (int i = 0; i < num_shapes_and_types; ++i) { + tensorflow::shape_inference::ShapeHandle shape_handle = + ShapeHandleFromDims(ic, ranks[i], shapes[i]); + shape_and_type_vec[i] = tensorflow::shape_inference::ShapeAndType( + shape_handle, static_cast(types[i])); + } + + ic->set_output_handle_shapes_and_types(output.index, shape_and_type_vec); +} + // Helpers for loading a TensorFlow plugin (a .so file). Status LoadLibrary(const char* library_filename, void** result, const void** buf, size_t* len); @@ -858,6 +937,7 @@ int TF_DeviceListCount(const TF_DeviceList* list) { status->status = InvalidArgument("index out of bounds"); \ return err_val; \ } \ + status->status = Status::OK(); \ return list->response[index].accessor; \ } @@ -930,7 +1010,6 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, Node* node = &output.oper->node; mutex_lock l(graph->mu); - // Set the shape. tensorflow::shape_inference::InferenceContext* ic = graph->refiner.GetContext(node); if (ic == nullptr) { @@ -938,18 +1017,8 @@ void TF_GraphSetTensorShape(TF_Graph* graph, TF_Output output, InvalidArgument("Node ", node->name(), " was not found in the graph"); return; } - - tensorflow::shape_inference::ShapeHandle new_shape; - if (num_dims != -1) { - std::vector dim_vec; - dim_vec.reserve(num_dims); - for (int i = 0; i < num_dims; ++i) { - dim_vec.push_back(ic->MakeDim(dims[i])); - } - new_shape = ic->MakeShape(dim_vec); - } else { - new_shape = ic->UnknownShape(); - } + tensorflow::shape_inference::ShapeHandle new_shape = + tensorflow::ShapeHandleFromDims(ic, num_dims, dims); status->status = graph->refiner.SetShape(node, output.index, new_shape); } @@ -1143,6 +1212,13 @@ void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, reinterpret_cast(values), num_values)); } +void TF_SetAttrFuncName(TF_OperationDescription* desc, const char* attr_name, + const char* value, size_t length) { + tensorflow::NameAttrList func_name; + func_name.set_name(std::string(value, value + length)); + desc->node_builder.Attr(attr_name, func_name); +} + void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, const int64_t* dims, int num_dims) { PartialTensorShape shape; @@ -1404,7 +1480,13 @@ int TF_OperationOutputConsumers(TF_Output oper_out, TF_Input* consumers, } int TF_OperationNumControlInputs(TF_Operation* oper) { - return oper->node.in_edges().size() - oper->node.num_inputs(); + int count = 0; + for (const auto* edge : oper->node.in_edges()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { + ++count; + } + } + return count; } int TF_OperationGetControlInputs(TF_Operation* oper, @@ -1412,7 +1494,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper, int max_control_inputs) { int count = 0; for (const auto* edge : oper->node.in_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->src()->IsSource()) { if (count < max_control_inputs) { control_inputs[count] = ToOperation(edge->src()); } @@ -1425,7 +1507,7 @@ int TF_OperationGetControlInputs(TF_Operation* oper, int TF_OperationNumControlOutputs(TF_Operation* oper) { int count = 0; for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { ++count; } } @@ -1437,7 +1519,7 @@ int TF_OperationGetControlOutputs(TF_Operation* oper, int max_control_outputs) { int count = 0; for (const auto* edge : oper->node.out_edges()) { - if (edge->IsControlEdge()) { + if (edge->IsControlEdge() && !edge->dst()->IsSink()) { if (count < max_control_outputs) { control_outputs[count] = ToOperation(edge->dst()); } @@ -1745,7 +1827,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def, TF_Graph::TF_Graph() : graph(tensorflow::OpRegistry::Global()), refiner(graph.versions().producer(), graph.op_registry()), - num_sessions(0), delete_requested(false), parent(nullptr), parent_inputs(nullptr) {} @@ -1755,7 +1836,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; } void TF_DeleteGraph(TF_Graph* g) { g->mu.lock(); g->delete_requested = true; - const bool del = g->num_sessions == 0; + const bool del = g->sessions.empty(); g->mu.unlock(); if (del) delete g; } @@ -1835,6 +1916,16 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, opts->opts.prefix = prefix; } +void TF_ImportGraphDefOptionsSetUniquifyNames(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_names) { + opts->opts.uniquify_names = uniquify_names; +} + +void TF_ImportGraphDefOptionsSetUniquifyPrefix(TF_ImportGraphDefOptions* opts, + unsigned char uniquify_prefix) { + opts->opts.uniquify_prefix = uniquify_prefix; +} + void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { @@ -1892,12 +1983,12 @@ void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, *opers = results->return_nodes.data(); } -void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes) { - *num_unused_input_mappings = results->unused_key_names.size(); - *src_names = results->unused_key_names.data(); - *src_indexes = results->unused_key_indexes.data(); + *num_missing_unused_input_mappings = results->missing_unused_key_names.size(); + *src_names = results->missing_unused_key_names.data(); + *src_indexes = results->missing_unused_key_indexes.data(); } void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { @@ -1937,18 +2028,21 @@ static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); } - // Populate unused map keys - DCHECK(tf_results->unused_key_names.empty()); - DCHECK(tf_results->unused_key_indexes.empty()); - DCHECK(tf_results->unused_key_names_data.empty()); - tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); - tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); - for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { - TensorId id = results.unused_input_map_keys[i]; - tf_results->unused_key_names_data.push_back(id.first.ToString()); - tf_results->unused_key_names[i] = - tf_results->unused_key_names_data.back().c_str(); - tf_results->unused_key_indexes[i] = id.second; + // Populate missing unused map keys + DCHECK(tf_results->missing_unused_key_names.empty()); + DCHECK(tf_results->missing_unused_key_indexes.empty()); + DCHECK(tf_results->missing_unused_key_names_data.empty()); + + size_t size = results.missing_unused_input_map_keys.size(); + tf_results->missing_unused_key_names.resize(size); + tf_results->missing_unused_key_indexes.resize(size); + + for (int i = 0; i < size; ++i) { + TensorId id = results.missing_unused_input_map_keys[i]; + tf_results->missing_unused_key_names_data.push_back(id.first.ToString()); + tf_results->missing_unused_key_names[i] = + tf_results->missing_unused_key_names_data.back().c_str(); + tf_results->missing_unused_key_indexes[i] = id.second; } } @@ -2060,7 +2154,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, opts.return_tensors.push_back(ToTensorId(nodes_to_return[i])); } - // TOOD(skyewm): change to OutputTensor + // TODO(skyewm): change to OutputTensor tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); @@ -2325,11 +2419,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, Session* session; status->status = NewSession(opt->options, &session); if (status->status.ok()) { + TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); - graph->num_sessions += 1; + graph->sessions[new_session] = Status::OK(); } - return new TF_Session(session, graph); + return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; @@ -2393,7 +2488,7 @@ TF_Session* TF_LoadSessionFromSavedModel( TF_Session* session = new TF_Session(bundle.session.release(), graph); - graph->num_sessions += 1; + graph->sessions[session] = Status::OK(); session->last_num_graph_nodes = graph->graph.num_node_ids(); return session; #endif // __ANDROID__ @@ -2408,8 +2503,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) { TF_Graph* const graph = s->graph; if (graph != nullptr) { graph->mu.lock(); - graph->num_sessions -= 1; - const bool del = graph->delete_requested && graph->num_sessions == 0; + graph->sessions.erase(s); + const bool del = graph->delete_requested && graph->sessions.empty(); graph->mu.unlock(); if (del) delete graph; } @@ -2425,6 +2520,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) { mutex_lock session_lock(session->mu); session->graph->mu.lock(); const Graph& graph = session->graph->graph; + + status->status = session->graph->sessions[session]; + if (!status->status.ok()) { + session->graph->mu.unlock(); + return false; + } + const auto num_nodes = graph.num_node_ids(); if (session->last_num_graph_nodes < num_nodes) { status->status = tensorflow::ValidateNoCycles(session->graph->graph); @@ -2580,4 +2682,54 @@ void TF_SessionPRun(TF_Session* session, const char* handle, output_values, target_names, nullptr, status); } +TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, TF_Status* status) { + tensorflow::OpList op_list; + if (!op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length)) { + status->status = InvalidArgument("Unparseable OpList"); + return nullptr; + } + status->status = Status::OK(); + return new TF_ApiDefMap(op_list); +} + +void TF_DeleteApiDefMap(TF_ApiDefMap* apimap) { delete apimap; } + +void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, const char* text, + size_t text_len, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported in Android."); +#else + mutex_lock l(api_def_map->lock); + if (api_def_map->update_docs_called) { + status->status = FailedPrecondition( + "TF_ApiDefMapPut cannot be called after TF_ApiDefMapGet has been " + "called."); + return; + } + string api_def_text(text, text_len); + status->status = api_def_map->api_def_map.LoadApiDef(api_def_text); +#endif // __ANDROID__ +} + +TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, const char* name, + size_t name_len, TF_Status* status) { +#ifdef __ANDROID__ + status->status = tensorflow::errors::Unimplemented( + "ApiDefMap is not supported in Android."); + return nullptr; +#else + mutex_lock l(api_def_map->lock); + if (!api_def_map->update_docs_called) { + api_def_map->api_def_map.UpdateDocs(); + api_def_map->update_docs_called = true; + } + string name_str(name, name_len); + const auto* api_def = api_def_map->api_def_map.GetApiDef(name_str); + + TF_Buffer* ret = TF_NewBuffer(); + status->status = MessageToBuffer(*api_def, ret); + return ret; +#endif // __ANDROID__ +} } // end extern "C" diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index bb569d67fcbcec29e9494236abd79b3e40db91cd..ad592ef70961ef427bfe9fd322a82bd64df7f9f1 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -226,6 +226,10 @@ typedef struct TF_Tensor TF_Tensor; // (*deallocator)(data, len, deallocator_arg) // Clients must provide a custom deallocator function so they can pass in // memory managed by something like numpy. +// +// May return NULL (and invoke the deallocator) if the provided data buffer +// (data, len) is inconsistent with a tensor of the given TF_DataType +// and the shape specified by (dima, num_dims). TF_CAPI_EXPORT extern TF_Tensor* TF_NewTensor( TF_DataType, const int64_t* dims, int num_dims, void* data, size_t len, void (*deallocator)(void* data, size_t len, void* arg), @@ -511,6 +515,11 @@ TF_CAPI_EXPORT extern void TF_SetAttrTypeList(TF_OperationDescription* desc, const char* attr_name, const TF_DataType* values, int num_values); +// Set a 'func' attribute to the specified name. +// `value` must point to a string of length `length` bytes. +TF_CAPI_EXPORT extern void TF_SetAttrFuncName(TF_OperationDescription* desc, + const char* attr_name, + const char* value, size_t length); // Set `num_dims` to -1 to represent "unknown rank". Otherwise, // `dims` points to an array of length `num_dims`. `dims[i]` must be @@ -889,6 +898,20 @@ TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefOptions( TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetPrefix( TF_ImportGraphDefOptions* opts, const char* prefix); +// Set whether to uniquify imported operation names. If true, imported operation +// names will be modified if their name already exists in the graph. If false, +// conflicting names will be treated as an error. Note that this option has no +// effect if a prefix is set, since the prefix will guarantee all names are +// unique. Defaults to false. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyNames( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_names); + +// If true, the specified prefix will be modified if it already exists as an +// operation name or prefix in the graph. If false, a conflicting prefix will be +// treated as an error. This option has no effect if no prefix is specified. +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsSetUniquifyPrefix( + TF_ImportGraphDefOptions* opts, unsigned char uniquify_prefix); + // Set any imported nodes with input `src_name:src_index` to have that input // replaced with `dst`. `src_name` refers to a node in the graph to be imported, // `dst` references a node already existing in the graph being imported into. @@ -948,16 +971,16 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); // Fetches any input mappings requested via -// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any -// node in the imported graph def. The number of fetched mappings is returned in -// `num_unused_input_mappings`. The array of each mapping's source node name is -// returned in `src_names`, and the array of each mapping's source index is -// returned in `src_indexes`. +// TF_ImportGraphDefOptionsAddInputMapping() that didn't appear in the GraphDef +// and weren't used as input to any node in the imported graph def. The number +// of fetched mappings is returned in `num_missing_unused_input_mappings`. The +// array of each mapping's source node name is returned in `src_names`, and the +// array of each mapping's source index is returned in `src_indexes`. // // `*src_names`, `*src_indexes`, and the memory backing each string in // `src_names` are owned by and have the lifetime of `results`. -TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings( - TF_ImportGraphDefResults* results, int* num_unused_input_mappings, +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsMissingUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_missing_unused_input_mappings, const char*** src_names, int** src_indexes); // Deletes a results object returned by TF_GraphImportGraphDefWithResults(). @@ -1015,6 +1038,23 @@ TF_CAPI_EXPORT extern void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* grad, TF_Status* status); +// Returns the number of TF_Functions registered in `g`. +TF_CAPI_EXPORT extern int TF_GraphNumFunctions(TF_Graph* g); + +// Fills in `funcs` with the TF_Function* registered in `g`. +// `funcs` must point to an array of TF_Function* of length at least +// `max_func`. In usual usage, max_func should be set to the result of +// TF_GraphNumFunctions(g). In this case, all the functions registered in +// `g` will be returned. Else, an unspecified subset. +// +// If successful, returns the number of TF_Function* successfully set in +// `funcs` and sets status to OK. The caller takes ownership of +// all the returned TF_Functions. They must be deleted with TF_DeleteFunction. +// On error, returns 0, sets status to the encountered error, and the contents +// of funcs will be undefined. +TF_CAPI_EXPORT extern int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, + int max_func, TF_Status* status); + // Note: The following function may fail on very large protos in the future. TF_CAPI_EXPORT extern void TF_OperationToNodeDef(TF_Operation* oper, @@ -1247,11 +1287,12 @@ TF_CAPI_EXPORT extern void TF_DeleteFunction(TF_Function* func); typedef struct TF_Session TF_Session; -// Return a new execution session with the associated graph, or NULL on error. +// Return a new execution session with the associated graph, or NULL on +// error. Does not take ownership of any input parameters. // -// *graph must be a valid graph (not deleted or nullptr). This function will -// prevent the graph from being deleted until TF_DeleteSession() is called. -// Does not take ownership of opts. +// *`graph` must be a valid graph (not deleted or nullptr). `graph` will be be +// kept alive for the lifetime of the returned TF_Session. New nodes can still +// be added to `graph` after this call. TF_CAPI_EXPORT extern TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opts, TF_Status* status); @@ -1504,6 +1545,49 @@ TF_CAPI_EXPORT extern void TF_DeleteLibraryHandle(TF_Library* lib_handle); // in this address space. TF_CAPI_EXPORT extern TF_Buffer* TF_GetAllOpList(); +// TF_ApiDefMap encapsulates a collection of API definitions for an operation. +// +// This object maps the name of a TensorFlow operation to a description of the +// API to generate for it, as defined by the ApiDef protocol buffer ( +// https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto) +// +// The ApiDef messages are typically used to generate convenience wrapper +// functions for TensorFlow operations in various language bindings. +typedef struct TF_ApiDefMap TF_ApiDefMap; + +// Creates a new TF_ApiDefMap instance. +// +// Params: +// op_list_buffer - TF_Buffer instance containing serialized OpList +// protocol buffer. (See +// https://www.tensorflow.org/code/tensorflow/core/framework/op_def.proto +// for the OpList proto definition). +// status - Set to OK on success and an appropriate error on failure. +TF_CAPI_EXPORT extern TF_ApiDefMap* TF_NewApiDefMap(TF_Buffer* op_list_buffer, + TF_Status* status); + +// Deallocates a TF_ApiDefMap. +TF_CAPI_EXPORT extern void TF_DeleteApiDefMap(TF_ApiDefMap* apimap); + +// Add ApiDefs to the map. +// +// `text` corresponds to a text representation of an ApiDefs protocol message. +// (https://www.tensorflow.org/code/tensorflow/core/framework/api_def.proto). +// +// The provided ApiDefs will be merged with existing ones in the map, with +// precedence given to the newly added version in case of conflicts with +// previous calls to TF_ApiDefMapPut. +TF_CAPI_EXPORT extern void TF_ApiDefMapPut(TF_ApiDefMap* api_def_map, + const char* text, size_t text_len, + TF_Status* status); + +// Returns a serialized ApiDef protocol buffer for the TensorFlow operation +// named `name`. +TF_CAPI_EXPORT extern TF_Buffer* TF_ApiDefMapGet(TF_ApiDefMap* api_def_map, + const char* name, + size_t name_len, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index dcb818b88b6fca460852beb6e948d2eb6964f663..384e6c8cb97022264c5327da5ca5861057608fbe 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -44,8 +44,12 @@ class NodeNameMapping { public: NodeNameMapping() = default; - // Normalize the input/output name and make it unique. - string GetIOName(const string& name); + // Normalize the input name and make it unique. This is the same as the + // function for output, expect that it adds a name mapping for the name. + string GetInputName(const string& name); + + // Normalize the output name and make it unique. + string GetOutputName(const string& name); // Make the node name unique. string Uniquify(const string& name); @@ -68,7 +72,7 @@ class NodeNameMapping { // This is a superset of values in name_mapping_. std::unordered_set used_names_; // Mapping from original node name from the graph to the normalized - // and uniqified version of it. + // and uniquified version of it. std::unordered_map name_mapping_; }; @@ -107,7 +111,13 @@ string NodeNameMapping::UniquifyHelper(const string& name) const { } } -string NodeNameMapping::GetIOName(const string& name) { +string NodeNameMapping::GetInputName(const string& name) { + const string& input_name = GetOutputName(name); + name_mapping_[name] = input_name; + return input_name; +} + +string NodeNameMapping::GetOutputName(const string& name) { const string& input_name = UniquifyHelper(Normalize(name)); // Record that we used this name, but don't add it to name_mapping_ // since this name is not for a node. @@ -214,10 +224,11 @@ Status FillFunctionBody( // Add control inputs. for (const Edge* edge : control_edges) { - // Add this control input only if the src node is in the body. + // Add this control input only if the src node is in the body or a part of + // the inputs. const string normalized = node_names.Lookup(edge->src()->name()); // If we did not find a name for the source of control edge, this - // source must be outside of the body. Raise an error. + // source must be outside of the body, and not an input. Raise an error. if (normalized.empty()) { return InvalidArgument( "The source of control edge ", edge->DebugString(), @@ -226,12 +237,17 @@ Status FillFunctionBody( } node_def->add_input(strings::StrCat("^", normalized)); } + + // A function is stateful if any of its nodes are stateful. + if (node->op_def().is_stateful()) { + fdef->mutable_signature()->set_is_stateful(true); + } } return Status::OK(); } // Graph to FunctionDef conversion. This code is closely modeled on the Python -// code in third_party/tensorflow/python/framework/function.py. +// code in tensorflow/python/framework/function.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, bool append_hash_to_fn_name, const std::vector& body_nodes, @@ -274,7 +290,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, TF_RETURN_IF_ERROR(node_names.UseOutputName(output_names[i])); argdef->set_name(output_names[i]); } else { - argdef->set_name(node_names.GetIOName(node->name())); + argdef->set_name(node_names.GetOutputName(node->name())); } } @@ -284,7 +300,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, int idx = inputs[i].index; OpDef::ArgDef* argdef = fdef->mutable_signature()->add_input_arg(); argdef->set_type(node->output_type(idx)); - const string& input_name = node_names.GetIOName(node->name()); + const string& input_name = node_names.GetInputName(node->name()); argdef->set_name(input_name); tensor_renaming[strings::StrCat(node->name(), ":", idx)] = input_name; } @@ -307,7 +323,7 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, TF_RETURN_IF_ERROR( NameRangesForNode(*node, node->op_def(), nullptr, &output_ranges)); for (const auto& output : output_ranges) { - const string& output_name = output.first; + const StringPiece& output_name = output.first; int index_start = output.second.first; int index_end = output.second.second; for (int i = index_start; i < index_end; ++i) { @@ -462,7 +478,7 @@ Status ComputeBodyNodes( return Status::OK(); } -} // anonymous namespace +} // namespace } // namespace tensorflow using tensorflow::Node; @@ -543,6 +559,28 @@ void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, status->status = g->graph.AddFunctionLibrary(fdef_lib); } +int TF_GraphNumFunctions(TF_Graph* g) { + tensorflow::mutex_lock l(g->mu); + return g->graph.flib_def().num_functions(); +} + +int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func, + TF_Status* status) { + tensorflow::FunctionDefLibrary lib; + { + tensorflow::mutex_lock l(g->mu); + lib = g->graph.flib_def().ToProto(); + } + const auto len = std::min(max_func, static_cast(lib.function_size())); + for (int i = 0; i < len; ++i) { + TF_Function* func = new TF_Function(); + func->fdef = lib.function(i); + funcs[i] = func; + } + status->status = tensorflow::Status::OK(); + return len; +} + void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, TF_Status* status) { status->status = MessageToBuffer(func->fdef, output_func_def); diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index d5580b658992413ae6f9cb79ef88751ee28ce465..7ca50119eafe299b307f06c555aec1388e7e82e2 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" @@ -330,6 +331,11 @@ class CApiFunctionTest : public ::testing::Test { << "Failed to find expected edge " << e.ToString() << " in fdef: " << fdef.DebugString(); } + for (const EdgeSpec& e : c_edges) { + ASSERT_TRUE(a_edges.find(e) != a_edges.end()) + << "Failed to find expected control edge " << e.ToString() + << " in fdef: " << fdef.DebugString(); + } // If caller specified all edges, check that we have seen all if (is_exact_edges) { @@ -979,7 +985,7 @@ TEST_F(CApiFunctionTest, ControlDependency) { VerifyFDef( {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, - {{"scalar", "add_0"}}); + {{"^scalar", "add_0:2"}}); } TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) { @@ -1022,12 +1028,17 @@ TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) { TF_Operation* add = AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_); EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - Define(-1, {}, {feed1, feed2}, {add}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); - EXPECT_EQ(string("The source of control edge [id=3 feed1:-1 -> add:-1] " - "is not in the body. Encountered while creating " - "function 'MyFunc'"), - string(TF_Message(s_))); + Define(-1, {}, {feed1, feed2}, {add}, {}); + + // Use, run, and verify + TF_Operation* two = ScalarConst(2, host_graph_, s_); + TF_Operation* func_feed = Placeholder(host_graph_, s_); + TF_Operation* func_op = Use({two, func_feed}); + Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3); + VerifyFDef( + {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}), + {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, + {{"^feed1", "add_0:2"}}); } TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) { @@ -1462,7 +1473,11 @@ TEST_F(CApiFunctionTest, AppendHash) { /*append_hash=*/true); tensorflow::FunctionDef fdef; ASSERT_TRUE(GetFunctionDef(func_, &fdef)); +#if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__) + ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name()); +#else ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name()); +#endif } TEST_F(CApiFunctionTest, GetOpDef) { @@ -1482,9 +1497,124 @@ TEST_F(CApiFunctionTest, GetOpDef) { EXPECT_EQ(op_def.name(), func_name_); EXPECT_EQ(op_def.input_arg_size(), 1); EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_FALSE(op_def.is_stateful()); + + TF_DeleteBuffer(buffer); +} + +void DefineStatefulFunction(const char* name, TF_Function** func) { + std::unique_ptr func_graph( + TF_NewGraph(), TF_DeleteGraph); + std::unique_ptr s(TF_NewStatus(), + TF_DeleteStatus); + + TF_Tensor* tensor_shape = Int32Tensor({37, 1}); + TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape"); + TF_Operation* random = + RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get()); + + TF_Output inputs[] = {}; + TF_Output outputs[] = {{random, 0}}; + *func = TF_GraphToFunction(func_graph.get(), name, /*append_hash=*/false, -1, + /*opers=*/nullptr, 0, inputs, 1, outputs, + /*output_names=*/nullptr, + /*opts=*/nullptr, "", s.get()); + ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get()); + ASSERT_NE(*func, nullptr); + TF_DeleteTensor(tensor_shape); +} + +TEST_F(CApiFunctionTest, StatefulOpDef) { + DefineStatefulFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 0); + EXPECT_EQ(op_def.output_arg_size(), 1); + EXPECT_TRUE(op_def.is_stateful()); TF_DeleteBuffer(buffer); } +void AssertEqual(TF_Function* f1, TF_Function* f2) { + string s1, s2; + tensorflow::FunctionDef fdef1, fdef2; + ASSERT_TRUE(GetFunctionDef(f1, &fdef1)); + ASSERT_TRUE(GetFunctionDef(f2, &fdef2)); + SerializeToStringDeterministic(fdef1, &s1); + SerializeToStringDeterministic(fdef2, &s2); + ASSERT_EQ(s1, s2); +} + +string GetName(TF_Function* func) { + tensorflow::FunctionDef fdef; + GetFunctionDef(func, &fdef); + return fdef.signature().name(); +} + +TEST_F(CApiFunctionTest, GetFunctionsFromGraph) { + TF_Function* funcs[2]; + + // Get functions from empty graph + EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 0); + TF_GraphGetFunctions(host_graph_, nullptr, 0, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Define a function and add it to host_graph_ + TF_Function* func0; + DefineFunction("FooFunc0", &func0); + TF_GraphCopyFunction(host_graph_, func0, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Get this function from host_graph_ + EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 1); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 1, s_), 1); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + AssertEqual(func0, funcs[0]); + TF_DeleteFunction(funcs[0]); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 1); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + AssertEqual(func0, funcs[0]); + TF_DeleteFunction(funcs[0]); + + // Define a second function + TF_Function* func1; + DefineFunction("FooFunc1", &func1); + TF_GraphCopyFunction(host_graph_, func1, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Get both function from host_graph_ + EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 2); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 2); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + if (GetName(funcs[0]) == GetName(func0)) { + AssertEqual(func0, funcs[0]); + AssertEqual(func1, funcs[1]); + } else { + AssertEqual(func0, funcs[1]); + AssertEqual(func1, funcs[0]); + } + + TF_DeleteFunction(funcs[0]); + TF_DeleteFunction(funcs[1]); + + TF_DeleteFunction(func0); + TF_DeleteFunction(func1); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index bb04e01beec931a8ea66d0855eec9625d3a6a5ab..91667056e0eeb224b4b8a034766f11a123cd1a03 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -24,6 +24,9 @@ limitations under the License. #include #include +#ifndef __ANDROID__ +#include "tensorflow/core/framework/op_gen_lib.h" +#endif #include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" @@ -81,12 +84,20 @@ struct TF_Graph { std::unordered_map name_map GUARDED_BY(mu); - // TF_Graph may only / must be deleted when - // num_sessions == 0 && delete_requested == true - - // num_sessions incremented by TF_NewSession, and decremented by + // The keys of this map are all the active sessions using this graph. + // Each value is the current "runnability" status of the corresponding + // session. Under normal conditions all statuses are Status::OK(), but + // if some operation is mutated after it was run by a session (this + // is detected in RecordMutation function), that session is no longer + // safe to run. Its status will contain the error that will be returned + // to the user, should she try running this session. + // + // Sessions are added to this map in TF_NewSession, and removed in // TF_DeleteSession. - int num_sessions GUARDED_BY(mu); + // TF_Graph may only / must be deleted when + // sessions.size() == 0 && delete_requested == true + tensorflow::gtl::FlatMap sessions + GUARDED_BY(mu); bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph // Used to link graphs contained in TF_WhileParams to the parent graph that @@ -135,11 +146,11 @@ struct TF_ImportGraphDefOptions { struct TF_ImportGraphDefResults { std::vector return_tensors; std::vector return_nodes; - std::vector unused_key_names; - std::vector unused_key_indexes; + std::vector missing_unused_key_names; + std::vector missing_unused_key_indexes; - // Backing memory for unused_key_names values. - std::list unused_key_names_data; + // Backing memory for missing_unused_key_names values. + std::list missing_unused_key_names_data; }; struct TF_DeviceList { @@ -150,6 +161,22 @@ struct TF_Function { tensorflow::FunctionDef fdef; }; +struct TF_ApiDefMap { + explicit TF_ApiDefMap(const tensorflow::OpList& op_list) + : +#ifndef __ANDROID__ + api_def_map(op_list), +#endif + update_docs_called(false) { + } + +#ifndef __ANDROID__ + tensorflow::ApiDefMap api_def_map GUARDED_BY(lock); +#endif + bool update_docs_called GUARDED_BY(lock); + tensorflow::mutex lock; +}; + namespace tensorflow { class TensorCApi { @@ -167,6 +194,24 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status); Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out); +// Set the shapes and types of the output's handle. +// +// The lengths of the arrays pointed to by `shapes`, `ranks`, and `types` must +// all be equal to `num_shapes_and_types`. If `ranks[i] != -1`, (i.e., if the +// rank is known), then it must be equal to the length of `shapes[i]`; if +// `ranks[i] == 1`, then `shapes[i]` may be nullptr. +// +// TODO(akshayka): Implement a corresponding getter method. +void TF_GraphSetOutputHandleShapesAndTypes(TF_Graph* graph, TF_Output output, + int num_shapes_and_types, + const int64_t** shapes, + const int* ranks, + const TF_DataType* types, + TF_Status* status); + +void RecordMutation(TF_Graph* graph, const TF_Operation& op, + const char* mutation_type); + } // end namespace tensorflow #endif // TENSORFLOW_C_C_API_INTERNAL_H_ diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index 6ec1db8ccfdb713f330b708e604bd4b502ff7202..028f146be31790b211e546978302e81afe26b231 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/tag_constants.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" +#include "tensorflow/core/framework/api_def.pb.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/graph.pb_text.h" #include "tensorflow/core/framework/node_def.pb_text.h" @@ -56,6 +57,52 @@ static void ExpectHasSubstr(StringPiece s, StringPiece expected) { << "'" << s << "' does not contain '" << expected << "'"; } +// Returns the GPU device name if there is one (with arbitrary tie breaking if +// there are more than one), or "" otherwise. +string GPUDeviceName(TF_Session* session) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + std::unique_ptr list( + TF_SessionListDevices(session, s), TF_DeleteDeviceList); + TF_DeviceList* device_list = list.get(); + + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + const int num_devices = TF_DeviceListCount(device_list); + LOG(INFO) << "There are " << num_devices << " devices."; + for (int i = 0; i < num_devices; ++i) { + const char* device_name = TF_DeviceListName(device_list, i, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + const char* device_type = TF_DeviceListType(device_list, i, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + LOG(INFO) << "Device " << i << " has name " << device_name << ", type " + << device_type; + if (string(device_type) == DEVICE_GPU) { + return device_name; + } + } + // No GPU device found. + return ""; +} + +string GPUDeviceName() { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_Status* s = status.get(); + std::unique_ptr graph(TF_NewGraph(), + TF_DeleteGraph); + + TF_SessionOptions* opts = TF_NewSessionOptions(); + TF_Session* sess = TF_NewSession(graph.get(), opts, s); + TF_DeleteSessionOptions(opts); + + const string gpu_device_name = GPUDeviceName(sess); + TF_DeleteSession(sess, s); + CHECK_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + return gpu_device_name; +} + TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); } TEST(CAPI, Status) { @@ -93,6 +140,17 @@ TEST(CAPI, Tensor) { EXPECT_TRUE(deallocator_called); } +void NoOpDeallocator(void* data, size_t, void*) {} + +TEST(CAPI, MalformedTensor) { + // See https://github.com/tensorflow/tensorflow/issues/7394 + // num_dims = 0 implies a scalar, so should be backed by at least 4 bytes of + // data. + TF_Tensor* t = + TF_NewTensor(TF_FLOAT, nullptr, 0, nullptr, 0, &NoOpDeallocator, nullptr); + ASSERT_TRUE(t == nullptr); +} + TEST(CAPI, AllocateTensor) { const int num_bytes = 6 * sizeof(float); int64_t dims[] = {2, 3}; @@ -122,6 +180,10 @@ TEST(CAPI, MaybeMove) { } TEST(CAPI, LibraryLoadFunctions) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + // Load the library. TF_Status* status = TF_NewStatus(); TF_Library* lib = @@ -574,7 +636,7 @@ TEST(CAPI, ImportGraphDef) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); - // Create a graph with two nodes: x and 3 + // Create a simple graph. Placeholder(graph, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); @@ -585,7 +647,7 @@ TEST(CAPI, ImportGraphDef) { ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); - // Export to a GraphDef + // Export to a GraphDef. TF_Buffer* graph_def = TF_NewBuffer(); TF_GraphToGraphDef(graph, graph_def, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); @@ -605,6 +667,31 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(feed != nullptr); ASSERT_TRUE(neg != nullptr); + // Test basic structure of the imported graph. + EXPECT_EQ(0, TF_OperationNumInputs(scalar)); + EXPECT_EQ(0, TF_OperationNumInputs(feed)); + ASSERT_EQ(1, TF_OperationNumInputs(neg)); + TF_Output neg_input = TF_OperationInput({neg, 0}); + EXPECT_EQ(scalar, neg_input.oper); + EXPECT_EQ(0, neg_input.index); + + // Test that we can't see control edges involving the source and sink nodes. + TF_Operation* control_ops[100]; + EXPECT_EQ(0, TF_OperationNumControlInputs(scalar)); + EXPECT_EQ(0, TF_OperationGetControlInputs(scalar, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(scalar)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(scalar, control_ops, 100)); + + EXPECT_EQ(0, TF_OperationNumControlInputs(feed)); + EXPECT_EQ(0, TF_OperationGetControlInputs(feed, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(feed)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(feed, control_ops, 100)); + + EXPECT_EQ(0, TF_OperationNumControlInputs(neg)); + EXPECT_EQ(0, TF_OperationGetControlInputs(neg, control_ops, 100)); + EXPECT_EQ(0, TF_OperationNumControlOutputs(neg)); + EXPECT_EQ(0, TF_OperationGetControlOutputs(neg, control_ops, 100)); + // Import it again, with an input mapping, return outputs, and a return // operation, into the same graph. TF_DeleteImportGraphDefOptions(opts); @@ -628,7 +715,7 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(neg2 != nullptr); // Check input mapping - TF_Output neg_input = TF_OperationInput({neg, 0}); + neg_input = TF_OperationInput({neg, 0}); EXPECT_EQ(scalar, neg_input.oper); EXPECT_EQ(0, neg_input.index); @@ -773,7 +860,7 @@ TEST(CAPI, ImportGraphDef_WithReturnOutputs) { TF_DeleteStatus(s); } -TEST(CAPI, ImportGraphDef_UnusedInputMappings) { +TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -816,7 +903,7 @@ TEST(CAPI, ImportGraphDef_UnusedInputMappings) { int num_unused_input_mappings; const char** src_names; int* src_indexes; - TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResultsMissingUnusedInputMappings( results, &num_unused_input_mappings, &src_names, &src_indexes); ASSERT_EQ(1, num_unused_input_mappings); EXPECT_EQ(string("fake"), string(src_names[0])); @@ -886,6 +973,70 @@ TEST(CAPI, Session) { TF_DeleteStatus(s); } +// If `device` is non-empty, run Min op on that device. +// Otherwise run it on the default device (CPU). +void RunMinTest(const string& device, bool use_XLA) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Make a placeholder operation. + TF_Operation* feed = Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Make a constant operation with the scalar "0", for axis. + TF_Operation* one = ScalarConst(0, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Create a session for this graph. + CSession csession(graph, s, use_XLA); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + if (!device.empty()) { + LOG(INFO) << "Setting op Min on device " << device; + } + TF_Operation* min = MinWithDevice(feed, one, graph, device, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Run the graph. + csession.SetInputs({{feed, Int32Tensor({3, 2, 5})}}); + csession.SetOutputs({min}); + csession.Run(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_Tensor* out = csession.output_tensor(0); + ASSERT_TRUE(out != nullptr); + EXPECT_EQ(TF_INT32, TF_TensorType(out)); + EXPECT_EQ(0, TF_NumDims(out)); // scalar + ASSERT_EQ(sizeof(int32), TF_TensorByteSize(out)); + int32* output_contents = static_cast(TF_TensorData(out)); + EXPECT_EQ(2, *output_contents); + + // Clean up + csession.CloseAndDelete(s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI, Session_Min_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/false); } + +TEST(CAPI, Session_Min_XLA_CPU) { RunMinTest(/*device=*/"", /*use_XLA=*/true); } + +TEST(CAPI, Session_Min_GPU) { + const string gpu_device = GPUDeviceName(); + // Skip this test if no GPU is available. + if (gpu_device.empty()) return; + + RunMinTest(gpu_device, /*use_XLA=*/false); +} + +TEST(CAPI, Session_Min_XLA_GPU) { + const string gpu_device = GPUDeviceName(); + // Skip this test if no GPU is available. + if (gpu_device.empty()) return; + + RunMinTest(gpu_device, /*use_XLA=*/true); +} + TEST(CAPI, SessionPRun) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -1930,7 +2081,7 @@ TEST_F(CApiAttributesTest, Tensor) { } TEST_F(CApiAttributesTest, StringTensor) { - // Create the string-Tensor "atttribute" value. + // Create the string-Tensor "attribute" value. char encoded[] = { 0, 0, 0, 0, 0, 0, 0, 0, // array[uint64] offsets 1, // varint encoded string length @@ -2027,6 +2178,85 @@ TEST_F(CApiAttributesTest, Errors) { EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)) << TF_Message(s_); } +TEST(TestApiDef, TestCreateApiDef) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ(R"doc(Used to test C API)doc", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + +TEST(TestApiDef, TestCreateApiDefWithOverwrites) { + // TODO(b/73318067): Fix linking for the GPU test generated by the + // tf_cuda_cc_test() bazel rule and remove the next line. + if (!GPUDeviceName().empty()) return; + + TF_Status* status = TF_NewStatus(); + TF_Library* lib = + TF_LoadLibrary("tensorflow/c/test_op.so", status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + TF_Buffer op_list_buf = TF_GetOpList(lib); + status = TF_NewStatus(); + auto* api_def_map = TF_NewApiDefMap(&op_list_buf, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string api_def_overwrites = R"(op: < + graph_op_name: "TestCApi" + summary: "New summary" +> +)"; + status = TF_NewStatus(); + TF_ApiDefMapPut(api_def_map, api_def_overwrites.c_str(), + api_def_overwrites.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + string op_name = "TestCApi"; + status = TF_NewStatus(); + auto* api_def_buf = + TF_ApiDefMapGet(api_def_map, op_name.c_str(), op_name.size(), status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); + + tensorflow::ApiDef api_def; + EXPECT_TRUE(api_def.ParseFromArray(api_def_buf->data, api_def_buf->length)); + EXPECT_EQ(op_name, api_def.graph_op_name()); + EXPECT_EQ("New summary", api_def.summary()); + + TF_DeleteBuffer(api_def_buf); + TF_DeleteApiDefMap(api_def_map); + TF_DeleteLibraryHandle(lib); +} + #undef EXPECT_TF_META } // namespace diff --git a/tensorflow/c/c_test_util.cc b/tensorflow/c/c_test_util.cc index c291a2e440a8515e968b0ce0395b289080f04e8b..a55af46ae2baef1cd4f55f478ec234551f370503 100644 --- a/tensorflow/c/c_test_util.cc +++ b/tensorflow/c/c_test_util.cc @@ -15,11 +15,13 @@ limitations under the License. #include "tensorflow/c/c_test_util.h" +#include "tensorflow/compiler/jit/legacy_flags/mark_for_compilation_pass_flags.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/session_options.h" using tensorflow::GraphDef; using tensorflow::NodeDef; @@ -124,8 +126,9 @@ TF_Operation* ScalarConst(double v, TF_Graph* graph, TF_Status* s, return Const(tensor.get(), graph, s, name); } -void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, - const char* name, TF_Operation** op, bool check) { +void AddOpHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name, TF_Operation** op, + bool check) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); TF_Output add_inputs[2] = {{l, 0}, {r, 0}}; TF_AddInputList(desc, add_inputs, 2); @@ -139,14 +142,14 @@ void AddHelper(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, TF_Operation* Add(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, true); + AddOpHelper(l, r, graph, s, name, &op, true); return op; } TF_Operation* AddNoCheck(TF_Operation* l, TF_Operation* r, TF_Graph* graph, TF_Status* s, const char* name) { TF_Operation* op; - AddHelper(l, r, graph, s, name, &op, false); + AddOpHelper(l, r, graph, s, name, &op, false); return op; } @@ -160,6 +163,36 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, return TF_FinishOperation(desc, s); } +// If `op_device` is non-empty, set the created op on that device. +void BinaryOpHelper(const char* op_name, TF_Operation* l, TF_Operation* r, + TF_Graph* graph, TF_Status* s, const char* name, + TF_Operation** op, const string& op_device, bool check) { + TF_OperationDescription* desc = TF_NewOperation(graph, op_name, name); + if (!op_device.empty()) { + TF_SetDevice(desc, op_device.c_str()); + } + TF_AddInput(desc, {l, 0}); + TF_AddInput(desc, {r, 0}); + *op = TF_FinishOperation(desc, s); + if (check) { + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_NE(*op, nullptr); + } +} + +TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + const string& op_device, TF_Status* s, + const char* name) { + TF_Operation* op; + BinaryOpHelper("Min", l, r, graph, s, name, &op, op_device, true); + return op; +} + +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name) { + return MinWithDevice(l, r, graph, /*op_device=*/"", s, name); +} + TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name) { TF_OperationDescription* desc = TF_NewOperation(graph, "AddN", name); @@ -193,6 +226,15 @@ TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, return TF_FinishOperation(desc, s); } +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s) { + TF_OperationDescription* desc = + TF_NewOperation(graph, "RandomUniform", "random_uniform"); + TF_AddInput(desc, {shape, 0}); + TF_SetAttrType(desc, "dtype", dtype); + return TF_FinishOperation(desc, s); +} + void Split3Helper(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name, TF_Operation** op) { TF_Operation* zero = ScalarConst( @@ -360,8 +402,21 @@ std::vector GetFuncNames(const tensorflow::GraphDef& graph_def) { return names; } -CSession::CSession(TF_Graph* graph, TF_Status* s) { +CSession::CSession(TF_Graph* graph, TF_Status* s, bool use_XLA) { TF_SessionOptions* opts = TF_NewSessionOptions(); + tensorflow::legacy_flags::MarkForCompilationPassFlags* flags = + tensorflow::legacy_flags::GetMarkForCompilationPassFlags(); + flags->tf_xla_cpu_global_jit = use_XLA; + if (use_XLA) { + tensorflow::ConfigProto config; + config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_global_jit_level(tensorflow::OptimizerOptions::ON_1); + std::string contents; + contents.resize(config.ByteSizeLong()); + config.SerializeToArray(&contents[0], contents.size()); + TF_SetConfig(opts, contents.data(), contents.size(), s); + } session_ = TF_NewSession(graph, opts, s); TF_DeleteSessionOptions(opts); } diff --git a/tensorflow/c/c_test_util.h b/tensorflow/c/c_test_util.h index d54733749248fa32c39d88bb0281d329dd50c7bd..2a70177c724c569844a5d8ad42b99bed20209946 100644 --- a/tensorflow/c/c_test_util.h +++ b/tensorflow/c/c_test_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ +#ifndef TENSORFLOW_C_C_TEST_UTIL_H_ +#define TENSORFLOW_C_C_TEST_UTIL_H_ #include "tensorflow/c/c_api.h" @@ -69,12 +69,23 @@ TF_Operation* AddWithCtrlDependency(TF_Operation* l, TF_Operation* r, TF_Operation* Add(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s, const char* name = "add"); +TF_Operation* Min(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + TF_Status* s, const char* name = "min"); + +// If `op_device` is non-empty, set the created op on that device. +TF_Operation* MinWithDevice(TF_Operation* l, TF_Operation* r, TF_Graph* graph, + const string& op_device, TF_Status* s, + const char* name = "min"); + TF_Operation* Neg(TF_Operation* n, TF_Graph* graph, TF_Status* s, const char* name = "neg"); TF_Operation* LessThan(TF_Output l, TF_Output r, TF_Graph* graph, TF_Status* s); -// Split `input` along the first dimention into 3 tensors +TF_Operation* RandomUniform(TF_Operation* shape, TF_DataType dtype, + TF_Graph* graph, TF_Status* s); + +// Split `input` along the first dimension into 3 tensors TF_Operation* Split3(TF_Operation* input, TF_Graph* graph, TF_Status* s, const char* name = "split3"); @@ -105,7 +116,7 @@ std::vector GetFuncNames(const tensorflow::GraphDef& graph_def); class CSession { public: - CSession(TF_Graph* graph, TF_Status* s); + CSession(TF_Graph* graph, TF_Status* s, bool use_XLA = false); explicit CSession(TF_Session* session); ~CSession(); @@ -121,6 +132,8 @@ class CSession { TF_Tensor* output_tensor(int i) { return output_values_[i]; } + TF_Session* mutable_session() { return session_; } + private: void DeleteInputValues(); void ResetOutputValues(); @@ -133,4 +146,4 @@ class CSession { std::vector targets_; }; -#endif // THIRD_PARTY_TENSORFLOW_C_C_TEST_UTIL_H_ +#endif // TENSORFLOW_C_C_TEST_UTIL_H_ diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index d533758e360bc44a6f52f57eaae5b222e0482860..e55cb672e97e1403a3dd864c91c176426eb3f067 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -6,6 +6,7 @@ load( "tf_cuda_cc_test", "tf_cc_test", "tf_copts", + "tfe_xla_copts", "tf_cuda_library", ) @@ -16,7 +17,7 @@ tf_cuda_library( "c_api_internal.h", ], hdrs = ["c_api.h"], - copts = tf_copts(), + copts = tf_copts() + tfe_xla_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ @@ -33,7 +34,15 @@ tf_cuda_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", ], - }), + }) + select({ + "//tensorflow:with_xla_support": [ + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/jit", + ], + "//conditions:default": [], + }) + [ + "//tensorflow/core:gpu_runtime", + ], ) tf_cuda_library( @@ -46,6 +55,7 @@ tf_cuda_library( "//tensorflow/c:c_api", "//tensorflow/c:c_api_internal", "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:framework_lite", "//tensorflow/core:lib_internal", @@ -55,8 +65,14 @@ tf_cuda_library( tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], + extra_copts = tfe_xla_copts(), + tags = [ + "guitar", + "multi_gpu", + ], deps = [ ":c_api", + "//tensorflow/c:c_test_util", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -113,3 +129,9 @@ cc_library( "//tensorflow/core:lib", ], ) + +filegroup( + name = "headers", + srcs = ["c_api.h"], + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 706c89536db019c7f7389af576815746b2425520..8e834eb99c13d1f26da9f0860897267efc2fd01c 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -25,6 +25,10 @@ limitations under the License. #include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" +#ifdef TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#endif // TENSORFLOW_EAGER_USE_XLA +#include "tensorflow/core/common_runtime/copy_tensor.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/function.h" @@ -33,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/refcount.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" @@ -43,13 +48,23 @@ using tensorflow::int64; using tensorflow::string; namespace { -bool IsCPU(tensorflow::Device* d) { +bool IsCPU(const tensorflow::Device* d) { return d == nullptr || d->tensorflow_gpu_device_info() == nullptr; } -string DeviceName(tensorflow::Device* d) { +bool IsXLA(const tensorflow::Device* d) { + if (d == nullptr) return false; + const auto& device_type = d->attributes().device_type(); + return device_type.find("XLA") != std::string::npos; +} + +string DeviceName(const tensorflow::Device* d) { return (d == nullptr) ? "cpu:0" : d->name(); } + +#ifdef TENSORFLOW_EAGER_USE_XLA +std::atomic_int_fast64_t func_id_generator(0); +#endif // TENSORFLOW_EAGER_USE_XLA } // namespace extern "C" { @@ -84,20 +99,15 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { return nullptr; } - TFE_Context* ret = new TFE_Context(session); - ret->policy = opts->policy; - ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( - ret->session->device_mgr, opts->session_options.options.env, - TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); - ret->rendezvous = - new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); - - return ret; + return new TFE_Context(*opts, session); } void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) { status->status = tensorflow::Status::OK(); - tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); + { + tensorflow::mutex_lock ml(ctx->cache_mu); + tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); + } TF_Graph* graph = ctx->session->graph; TF_DeleteSession(ctx->session, status); TF_DeleteGraph(graph); @@ -109,6 +119,28 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) { return TF_SessionListDevices(ctx->session, status); } +void TFE_ContextClearCaches(TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->cache_mu); + tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache); +} + +void TFE_ContextSetThreadLocalDevicePlacementPolicy( + TFE_Context* ctx, TFE_ContextDevicePlacementPolicy policy) { + tensorflow::mutex_lock ml(ctx->policy_map_mu); + ctx->thread_local_policies[std::this_thread::get_id()] = policy; +} + +extern TFE_ContextDevicePlacementPolicy TFE_ContextGetDevicePlacementPolicy( + TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->policy_map_mu); + auto policy_map_it = + ctx->thread_local_policies.find(std::this_thread::get_id()); + if (policy_map_it != ctx->thread_local_policies.end()) { + return policy_map_it->second; + } + return ctx->policy; +} + TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) { tensorflow::Tensor tensor; status->status = tensorflow::TF_TensorToTensor(t, &tensor); @@ -164,23 +196,16 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, bool is_same_device = (srcd == dstd) || (DeviceName(srcd) == DeviceName(dstd)); const bool dst_cpu = IsCPU(dstd); - if (is_same_device) { - return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); - } const bool src_cpu = IsCPU(srcd); - if (src_cpu == dst_cpu) { - TF_SetStatus( - status, TF_INVALID_ARGUMENT, - tensorflow::strings::StrCat( - "TFE_TensorHandleCopyToDevice requires either the source " - "TFE_TensorHandle be on or the destination device be on CPU " - "or be the same (they are ", - DeviceName(srcd), " and ", DeviceName(dstd), " in this call)") - .c_str()); - return nullptr; + // both_on_cpu can be true and yet is_same_device is false, if one of src/dst + // has device type XLA_CPU, and the other CPU. + const bool both_on_cpu = src_cpu && dst_cpu; + if (is_same_device || both_on_cpu) { + return new TFE_TensorHandle(h->t, dst_cpu ? nullptr : dstd); } tensorflow::Tensor* src = &(h->t); - if (!dst_cpu && !tensorflow::DataTypeCanUseMemcpy(src->dtype())) { + if (!dst_cpu && (src->dtype() != tensorflow::DT_VARIANT && + !tensorflow::DataTypeCanUseMemcpy(src->dtype()))) { TF_SetStatus( status, TF_INVALID_ARGUMENT, tensorflow::strings::StrCat("Can't copy Tensor with type ", @@ -189,26 +214,22 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, .c_str()); return nullptr; } - if (src_cpu) { - tensorflow::Tensor dst( - dstd->GetAllocator(tensorflow::AllocatorAttributes()), src->dtype(), - src->shape()); - if (src->shape().num_elements() == 0) { - return new TFE_TensorHandle(dst, dstd); - } - tensorflow::Notification n; - dstd->tensorflow_gpu_device_info()->default_context->CopyCPUTensorToDevice( - src, dstd, &dst, [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); - n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, dstd) - : nullptr; - } - CHECK(dst_cpu); - tensorflow::Tensor dst(src->dtype(), src->shape()); - tensorflow::Notification n; + tensorflow::AllocatorAttributes attr; + if (src->dtype() == tensorflow::DT_VARIANT) { + attr.set_on_host(true); + } + tensorflow::Tensor dst(dstd->GetAllocator(attr), src->dtype(), src->shape()); + if (src->shape().num_elements() == 0) { + return new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd); + } + tensorflow::DeviceContext* src_device_context = nullptr; + if (!src_cpu) { + src_device_context = srcd->tensorflow_gpu_device_info()->default_context; + } + tensorflow::DeviceContext* dst_device_context = nullptr; + if (!dst_cpu) { + dst_device_context = dstd->tensorflow_gpu_device_info()->default_context; + } // TODO(ashankar): The Sync() call below may be more aggressive than // necessary. It is based on knowledge of implementation details - that // GPU devices are implemented using 3 streams - one for host->device copies, @@ -217,16 +238,18 @@ TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, // but more than necessary (since it waits for operations that might have // nothing to do with this tensor to complete). status->status = srcd->Sync(); - if (!status->status.ok()) return nullptr; - srcd->tensorflow_gpu_device_info()->default_context->CopyDeviceTensorToCPU( - src, "IGNORE_MY_TENSOR_NAME", srcd, &dst, - [status, &n](const tensorflow::Status& s) { - status->status = s; - n.Notify(); - }); + tensorflow::Notification n; + tensorflow::CopyTensor::ViaDMA("copy", src_device_context, dst_device_context, + srcd, dstd, tensorflow::AllocatorAttributes(), + tensorflow::AllocatorAttributes(), src, &dst, + [status, &n](const tensorflow::Status& s) { + status->status = s; + n.Notify(); + }); n.WaitForNotification(); - return (TF_GetCode(status) == TF_OK) ? new TFE_TensorHandle(dst, nullptr) - : nullptr; + return (TF_GetCode(status) == TF_OK) + ? new TFE_TensorHandle(dst, dst_cpu ? nullptr : dstd) + : nullptr; } TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, @@ -247,15 +270,6 @@ TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, void TFE_DeleteOp(TFE_Op* op) { delete op; } -static void TFE_OpSetDeviceHelper(TFE_Op* op, tensorflow::Device* device, - TF_Status* status) { - // Questionable heuristic: Place the op on the same device as the first input - // placed outside of host memory? - if (IsCPU(op->device) && !IsCPU(device)) { - op->device = device; - } -} - void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { tensorflow::Device* d = nullptr; if (device_name != nullptr && strlen(device_name) > 0) { @@ -263,11 +277,32 @@ void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status) { op->ctx->session->device_mgr->LookupDevice(device_name, &d); if (!status->status.ok()) return; } - TFE_OpSetDeviceHelper(op, d, status); + op->device = d; +} + +const char* TFE_OpGetDevice(TFE_Op* op, TF_Status* status) { + tensorflow::Device* device = + (op->device == nullptr) ? op->ctx->devices()[0] : op->device; + return device->name().c_str(); +} + +void TFE_OpSetXLACompilation(TFE_Op* op, unsigned char enable) { + op->use_xla = enable; +#ifndef TENSORFLOW_EAGER_USE_XLA + LOG(WARNING) << "This call is a no-op, as the TensorFlow library is not " + "built with XLA support."; +#endif // TENSORFLOW_EAGER_USE_XLA } void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status) { - TFE_OpSetDeviceHelper(op, h->d, status); + // Questionable heuristic ... + // + // Motivation: After an 'op' is placed on GPU because some of its earlier + // inputs are on GPU, we want to keep the 'op' there, even if some later + // inputs of it are not on GPU. + if (IsCPU(op->device) && !IsCPU(h->d)) { + op->device = h->d; + } if (!status->status.ok()) return; op->inputs.push_back(h->t); op->input_devices.push_back(h->d); @@ -284,7 +319,7 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return TF_ATTR_INT; // The compiler requires that we return something. } status->status = - tensorflow::AttrTypeByName(op->attr_types, attr_name, &ret, is_list); + tensorflow::AttrTypeByName(*op->attr_types, attr_name, &ret, is_list); return ret; } @@ -420,6 +455,19 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, proto.get(), num_values)); } +void TFE_OpSetAttrFunctionList(TFE_Op* op, const char* attr_name, + const TFE_Op** value, int num_values) { + std::unique_ptr funcs( + new tensorflow::NameAttrList[num_values]); + for (int i = 0; i < num_values; i++) { + funcs[i].set_name(value[i]->name); + value[i]->attrs.FillAttrValueMap(funcs[i].mutable_attr()); + } + op->attrs.Set(attr_name, + tensorflow::gtl::ArraySlice( + funcs.get(), num_values)); +} + namespace { tensorflow::Status ValidateInputTypeAndPlacement( @@ -438,10 +486,17 @@ tensorflow::Status ValidateInputTypeAndPlacement( const tensorflow::Device* actual_device = op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; if (expected_device != actual_device) { - switch (ctx->policy) { - case TFE_DEVICE_PLACEMENT_EXPLICIT: + switch (TFE_ContextGetDevicePlacementPolicy(ctx)) { + case TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32: // TODO(xpan): See if we could bubble python related error up // to python level. + if (op->inputs[i].dtype() == tensorflow::DT_INT32) { + // Note: enabling silent copies of int32 tensors to match behavior + // of graph mode. + break; + } + TF_FALLTHROUGH_INTENDED; + case TFE_DEVICE_PLACEMENT_EXPLICIT: return tensorflow::errors::InvalidArgument( "Tensors on conflicting devices:" " cannot compute ", @@ -494,6 +549,228 @@ tensorflow::Status ValidateInputTypeAndPlacement( } return tensorflow::Status::OK(); } + +#ifdef TENSORFLOW_EAGER_USE_XLA +// Synthesizes and returns a wrapper function over `op`, which must be a +// primitive op (e.g. matmul). +// +// The wrapper function conforms to the function signature expected by +// _XlaLaunchOp, with input params ordered by . For example, if the op has input params , they will be reordered to as the input params to the synthesized function. +// +// It populates `const_input_types`, `arg_input_types` and +// `op_input_to_func_input` based on the reordering results, that the caller can +// use them to build an _XlaLaunchOp. On error, it returns NULL, and sets +// `status` accordingly. +const tensorflow::FunctionDef* OpToFunction( + TFE_Op* op, std::vector* const_input_types, + std::vector* arg_input_types, + tensorflow::gtl::FlatMap* op_input_to_func_input, + TF_Status* status) { + DCHECK(!op->is_function()); + + tensorflow::FunctionDef fdef; + + // Get the OpDef of the op we are trying to encapsulate. + TFE_Context* ctx = op->ctx; + const tensorflow::OpRegistrationData* op_data; + { + tensorflow::tf_shared_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.LookUp(op->name, &op_data); + if (!status->status.ok()) { + return nullptr; + } + } + const tensorflow::OpDef& op_def = op_data->op_def; + + tensorflow::OpDef* signature = fdef.mutable_signature(); + + // Handle constant inputs. + const std::unordered_set const_inputs( + *tensorflow::XlaOpRegistry::CompileTimeConstantInputs(op->name)); + + // First add place holders for the input args, so that we can refer to them by + // position in the next loop. Also tally up the resource inputs. + int num_resource_inputs = 0; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + if (op_def.input_arg(i).type() == tensorflow::DT_RESOURCE) { + ++num_resource_inputs; + } + signature->add_input_arg(); + } + + // Now we map the input params from `op_def` to `signature`, where the param + // ordering for `signature` is: . + int const_index = 0; + int arg_index = const_inputs.size(); + int resource_index = op_def.input_arg_size() - num_resource_inputs; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + const tensorflow::OpDef::ArgDef& op_input_arg = op_def.input_arg(i); + tensorflow::OpDef::ArgDef* func_input_arg = nullptr; + if (const_inputs.find(op_input_arg.name()) != const_inputs.end()) { + VLOG(1) << "For const input, mapping op input " << i << " to func input " + << const_index; + (*op_input_to_func_input)[i] = const_index; + func_input_arg = signature->mutable_input_arg(const_index++); + const_input_types->push_back( + static_cast(op->inputs[i].dtype())); + } else if (op_input_arg.type() == tensorflow::DT_RESOURCE) { + VLOG(1) << "For resource input, mapping op input " << i + << " to func input " << resource_index; + (*op_input_to_func_input)[i] = resource_index; + func_input_arg = signature->mutable_input_arg(resource_index++); + } else { + VLOG(1) << "For arg input, mapping op input " << i << " to func input " + << arg_index; + (*op_input_to_func_input)[i] = arg_index; + func_input_arg = signature->mutable_input_arg(arg_index++); + arg_input_types->push_back( + static_cast(op->inputs[i].dtype())); + } + + func_input_arg->set_name(op_input_arg.name()); + func_input_arg->set_type(op->inputs[i].dtype()); + } + VLOG(1) << "Added OpDef Inputs: " << fdef.DebugString(); + + // Resources args are at the end of the function input params, and we should + // have iterated over all of them. + DCHECK_EQ(signature->input_arg_size(), resource_index); + + // Make the synthesized function's name unique. + signature->set_name(tensorflow::strings::StrCat( + op_def.name(), func_id_generator.fetch_add(1))); + + // Add the node def and set its input names to match op_def's names. + const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); + DCHECK_EQ(signature->input_arg_size(), ndef.input_size()); + *fdef.add_node_def() = ndef; + for (int i = 0; i < op_def.input_arg_size(); ++i) { + fdef.mutable_node_def(0)->set_input(i, op_def.input_arg(i).name()); + } + VLOG(1) << "Added NodeDef: " << fdef.DebugString(); + + // Fix the output names and set output types. + for (int i = 0; i < op_def.output_arg_size(); ++i) { + tensorflow::OpDef::ArgDef* arg = signature->add_output_arg(); + const tensorflow::OpDef::ArgDef& op_def_arg = op_def.output_arg(i); + const string& out_tensor_name = tensorflow::strings::StrCat( + ndef.name(), ":", op_def_arg.name(), ":", 0); + arg->set_name(op_def_arg.name()); + (*fdef.mutable_ret())[op_def_arg.name()] = out_tensor_name; + const string& type_attr = op_def_arg.type_attr(); + if (!type_attr.empty()) { + auto i = ndef.attr().find(type_attr); + if (i == ndef.attr().end()) { + status->status = tensorflow::errors::InvalidArgument( + tensorflow::strings::StrCat("Could not find attr ", type_attr, + " in NodeDef ", ndef.DebugString())); + return nullptr; + } + arg->set_type(i->second.type()); + } + } + VLOG(1) << "Fixed Output names and all types: " << fdef.DebugString(); + + tensorflow::mutex_lock l(ctx->functions_mu); + status->status = ctx->func_lib_def.AddFunctionDef(fdef); + if (!status->status.ok()) return nullptr; + const auto ret = ctx->func_lib_def.Find(signature->name()); + DCHECK(ret != nullptr); + return ret; +} + +// Builds an _XLALaunchOp as a wrapper over 'op', so that 'op' can be executed +// via XLA. +std::unique_ptr BuildXlaLaunch(TFE_Op* op, TF_Status* status) { + VLOG(1) << "Creating _XlaLaunchOp for TFE_Op " << op->name; + auto launch_op = + std::unique_ptr(TFE_NewOp(op->ctx, "_XlaLaunch", status)); + if (TF_GetCode(status) != TF_OK) return nullptr; + if (op->device) { + TFE_OpSetDevice(launch_op.get(), op->device->name().c_str(), status); + if (TF_GetCode(status) != TF_OK) return nullptr; + } + + const tensorflow::FunctionDef* fdef; + { + tensorflow::tf_shared_lock l(op->ctx->functions_mu); + fdef = op->ctx->func_lib_def.Find(op->name); + } + std::vector const_input_types; + std::vector arg_input_types; + tensorflow::gtl::FlatMap op_input_to_func_input; + if (fdef == nullptr) { + // See if this is a primitive op, and if so create a function for it, so + // that _XlaLaunchOp can access it. + fdef = OpToFunction(op, &const_input_types, &arg_input_types, + &op_input_to_func_input, status); + if (!status->status.ok()) return nullptr; + } else { + // TODO(hongm): XlaOpRegistry::CompileTimeConstantInputs() does not work for + // functions, so we need to find another way to handle constant inputs. + for (int i = const_input_types.size(); + i < fdef->signature().input_arg_size(); ++i) { + VLOG(1) << "Adding Targs from input arg " << i; + const tensorflow::OpDef::ArgDef& arg = fdef->signature().input_arg(i); + arg_input_types.push_back(static_cast(arg.type())); + } + } + DCHECK(fdef != nullptr); + + // Copy inputs and their devices. + // Since input param reordering may have occurred between `op` and `launch_op` + // via `op_input_to_func_input`, adjust the actual inputs accordingly. + launch_op->inputs = op->inputs; + launch_op->input_devices = op->input_devices; + if (!op_input_to_func_input.empty()) { + DCHECK_EQ(op->inputs.size(), op_input_to_func_input.size()); + if (!op->input_devices.empty()) { + DCHECK_EQ(op->input_devices.size(), op_input_to_func_input.size()); + } + for (int i = 0; i < op_input_to_func_input.size(); ++i) { + VLOG(1) << "mapping op input " << i << " to func input " + << op_input_to_func_input[i]; + + launch_op->inputs[op_input_to_func_input[i]] = op->inputs[i]; + if (!op->input_devices.empty()) { + launch_op->input_devices[op_input_to_func_input[i]] = + op->input_devices[i]; + } + } + } + launch_op->attrs.NumInputs(op->inputs.size()); + + TFE_OpSetAttrTypeList(launch_op.get(), "Tconstants", const_input_types.data(), + const_input_types.size()); + + // Set Targs and Nresources attrs. + TFE_OpSetAttrTypeList(launch_op.get(), "Targs", arg_input_types.data(), + arg_input_types.size()); + const int num_resource_inputs = fdef->signature().input_arg_size() - + const_input_types.size() - + arg_input_types.size(); + TFE_OpSetAttrInt(launch_op.get(), "Nresources", num_resource_inputs); + + // Set Tresults attr. + std::vector tresults; + for (const tensorflow::OpDef::ArgDef& arg : fdef->signature().output_arg()) { + tresults.push_back(static_cast(arg.type())); + } + TFE_OpSetAttrTypeList(launch_op.get(), "Tresults", tresults.data(), + tresults.size()); + + // Set function attr. + tensorflow::AttrValue attr_value; + tensorflow::NameAttrList* func = attr_value.mutable_func(); + func->set_name(fdef->signature().name()); + launch_op->attrs.Set("function", attr_value); + + return launch_op; +} +#endif // TENSORFLOW_EAGER_USE_XLA } // namespace void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, @@ -502,11 +779,26 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // TODO(ashankar): ASSUMPTION: ctx->devices()[0] is always CPU tensorflow::Device* device = (op->device == nullptr) ? ctx->devices()[0] : op->device; + +#ifdef TENSORFLOW_EAGER_USE_XLA + std::unique_ptr xla_launch_op; + if (op->use_xla && op->name != "_XlaLaunch") { + xla_launch_op = BuildXlaLaunch(op, status); + if (!status->status.ok()) { + return; + } + op = xla_launch_op.get(); + } +#endif // TENSORFLOW_EAGER_USE_XLA + std::vector outputs(1); const tensorflow::MemoryTypeVector* output_memory_types = nullptr; tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name()); - tensorflow::KernelAndDevice* kernel = - tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + tensorflow::KernelAndDevice* kernel; + { + tensorflow::tf_shared_lock l(ctx->cache_mu); + kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key); + } if (kernel == nullptr) { const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef(); kernel = new tensorflow::KernelAndDevice(ctx->rendezvous); @@ -522,6 +814,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, delete kernel; return; } + tensorflow::mutex_lock ml(ctx->cache_mu); tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } std::vector copied_tensors; @@ -534,19 +827,54 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } return; } + std::unique_ptr maybe_stats; + if (ctx->should_store_metadata.load()) { + maybe_stats.reset(new tensorflow::NodeExecStats); + maybe_stats->set_node_name(op->name); + maybe_stats->set_all_start_micros(tensorflow::Env::Default()->NowMicros()); + maybe_stats->set_op_start_rel_micros(0); + maybe_stats->set_scheduled_micros(tensorflow::Env::Default()->NowMicros()); + // TODO(apassos) track referenced tensors + } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime // (ctx->func_lib(device)), which in turn holds a pointer to func_lib_def, // which is GUARDED_BY(ctx->functions_mu). But knowledge of the implementation - // of FunctionLibraryRuntime tells use that func_lib_def is not accessed by + // of FunctionLibraryRuntime tells us that func_lib_def is not accessed by // FunctionLibraryRuntime::Run(), so there is no thread-safety concern here. // This is quite subtle. Re-work things to make this better? (Would it make // sense for FunctionLibraryRuntime to ensure thread-safe access to - // FunctionLibraryDefinition?). - status->status = kernel->Run(&op->inputs, &outputs); + // FunctionLibraryDefinition?). TODO(apassos) figure out how to record stats + // for ops which are a part of functions. + status->status = kernel->Run(&op->inputs, &outputs, maybe_stats.get()); for (auto* t : copied_tensors) { TFE_DeleteTensorHandle(t); } if (!status->status.ok()) return; + if (maybe_stats != nullptr) { + maybe_stats->set_op_end_rel_micros(tensorflow::Env::Default()->NowMicros() - + maybe_stats->all_start_micros()); + tensorflow::mutex_lock ml(ctx->metadata_mu); + if (ctx->should_store_metadata.load()) { + auto* step_stats = ctx->run_metadata.mutable_step_stats(); + // Lazily initialize the RunMetadata with information about all devices if + // this is the first call. + while (step_stats->dev_stats_size() < ctx->devices().size()) { + step_stats->add_dev_stats(); + } + // Find the current device's index. + int device_idx = 0; + for (int i = 0; i < ctx->devices().size(); ++i) { + if (ctx->devices()[i] == device) { + device_idx = i; + break; + } + } + // Populate the device stats for this device. + auto* dev_stats = step_stats->mutable_dev_stats(device_idx); + dev_stats->set_device(device->name()); + *dev_stats->add_node_stats() = *maybe_stats; + } + } *num_retvals = std::min(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { tensorflow::Device* d = IsCPU(device) ? nullptr : device; @@ -593,3 +921,20 @@ const tensorflow::Tensor* TFE_TensorHandleUnderlyingTensorInHostMemory( } return &h->t; } + +void TFE_ContextEnableRunMetadata(TFE_Context* ctx) { + ctx->should_store_metadata.store(true); +} + +void TFE_ContextDisableRunMetadata(TFE_Context* ctx) { + tensorflow::mutex_lock ml(ctx->metadata_mu); + ctx->should_store_metadata.store(false); + ctx->run_metadata.Clear(); +} + +void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf, + TF_Status* status) { + tensorflow::mutex_lock ml(ctx->metadata_mu); + status->status = MessageToBuffer(ctx->run_metadata, buf); + ctx->run_metadata.Clear(); +} diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index ca105962df0d6655946304159937621022e7fcba..7a321b54da343fd2b8912187bc620c1e7456db0c 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -17,6 +17,8 @@ limitations under the License. #define TENSORFLOW_C_EAGER_C_API_H_ // C API extensions to experiment with eager execution of kernels. +// WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be +// stable and can change without notice. #include "tensorflow/c/c_api.h" @@ -59,14 +61,16 @@ TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( // Controls how to act when we try to run an operation on a given device but // some input tensors are not on that device. typedef enum TFE_ContextDevicePlacementPolicy { - // The default: running operations with input tensors on the wrong device will - // fail. + // Running operations with input tensors on the wrong device will fail. TFE_DEVICE_PLACEMENT_EXPLICIT = 0, // Copy the tensor to the right device but log a warning. TFE_DEVICE_PLACEMENT_WARN = 1, // Silently copy the tensor, which has a performance cost since the // operation will be blocked till the copy completes. TFE_DEVICE_PLACEMENT_SILENT = 2, + // Default placement policy which silently copies int32 tensors but not other + // dtypes. + TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32 = 3, } TFE_ContextDevicePlacementPolicy; TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( @@ -83,10 +87,27 @@ typedef struct TFE_Context TFE_Context; TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( const TFE_ContextOptions* opts, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, + TF_Status* status); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); +// Clears the internal caches in the TFE context. Useful when reseeding random +// ops. +TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx); + +// Sets a thread-local device placement policy. After this call, other calls to +// TFE_Execute in the same thread will use the device policy specified here +// instead of the device policy used to construct the context. This has no +// effect on the device policy used by other program threads. +TF_CAPI_EXPORT extern void TFE_ContextSetThreadLocalDevicePlacementPolicy( + TFE_Context*, TFE_ContextDevicePlacementPolicy); + +// Returns the device placement policy to be used by this context in the current +// thread. +TF_CAPI_EXPORT extern TFE_ContextDevicePlacementPolicy +TFE_ContextGetDevicePlacementPolicy(TFE_Context*); + // A handle to a tensor on a device. // // Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape, @@ -99,8 +120,10 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_CAPI_EXPORT extern void TFE_DeleteTensorHandle(TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_DataType TFE_TensorHandleDataType(TFE_TensorHandle* h); TF_CAPI_EXPORT extern int TFE_TensorHandleNumDims(TFE_TensorHandle* h); -TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, int dim_index); -TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName(TFE_TensorHandle* h); +TF_CAPI_EXPORT extern int64_t TFE_TensorHandleDim(TFE_TensorHandle* h, + int dim_index); +TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceName( + TFE_TensorHandle* h); TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, TF_Status* status); @@ -110,10 +133,9 @@ TF_CAPI_EXPORT extern TF_Tensor* TFE_TensorHandleResolve(TFE_TensorHandle* h, // that shares the underlying buffer. Otherwise, it currently requires at least // one of the source or destination devices to be CPU (i.e., for the source or // destination tensor to be placed in host memory). -TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorHandle* h, - TFE_Context* ctx, - const char* device_name, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice( + TFE_TensorHandle* h, TFE_Context* ctx, const char* device_name, + TF_Status* status); // Description of the TensorFlow op to execute. // @@ -128,17 +150,31 @@ TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_TensorHandleCopyToDevice(TFE_TensorH // the additional sanity checks there seem unnecessary; typedef struct TFE_Op TFE_Op; -TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, const char* op_or_function_name, +TF_CAPI_EXPORT extern TFE_Op* TFE_NewOp(TFE_Context* ctx, + const char* op_or_function_name, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteOp(TFE_Op* op); TF_CAPI_EXPORT extern void TFE_OpSetDevice(TFE_Op* op, const char* device_name, TF_Status* status); +// The returned string remains valid throughout the lifetime of 'op'. +TF_CAPI_EXPORT extern const char* TFE_OpGetDevice(TFE_Op* op, + TF_Status* status); + +// When 'enable' is set to 1, and if TensorFlow library is built with XLA +// support, a subsequent TFE_Execute() call on `op` will run the op via XLA. +// +// If the library is not built with XLA support, this call would be a no-op. +TF_CAPI_EXPORT extern void TFE_OpSetXLACompilation(TFE_Op* op, + unsigned char enable); -TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, + TF_Status* status); -TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, - unsigned char* is_list, TF_Status* status); +TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, + const char* attr_name, + unsigned char* is_list, + TF_Status* status); // Get an attribute type given an op name; a fusion of TFE_NewOp and // TFE_OpGetAttrType for use from Python without the overhead of the individual // calls and memory management of TFE_Op. @@ -146,10 +182,13 @@ TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( TFE_Context* ctx, const char* op_or_function_name, const char* attr_name, unsigned char* is_list, TF_Status* status); -TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, +TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, + const char* attr_name, const char* value); -TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, int64_t value); -TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, float value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrInt(TFE_Op* op, const char* attr_name, + int64_t value); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloat(TFE_Op* op, const char* attr_name, + float value); TF_CAPI_EXPORT extern void TFE_OpSetAttrBool(TFE_Op* op, const char* attr_name, unsigned char value); TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, @@ -158,7 +197,8 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrType(TFE_Op* op, const char* attr_name, // -1 and `dims` can be null. If a dimension is unknown, the // corresponding entry in the `dims` array must be -1. TF_CAPI_EXPORT extern void TFE_OpSetAttrShape(TFE_Op* op, const char* attr_name, - const int64_t* dims, const int num_dims, + const int64_t* dims, + const int num_dims, TF_Status* out_status); // Sets the attribute attr_name to be a function specified by 'function'. @@ -169,19 +209,33 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name, const TFE_Op* value); -TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, const char* attr_name, - const char** value, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, const char* attr_name, - const int64_t* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, const char* attr_name, - const float* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, const char* attr_name, - const unsigned char* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, const char* attr_name, - const TF_DataType* values, int num_values); -TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, - const int64_t** dims, const int* num_dims, - int num_values, TF_Status* out_status); +TF_CAPI_EXPORT extern void TFE_OpSetAttrStringList(TFE_Op* op, + const char* attr_name, + const char** value, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrIntList(TFE_Op* op, + const char* attr_name, + const int64_t* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFloatList(TFE_Op* op, + const char* attr_name, + const float* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrBoolList(TFE_Op* op, + const char* attr_name, + const unsigned char* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrTypeList(TFE_Op* op, + const char* attr_name, + const TF_DataType* values, + int num_values); +TF_CAPI_EXPORT extern void TFE_OpSetAttrShapeList( + TFE_Op* op, const char* attr_name, const int64_t** dims, + const int* num_dims, int num_values, TF_Status* out_status); +TF_CAPI_EXPORT extern void TFE_OpSetAttrFunctionList(TFE_Op* op, + const char* attr_name, + const TFE_Op** value, + int num_values); // Execute the operation defined by 'op' and return handles to computed // tensors in 'retvals'. @@ -196,9 +250,9 @@ TF_CAPI_EXPORT extern void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, // Add a function (serialized FunctionDef protocol buffer) to ctx so // that it can be invoked using TFE_Execute. -TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef(TFE_Context* ctx, - const char* serialized_function_def, - size_t size, TF_Status* status); +TF_CAPI_EXPORT extern void TFE_ContextAddFunctionDef( + TFE_Context* ctx, const char* serialized_function_def, size_t size, + TF_Status* status); // Adds a function (created from TF_GraphToFunction or // TF_FunctionImportFunctionDef) to the context, allowing it to be executed with @@ -207,6 +261,19 @@ TF_CAPI_EXPORT extern void TFE_ContextAddFunction(TFE_Context* ctx, TF_Function* function, TF_Status* status); +// Enables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextEnableRunMetadata(TFE_Context* ctx); + +// Disables tracing of RunMetadata on the ops executed from this context. +TF_CAPI_EXPORT extern void TFE_ContextDisableRunMetadata(TFE_Context* ctx); + +// Populates the passed-in buffer with a serialized RunMetadata protocol buffer +// containing any run metadata information accumulated so far and clears this +// information. +TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx, + TF_Buffer* buf, + TF_Status* status); + #ifdef __cplusplus } /* end extern "C" */ #endif diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h index 0971e2ab2fe98cc8bf6f631f41d5adce90ee7051..7b9f1db02ed9c53a280c7bd1284165cac4fb6353 100644 --- a/tensorflow/c/eager/c_api_internal.h +++ b/tensorflow/c/eager/c_api_internal.h @@ -21,6 +21,7 @@ limitations under the License. #include #include #include +#include #include #include "tensorflow/c/c_api.h" @@ -34,20 +35,34 @@ limitations under the License. #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" struct TFE_ContextOptions { TF_SessionOptions session_options; - TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT}; + TFE_ContextDevicePlacementPolicy policy{ + TFE_DEVICE_PLACEMENT_SILENT_FOR_INT32}; }; struct TFE_Context { - explicit TFE_Context(TF_Session* s) : session(s) {} - - TFE_ContextDevicePlacementPolicy policy; + explicit TFE_Context(const TFE_ContextOptions& opts, TF_Session* s) + : policy(opts.policy), + session(s), + rendezvous(new tensorflow::IntraProcessRendezvous(s->device_mgr)), + pflr(new tensorflow::ProcessFunctionLibraryRuntime( + session->device_mgr, opts.session_options.options.env, + TF_GRAPH_DEF_VERSION, &func_lib_def, {})) {} + + const TFE_ContextDevicePlacementPolicy policy; + + // Note: we cannot use C++11 thread_local here as there is no concept of a + // thread-local-object-local variable in C++11. + tensorflow::mutex policy_map_mu; + std::unordered_map + thread_local_policies GUARDED_BY(policy_map_mu); // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* session; - tensorflow::Rendezvous* rendezvous; + TF_Session* const session; + tensorflow::Rendezvous* const rendezvous; tensorflow::mutex functions_mu; tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ @@ -56,17 +71,23 @@ struct TFE_Context { // One FunctionLibraryRuntime per device. // func_libs[i] is the FunctionLibraryRuntime corresponding to // session->devices[i]. - std::unique_ptr pflr; + const std::unique_ptr pflr; + tensorflow::mutex cache_mu; std::unordered_map - kernel_cache; + kernel_cache GUARDED_BY(cache_mu); - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { + tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) const { return pflr->GetFLR(d->name()); } const std::vector& devices() { return session->devices; } + + // Whether we should compute RunMetadata. + std::atomic should_store_metadata{false}; + tensorflow::mutex metadata_mu; + tensorflow::RunMetadata run_metadata GUARDED_BY(metadata_mu); }; struct TFE_TensorHandle { @@ -86,6 +107,8 @@ struct TFE_TensorHandle { }; struct TFE_Op { + // t is NULL iff the TFE_Op corresponds to a TensorFlow function instead of a + // primitive operation. TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} @@ -98,6 +121,7 @@ struct TFE_Op { std::vector inputs; std::vector input_devices; tensorflow::Device* device; + bool use_xla = false; }; #endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 3fe0b7efa11bc619ed98bf9a1634ade5b6ed0a7c..4a3ecbc0abb16296a84c0d2184dc3fc9f7f3ebb4 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/protobuf/config.pb.h" using tensorflow::string; @@ -59,6 +60,63 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { return op; } +TFE_TensorHandle* TestAxisTensorHandle() { + int64_t dims[] = {1}; + int data[] = {1}; + TF_Tensor* t = TF_AllocateTensor( + TF_INT32, &dims[0], sizeof(dims) / sizeof(int64_t), sizeof(data)); + memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t)); + TF_Status* status = TF_NewStatus(); + TFE_TensorHandle* th = TFE_NewTensorHandle(t, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteTensor(t); + TF_DeleteStatus(status); + return th; +} + +TFE_Op* MinOp(TFE_Context* ctx, TFE_TensorHandle* input, + TFE_TensorHandle* axis) { + TF_Status* status = TF_NewStatus(); + + TFE_Op* op = TFE_NewOp(ctx, "Min", status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, input, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpAddInput(op, axis, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_OpSetAttrBool(op, "keep_dims", 1); + TFE_OpSetAttrType(op, "Tidx", TF_INT32); + TF_DeleteStatus(status); + TFE_OpSetAttrType(op, "T", TFE_TensorHandleDataType(input)); + + return op; +} + +// If there is a GPU device, returns true and sets 'gpu_device_name' +// accordingly. +bool GetGPUDeviceName(TFE_Context* ctx, string* gpu_device_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + CHECK_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + const int num_devices = TF_DeviceListCount(devices); + for (int i = 0; i < num_devices; ++i) { + const string device_type(TF_DeviceListType(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + const string device_name(TF_DeviceListName(devices, i, status.get())); + CHECK_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get()); + if (device_type == "GPU") { + *gpu_device_name = device_name; + LOG(INFO) << "Found GPU device " << device_name; + TF_DeleteDeviceList(devices); + return true; + } + } + TF_DeleteDeviceList(devices); + return false; +} + void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); @@ -216,11 +274,10 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, TensorHandleSilentCopy) { +TEST(CAPI, TensorHandleCopyBetweenTwoGPUDevices) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); TFE_Context* ctx = TFE_NewContext(opts, status.get()); TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); @@ -233,18 +290,111 @@ TEST(CAPI, TensorHandleSilentCopy) { ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); const int num_devices = TF_DeviceListCount(devices); + const char* kCPUDevice = "CPU:0"; + if (num_devices < 3) { + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + return; + } + const string gpu_1_name(TF_DeviceListName(devices, 1, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + const string gpu_2_name(TF_DeviceListName(devices, 2, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_TensorHandle* hdevice = + TFE_TensorHandleCopyToDevice(hcpu, ctx, gpu_1_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + + TFE_TensorHandle* hdevice2 = TFE_TensorHandleCopyToDevice( + hdevice, ctx, gpu_2_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_DeleteTensorHandle(hdevice); + // Copy back to CPU + TFE_TensorHandle* hcopy = + TFE_TensorHandleCopyToDevice(hdevice2, ctx, kCPUDevice, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + TFE_DeleteTensorHandle(hdevice2); + + // Ensure that the contents are the same! + TF_Tensor* tcopy = TFE_TensorHandleResolve(hcopy, status.get()); + TFE_DeleteTensorHandle(hcopy); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK); + EXPECT_EQ(TF_TensorByteSize(t), TF_TensorByteSize(tcopy)); + EXPECT_EQ( + 0, memcmp(TF_TensorData(t), TF_TensorData(tcopy), TF_TensorByteSize(t))); + TF_DeleteTensor(tcopy); + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + +TEST(CAPI, TensorHandleSilentCopy) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + // Disable the test if no GPU is present. - if (num_devices > 1) { - const int device_to_use = 1; - const string name(TF_DeviceListName(devices, device_to_use, status.get())); + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); - TFE_TensorHandle* hgpu = - TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + +TEST(CAPI, TensorHandleSilentCopyLocal) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, + TFE_DEVICE_PLACEMENT_EXPLICIT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_ContextSetThreadLocalDevicePlacementPolicy(ctx, + TFE_DEVICE_PLACEMENT_SILENT); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_TensorHandle* hgpu = TFE_TensorHandleCopyToDevice( + hcpu, ctx, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); - TFE_OpSetDevice(matmul, name.c_str(), status.get()); + TFE_OpSetDevice(matmul, gpu_device_name.c_str(), status.get()); ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); TFE_TensorHandle* retvals[1]; int num_retvals = 1; @@ -255,20 +405,195 @@ TEST(CAPI, TensorHandleSilentCopy) { TFE_DeleteTensorHandle(hgpu); } - TF_DeleteDeviceList(devices); TF_DeleteTensor(t); TFE_DeleteTensorHandle(hcpu); TFE_DeleteContext(ctx, status.get()); EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } -TEST(CAPI, Execute) { +TEST(CAPI, SetAndGetOpDevices) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + + // Disable the test if no GPU is present. + string gpu_device_name; + if (GetGPUDeviceName(ctx, &gpu_device_name)) { + TFE_OpSetDevice(matmul, "GPU:0", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + const char* device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "GPU:0") != nullptr); + + TFE_OpSetDevice(matmul, "CPU:0", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + device_name = TFE_OpGetDevice(matmul, status); + ASSERT_TRUE(strstr(device_name, "CPU:0") != nullptr); + } + + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TF_DeleteStatus(status); +} + +TEST(CAPI, Execute_MatMul_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + TF_DeleteStatus(status); +} + +TEST(CAPI, Execute_Min_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* minOp = MinOp(ctx, input, axis); + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(minOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(minOp); + TFE_DeleteTensorHandle(input); + TFE_DeleteTensorHandle(axis); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float output[2] = {0}; + EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); + memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, output[0]); + EXPECT_EQ(3, output[1]); + TF_DeleteStatus(status); +} + +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST(CAPI, Execute_MatMul_XLA_CPU) { TF_Status* status = TF_NewStatus(); TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteContextOptions(opts); + TFE_TensorHandle* m = TestMatrixTensorHandle(); + TFE_Op* matmul = MatMulOp(ctx, m, m); + + TFE_OpSetXLACompilation(matmul, true); + + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(matmul, &retvals[0], &num_retvals, status); + // Running a primitive TF operator via XLA is not yet supported. + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(m); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + + EXPECT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float product[4] = {0}; + EXPECT_EQ(sizeof(product), TF_TensorByteSize(t)); + memcpy(&product[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(7, product[0]); + EXPECT_EQ(10, product[1]); + EXPECT_EQ(15, product[2]); + EXPECT_EQ(22, product[3]); + + TF_DeleteStatus(status); +} + +TEST(CAPI, Execute_Min_XLA_CPU) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + + TFE_TensorHandle* input = TestMatrixTensorHandle(); + TFE_TensorHandle* axis = TestAxisTensorHandle(); + TFE_Op* minOp = MinOp(ctx, input, axis); + + TFE_OpSetXLACompilation(minOp, true); + + TFE_TensorHandle* retvals[2] = {nullptr}; + int num_retvals = 2; // Should be reduced to 1 by the TFE_Execute call. + TFE_Execute(minOp, &retvals[0], &num_retvals, status); + EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteOp(minOp); + TFE_DeleteTensorHandle(input); + TFE_DeleteTensorHandle(axis); + TFE_DeleteContext(ctx, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + ASSERT_EQ(1, num_retvals); + + TF_Tensor* t = TFE_TensorHandleResolve(retvals[0], status); + TFE_DeleteTensorHandle(retvals[0]); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + float output[2] = {0}; + EXPECT_EQ(sizeof(output), TF_TensorByteSize(t)); + memcpy(&output[0], TF_TensorData(t), TF_TensorByteSize(t)); + TF_DeleteTensor(t); + EXPECT_EQ(1, output[0]); + EXPECT_EQ(3, output[1]); + TF_DeleteStatus(status); +} +#endif // TENSORFLOW_EAGER_USE_XLA + +TEST(CAPI, ExecuteWithTracing) { + TF_Status* status = TF_NewStatus(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + TFE_ContextEnableRunMetadata(ctx); + CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); TFE_TensorHandle* retvals[2] = {nullptr}; @@ -277,6 +602,13 @@ TEST(CAPI, Execute) { EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); TFE_DeleteOp(matmul); TFE_DeleteTensorHandle(m); + TF_Buffer* b = TF_NewBuffer(); + TFE_ContextExportRunMetadata(ctx, b, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); + tensorflow::RunMetadata rm; + EXPECT_TRUE( + rm.ParseFromString({reinterpret_cast(b->data), b->length})); + TF_DeleteBuffer(b); TFE_DeleteContext(ctx, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); ASSERT_EQ(1, num_retvals); @@ -295,7 +627,7 @@ TEST(CAPI, Execute) { TF_DeleteStatus(status); } -TEST(CAPI, Function) { +TEST(CAPI, Function_ident_CPU) { // First create a simple identity function. TF_Graph* function_graph = TF_NewGraph(); TF_OperationDescription* arg_descr = @@ -356,6 +688,72 @@ TEST(CAPI, Function) { TF_DeleteStatus(status); } +#ifdef TENSORFLOW_EAGER_USE_XLA +TEST(CAPI, Function_ident_XLA_CPU) { + // First create a simple identity function. + TF_Graph* function_graph = TF_NewGraph(); + TF_OperationDescription* arg_descr = + TF_NewOperation(function_graph, "Placeholder", "arg"); + TF_SetAttrType(arg_descr, "dtype", TF_INT32); + TF_Status* status = TF_NewStatus(); + TF_Operation* arg = TF_FinishOperation(arg_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_OperationDescription* id_descr = + TF_NewOperation(function_graph, "Identity", "id"); + TF_SetAttrType(id_descr, "T", TF_INT32); + TF_AddInput(id_descr, {arg, 0}); + TF_Operation* id = TF_FinishOperation(id_descr, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_Output input{arg, 0}; + TF_Output output{id, 0}; + TF_Function* fn = + TF_GraphToFunction(function_graph, "ident", 0, 1, &id, 1, &input, 1, + &output, nullptr, nullptr, "test", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteGraph(function_graph); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_Context* ctx = TFE_NewContext(opts, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_DeleteContextOptions(opts); + TFE_ContextAddFunction(ctx, fn, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteFunction(fn); + + TF_Tensor* t = + TF_AllocateTensor(TF_INT32, nullptr, 0, 1 * sizeof(tensorflow::int32)); + *reinterpret_cast(TF_TensorData(t)) = 42; + TFE_TensorHandle* h = TFE_NewTensorHandle(t, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteTensor(t); + + TFE_Op* op = TFE_NewOp(ctx, "ident", status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TFE_OpAddInput(op, h, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + + // Now run it via XLA. + TFE_OpSetXLACompilation(op, true); + + std::vector result; + result.push_back(nullptr); + int num_retvals = 1; + TFE_Execute(op, result.data(), &num_retvals, status); + TFE_DeleteOp(op); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + ASSERT_EQ(num_retvals, 1); + + TF_Tensor* r = TFE_TensorHandleResolve(result[0], status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + EXPECT_EQ(*reinterpret_cast(TF_TensorData(r)), 42); + TFE_DeleteTensorHandle(h); + TF_DeleteTensor(r); + TFE_DeleteTensorHandle(result[0]); + TFE_DeleteContext(ctx, status); + ASSERT_TRUE(TF_GetCode(status) == TF_OK) << TF_Message(status); + TF_DeleteStatus(status); +} +#endif // TENSORFLOW_EAGER_USE_XLA + string MatMulFunction() { tensorflow::FunctionDef def; CHECK(tensorflow::protobuf::TextFormat::ParseFromString( diff --git a/tensorflow/c/eager/runtime.cc b/tensorflow/c/eager/runtime.cc index 38066682a9fc5038c34a4ac3b20a67ceb08ab951..f77a937f1ffc2d146224cb3191a5ca127daefc22 100644 --- a/tensorflow/c/eager/runtime.cc +++ b/tensorflow/c/eager/runtime.cc @@ -86,10 +86,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) { return Status::OK(); } -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list) { - CHECK(m); - auto* t = gtl::FindOrNull(*m, attr_name); + auto* t = gtl::FindOrNull(m, attr_name); if (t == nullptr) { return errors::InvalidArgument("Attribute '", attr_name, "' does not exist for this operation"); @@ -173,14 +172,14 @@ void CombineUnordered(const tensorflow::Fprint128& a, b->high64 += a.high64; } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, const tensorflow::Fprint128& b) { // TODO(agarwal): avoid ToString(). tensorflow::Fprint128 a = tensorflow::Fingerprint128(s.ToString()); return FingerprintCat128(a, b); } -inline tensorflow::Fprint128 CacheKeyHelper(const StringPiece& s, uint64 b) { +inline tensorflow::Fprint128 CacheKeyHelper(StringPiece s, uint64 b) { return CacheKeyHelper(s, {b, b}); } @@ -262,7 +261,8 @@ Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, } Status KernelAndDevice::Run(std::vector* input_tensors, - std::vector* output_tensors) { + std::vector* output_tensors, + NodeExecStats* stats) { gtl::InlinedVector inputs; for (Tensor& t : *input_tensors) { inputs.push_back(TensorValue(&t)); @@ -284,6 +284,9 @@ Status KernelAndDevice::Run(std::vector* input_tensors, params.function_library = flib_; params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; + if (stats != nullptr) { + params.track_allocations = true; + } // TODO(apassos): use a thread pool. std::function)> runner = [](std::function f) { f(); }; @@ -297,6 +300,28 @@ Status KernelAndDevice::Run(std::vector* input_tensors, for (int i = 0; i < context.num_outputs(); ++i) { output_tensors->push_back(Tensor(*context.mutable_output(i))); } + if (stats != nullptr) { + for (const auto& allocator_pair : context.wrapped_allocators()) { + AllocatorMemoryUsed* memory = stats->add_memory(); + memory->set_allocator_name(allocator_pair.first->Name()); + auto sizes = allocator_pair.second->GetSizes(); + memory->set_total_bytes(std::get<0>(sizes)); + memory->set_peak_bytes(std::get<1>(sizes)); + memory->set_live_bytes(std::get<2>(sizes)); + + AllocatorStats allocator_stats; + allocator_pair.first->GetStats(&allocator_stats); + memory->set_allocator_bytes_in_use(allocator_stats.bytes_in_use); + allocator_pair.second->GetRecordsAndUnRef(); + } + auto* ms = stats->mutable_memory_stats(); + ms->set_temp_memory_size(context.temp_memory_allocated()); + for (const auto& alloc_id : context.persistent_alloc_ids()) { + ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); + } + + ms->set_persistent_memory_size(context.persistent_memory_allocated()); + } return Status::OK(); } diff --git a/tensorflow/c/eager/runtime.h b/tensorflow/c/eager/runtime.h index fb97e94a94103d17164cb30f6c6e0ed3e07dc103..4d20b5244a46fcde2eed0a429dced2a77b86aedd 100644 --- a/tensorflow/c/eager/runtime.h +++ b/tensorflow/c/eager/runtime.h @@ -43,7 +43,7 @@ typedef std::unordered_map AttrTypeMap; Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out); // Looks for 'attr_name' in 'm' and sets 'out' and 'is_list'. -Status AttrTypeByName(const AttrTypeMap* m, const string& attr_name, +Status AttrTypeByName(const AttrTypeMap& m, const string& attr_name, TF_AttrType* out, unsigned char* is_list); // KernelAndDevice::Init needs a NodeDef only to pass the attribute map through. @@ -175,7 +175,8 @@ class KernelAndDevice { : device_(nullptr), flib_(nullptr), rendez_(rendez) {} // TODO(ashankar): Handle list-valued inputs. - Status Run(std::vector* inputs, std::vector* outputs); + Status Run(std::vector* inputs, std::vector* outputs, + NodeExecStats* stats); const OpKernel* kernel() const { return kernel_.get(); } diff --git a/tensorflow/c/eager/runtime_test.cc b/tensorflow/c/eager/runtime_test.cc index 3236c6be0ec5281e8099219968dd5f5c6c2048c3..643153058ce3d6f0c88dd23a0dec4c6eff060319 100644 --- a/tensorflow/c/eager/runtime_test.cc +++ b/tensorflow/c/eager/runtime_test.cc @@ -63,17 +63,17 @@ TEST(AttrTypeMap, Lookup) { TF_AttrType t; unsigned char is_list = 1; - s = AttrTypeByName(m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); + s = AttrTypeByName(*m, "ThisAttribyteCannotPossiblyExist", &t, &is_list); EXPECT_FALSE(s.ok()); EXPECT_NE(is_list, 0); - s = AttrTypeByName(m, "transpose_a", &t, &is_list); + s = AttrTypeByName(*m, "transpose_a", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_BOOL, t); EXPECT_EQ(is_list, 0); s = AttrTypeMapForOp("Squeeze", &m); ASSERT_TRUE(s.ok()) << s; - s = AttrTypeByName(m, "squeeze_dims", &t, &is_list); + s = AttrTypeByName(*m, "squeeze_dims", &t, &is_list); ASSERT_TRUE(s.ok()) << s; EXPECT_EQ(TF_ATTR_INT, t); EXPECT_NE(is_list, 0); @@ -96,7 +96,7 @@ TEST(KernelAndDevice, Run) { KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel); ASSERT_TRUE(s.ok()) << s; std::vector outputs; - s = kernel.Run(&inputs, &outputs); + s = kernel.Run(&inputs, &outputs, nullptr); ASSERT_TRUE(s.ok()) << s; ASSERT_EQ(1, outputs.size()); const Tensor& out = outputs[0]; @@ -183,7 +183,7 @@ void BM_KernelAndDeviceRun(int iters) { KernelAndDevice::Init(ndef, env.function_library_runtime(), &kernel)); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { - TF_CHECK_OK(kernel.Run(&inputs, &outputs)); + TF_CHECK_OK(kernel.Run(&inputs, &outputs, nullptr)); } } BENCHMARK(BM_KernelAndDeviceRun); diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h index f52248e7d567b8edd911c6dba1786ceb5d5c721c..bdb0815d6b68444ec1c89b835d563db20ce4d8a1 100644 --- a/tensorflow/c/eager/tape.h +++ b/tensorflow/c/eager/tape.h @@ -18,12 +18,12 @@ limitations under the License. // Language-agnostic gradient tape. Does not perform backpropagation, just // maintains the data structures required to do so. -#include -#include #include #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -54,11 +54,11 @@ struct OpTapeEntry { // Map from tensor_id to internally-defined operation-id of the operation which // produced this tensor. A value of -1 means that the tensor was directly // watched and not the result of any operation in the tape. -using TensorTape = std::unordered_map; +using TensorTape = gtl::FlatMap; // Map from operation-id to tape entry. template -using OpTape = std::unordered_map>; +using OpTape = gtl::FlatMap>; // Operations the tape needs to perform on tensors to do backpropagation. Named // "vspace" because a subset of these are related to a vector space, such as @@ -159,9 +159,9 @@ class GradientTape { // Map from tensor id to number of remaining usages (i.e. how many entries in // the tape refer to it); to aid in tape garbage collection. - std::unordered_map tensor_usage_; + gtl::FlatMap tensor_usage_; - // If true, all activations are deleted in the first call to ComputeGradient. + // If false, all activations are deleted in the first call to ComputeGradient. // Else, only when this is destructed. bool persistent_; }; @@ -286,11 +286,11 @@ struct BackpropInitialState { // Map from tensor ID to how many references still exist for this tensor in // the tape. - std::unordered_map tensor_usage_counts; + gtl::FlatMap tensor_usage_counts; // Maps from op ID to how many output tensors of this op still need to have // their gradients computed. - std::unordered_map op_missing_tensor; + gtl::FlatMap op_missing_tensor; }; // If `persistent_tape` is true, op_tape is not changed and none of the @@ -301,8 +301,8 @@ struct BackpropInitialState { template BackpropInitialState PrepareBackprop( gtl::ArraySlice target, const TensorTape& tensor_tape, - OpTape* op_tape, - const std::unordered_set& sources_set, bool persistent_tape) { + OpTape* op_tape, const gtl::FlatSet& sources_set, + bool persistent_tape) { std::vector tensor_stack; tensor_stack.reserve(target.size()); for (auto t : target) { @@ -350,7 +350,7 @@ BackpropInitialState PrepareBackprop( // Call destructors for all unneeded gradient functions and // clear the op_tape. We can clear the tape because ownership of // backward functions that will be used for gradient computation - // has been transfered to `result`. + // has been transferred to `result`. for (const auto& op_pair : *op_tape) { op_pair.second.backward_function_deleter(); } @@ -362,7 +362,7 @@ BackpropInitialState PrepareBackprop( template std::vector InitialStack( const OpTape& op_tape, - const std::unordered_map& op_missing_tensor) { + const gtl::FlatMap& op_missing_tensor) { std::vector result; for (auto& op_entry : op_tape) { if (op_missing_tensor.find(op_entry.first) == op_missing_tensor.end()) { @@ -373,13 +373,13 @@ std::vector InitialStack( } template -Status InitialGradients( - const VSpace& vspace, - gtl::ArraySlice target_tensor_ids, - gtl::ArraySlice output_gradients, const TensorTape& tensor_tape, - const OpTape& op_tape, - const std::unordered_map& tensor_usage_counts, - std::unordered_map>* result) { +Status InitialGradients(const VSpace& vspace, + gtl::ArraySlice target_tensor_ids, + gtl::ArraySlice output_gradients, + const TensorTape& tensor_tape, + const OpTape& op_tape, + const gtl::FlatMap& tensor_usage_counts, + gtl::FlatMap>* result) { for (int i = 0; i < target_tensor_ids.size(); ++i) { const int64 id = target_tensor_ids[i]; if (tensor_usage_counts.find(id) != tensor_usage_counts.end()) { @@ -441,13 +441,13 @@ Status GradientTape::ComputeGradient( gtl::ArraySlice source_tensor_ids, gtl::ArraySlice output_gradients, std::vector* result) { - std::unordered_set sources_set(source_tensor_ids.begin(), - source_tensor_ids.end()); + gtl::FlatSet sources_set(source_tensor_ids.begin(), + source_tensor_ids.end()); BackpropInitialState state = PrepareBackprop( target_tensor_ids, tensor_tape_, &op_tape_, sources_set, persistent_); std::vector op_stack = InitialStack(state.op_tape, state.op_missing_tensor); - std::unordered_map> gradients; + gtl::FlatMap> gradients; Status s = InitialGradients(vspace, target_tensor_ids, output_gradients, tensor_tape_, state.op_tape, state.tensor_usage_counts, &gradients); @@ -463,7 +463,7 @@ Status GradientTape::ComputeGradient( cleanup(); return s; } - std::unordered_map gradients_size; + gtl::FlatMap gradients_size; // TODO(apassos) multiple threads could be dequeuing from op_stack at the same // time, for better CPU backprop performance. VLOG(1) << "Initial stack:"; @@ -472,11 +472,10 @@ Status GradientTape::ComputeGradient( VLOG(1) << " " << t; } } - std::unordered_map> - functions_accept_none_for_indices({ - {"SoftmaxCrossEntropyWithLogits", {1}}, - {"FusedBatchNorm", {1, 2, 3, 4}}, - }); + gtl::FlatMap> functions_accept_none_for_indices({ + {"SoftmaxCrossEntropyWithLogits", {1}}, + {"FusedBatchNorm", {1, 2, 3, 4}}, + }); while (!op_stack.empty()) { const int64 op = op_stack.back(); VLOG(1) << "Popped " << op; @@ -491,6 +490,7 @@ Status GradientTape::ComputeGradient( state.op_tape.erase(op_it); std::vector out_gradients; out_gradients.reserve(trace.output_tensor_info.size()); + bool any_gradient_nonzero = false; for (int i = 0; i < trace.output_tensor_info.size(); ++i) { const int64 id = trace.output_tensor_info[i].id; auto grad_it = gradients.find(id); @@ -506,6 +506,7 @@ Status GradientTape::ComputeGradient( trace.output_tensor_info[i].dtype)); } } else { + any_gradient_nonzero = true; out_gradients.push_back(vspace.AggregateGradients(grad_it->second)); if (sources_set.find(grad_it->first) == sources_set.end()) { gradients.erase(grad_it); @@ -513,14 +514,26 @@ Status GradientTape::ComputeGradient( } } std::vector in_gradients; - Status s = vspace.CallBackwardFunction(trace.backward_function, - out_gradients, &in_gradients); - if (!persistent_) { - vspace.ReleaseBackwardFunction(trace.backward_function); - } - if (!s.ok()) { - cleanup(); - return s; + if (any_gradient_nonzero) { + Status s = vspace.CallBackwardFunction(trace.backward_function, + out_gradients, &in_gradients); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + if (!s.ok()) { + cleanup(); + return s; + } + } else { + in_gradients.resize(trace.input_tensor_id.size()); + if (!persistent_) { + vspace.ReleaseBackwardFunction(trace.backward_function); + } + for (Gradient* grad : out_gradients) { + if (grad != nullptr) { + vspace.DeleteGradient(grad); + } + } } VLOG(1) << "Got " << in_gradients.size() << " in_gradients for " << trace.input_tensor_id.size() << " sources"; diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index ba5a9268b4f671499590d66fb41060dd18e1ce47..6e37cdb5f4beea53d4a2ded0705ae482d0bc2d68 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -22,6 +22,7 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { mutex_lock l(graph->mu); graph->graph.AddControlEdge(&input->node, &op->node); + RecordMutation(graph, *op, "adding control input"); } void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, @@ -36,11 +37,13 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, mutex_lock l(graph->mu); op->node.AddAttr(attr_name, attr_val); + RecordMutation(graph, *op, "setting attribute"); } void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); + RecordMutation(graph, *op, "setting device"); } void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, @@ -75,6 +78,25 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, } status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, &dst.oper->node, dst.index); + + if (status->status.ok()) { + // This modification only updates the destination node for + // the purposes of running this graph in a session. Thus, we don't + // record the source node as being modified. + RecordMutation(graph, *dst.oper, "updating input tensor"); + } +} + +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op) { + mutex_lock l(graph->mu); + std::vector control_edges; + for (const Edge* edge : op->node.in_edges()) { + if (!edge->IsControlEdge()) continue; + control_edges.push_back(edge); + } + for (const Edge* edge : control_edges) { + graph->graph.RemoveControlEdge(edge); + } } } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index f54585b0a1034ff108202272a11416e34985959e..aa9d9e06b28c54cb8869eb547d36ee3cb0d4e6b8 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ -#define THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ +#ifndef TENSORFLOW_C_PYTHON_API_H_ +#define TENSORFLOW_C_PYTHON_API_H_ #include "tensorflow/c/c_api.h" @@ -35,6 +35,8 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, TF_Status* status); +void RemoveAllControlInputs(TF_Graph* graph, TF_Operation* op); + } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ +#endif // TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index e354831d7d25af83c068a68a4f844056263a598c..9060c19e9d2cf965c2b9be07be07c42017da45a8 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -421,7 +421,7 @@ tf_cc_test( tf_gen_op_wrappers_cc( name = "cc_ops", - api_def_srcs = ["//tensorflow/core:base_api_def"], + api_def_srcs = ["//tensorflow/core/api_def:base_api_def"], op_lib_names = [ "array_ops", "audio_ops", @@ -433,6 +433,7 @@ tf_gen_op_wrappers_cc( "linalg_ops", "logging_ops", "lookup_ops", + "manip_ops", "math_ops", "nn_ops", "no_op", @@ -448,7 +449,6 @@ tf_gen_op_wrappers_cc( "ops/const_op.h", "ops/standard_ops.h", ], - override_file = "ops/op_gen_overrides.pbtxt", pkg = "//tensorflow/core", ) @@ -527,14 +527,13 @@ cc_library_with_android_deps( ], copts = tf_copts(), data = [ - "//tensorflow/core:base_api_def", + "//tensorflow/core/api_def:base_api_def", ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", - "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", ], @@ -547,15 +546,11 @@ tf_cc_test( "framework/cc_op_gen.h", "framework/cc_op_gen_test.cc", ], - data = [ - "//tensorflow/cc:ops/op_gen_overrides.pbtxt", - ], deps = [ "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:op_gen_lib", - "//tensorflow/core:op_gen_overrides_proto_cc", "//tensorflow/core:proto_text", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", @@ -679,7 +674,6 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", - "//tensorflow/core:tensorflow", ], ) diff --git a/tensorflow/cc/client/client_session_test.cc b/tensorflow/cc/client/client_session_test.cc index dfbac9788e16e9c7c65abcd1ea213b51d5d5d060..ea5cf5a1f12be316cc6e0d0a02cd3caf4d177400 100644 --- a/tensorflow/cc/client/client_session_test.cc +++ b/tensorflow/cc/client/client_session_test.cc @@ -23,7 +23,13 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) +namespace { + +using ops::Add; +using ops::Const; +using ops::Mul; +using ops::Placeholder; +using ops::Sub; TEST(ClientSessionTest, Basic) { Scope root = Scope::NewRootScope(); @@ -89,4 +95,5 @@ TEST(ClientSessionTest, MultiThreaded) { test::ExpectTensorEqual(outputs[0], test::AsTensor({-1, 2}, {2})); } -} // end namespace tensorflow +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc index d889c518f9c38a9f070970b37a2ad4b1fc26671b..a40ad1ffc3b262840e6ca0043139b1b61e04510d 100644 --- a/tensorflow/cc/framework/cc_op_gen.cc +++ b/tensorflow/cc/framework/cc_op_gen.cc @@ -1057,16 +1057,9 @@ string MakeInternal(const string& fname) { } // namespace void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, - const string& dot_h_fname, const string& dot_cc_fname, - const string& overrides_fnames) { + const string& dot_h_fname, const string& dot_cc_fname) { Env* env = Env::Default(); - // Load the override map. - OpGenOverrideMap override_map; - if (!overrides_fnames.empty()) { - TF_CHECK_OK(override_map.LoadFileList(env, overrides_fnames)); - } - // Write the initial boilerplate to the .h and .cc files. std::unique_ptr h = nullptr; std::unique_ptr cc = nullptr; diff --git a/tensorflow/cc/framework/cc_op_gen.h b/tensorflow/cc/framework/cc_op_gen.h index cea28990144b9371e8009ce13f912b44044f9aac..c7256a7dc384e652fa1bddfe3aa9893491c2b14c 100644 --- a/tensorflow/cc/framework/cc_op_gen.h +++ b/tensorflow/cc/framework/cc_op_gen.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#define TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/op_gen_lib.h" @@ -24,9 +24,8 @@ namespace tensorflow { /// Result is written to files dot_h and dot_cc. void WriteCCOps(const OpList& ops, const ApiDefMap& api_def_map, - const string& dot_h_fname, const string& dot_cc_fname, - const string& overrides_fnames); + const string& dot_h_fname, const string& dot_cc_fname); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_CC_OP_GEN_H_ diff --git a/tensorflow/cc/framework/cc_op_gen_main.cc b/tensorflow/cc/framework/cc_op_gen_main.cc index 326d5668b8803ee39ffe24900c92e1db87b93601..3157792e15a006555e4924eea3c72ea643e79c1c 100644 --- a/tensorflow/cc/framework/cc_op_gen_main.cc +++ b/tensorflow/cc/framework/cc_op_gen_main.cc @@ -28,7 +28,7 @@ namespace tensorflow { namespace { void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, - const std::string& overrides_fnames, bool include_internal, + bool include_internal, const std::vector& api_def_dirs) { OpList ops; OpRegistry::Global()->Export(include_internal, &ops); @@ -49,7 +49,7 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, api_def_map.UpdateDocs(); - WriteCCOps(ops, api_def_map, dot_h, dot_cc, overrides_fnames); + WriteCCOps(ops, api_def_map, dot_h, dot_cc); } } // namespace @@ -57,24 +57,21 @@ void PrintAllCCOps(const std::string& dot_h, const std::string& dot_cc, int main(int argc, char* argv[]) { tensorflow::port::InitMain(argv[0], &argc, &argv); - // TODO(annarev): Update this file to no longer take op_gen_overrides.pbtxt - // as an argument. - if (argc != 6) { + if (argc != 5) { for (int i = 1; i < argc; ++i) { fprintf(stderr, "Arg %d = %s\n", i, argv[i]); } fprintf(stderr, - "Usage: %s out.h out.cc overrides1.pbtxt,2.pbtxt include_internal " + "Usage: %s out.h out.cc include_internal " "api_def_dirs1,api_def_dir2 ...\n" " include_internal: 1 means include internal ops\n", argv[0]); exit(1); } - bool include_internal = tensorflow::StringPiece("1") == argv[4]; + bool include_internal = tensorflow::StringPiece("1") == argv[3]; std::vector api_def_dirs = tensorflow::str_util::Split( - argv[5], ",", tensorflow::str_util::SkipEmpty()); - tensorflow::PrintAllCCOps(argv[1], argv[2], argv[3], include_internal, - api_def_dirs); + argv[4], ",", tensorflow::str_util::SkipEmpty()); + tensorflow::PrintAllCCOps(argv[1], argv[2], include_internal, api_def_dirs); return 0; } diff --git a/tensorflow/cc/framework/cc_op_gen_test.cc b/tensorflow/cc/framework/cc_op_gen_test.cc index 0b7e720a5c7b343415eee1aa157b8de755a1e1a5..1e0f2d241bb350897a840dda90d6d0c009b1daad 100644 --- a/tensorflow/cc/framework/cc_op_gen_test.cc +++ b/tensorflow/cc/framework/cc_op_gen_test.cc @@ -24,10 +24,6 @@ limitations under the License. namespace tensorflow { namespace { -// TODO(annarev): Remove this op_gen_overrides.pbtxt reference. -// It is needed only because WriteCCOps takes it as an argument. -constexpr char kOverridesFnames[] = - "tensorflow/cc/ops/op_gen_overrides.pbtxt"; constexpr char kBaseOpDef[] = R"( op { name: "Foo" @@ -96,7 +92,7 @@ void GenerateCcOpFiles(Env* env, const OpList& ops, const auto internal_h_file_path = io::JoinPath(tmpdir, "test_internal.h"); const auto internal_cc_file_path = io::JoinPath(tmpdir, "test_internal.cc"); - WriteCCOps(ops, api_def_map, h_file_path, cc_file_path, kOverridesFnames); + WriteCCOps(ops, api_def_map, h_file_path, cc_file_path); TF_ASSERT_OK(ReadFileToString(env, h_file_path, h_file_text)); TF_ASSERT_OK( diff --git a/tensorflow/cc/framework/cc_ops_test.cc b/tensorflow/cc/framework/cc_ops_test.cc index 5da23036eaadbef270ba839357dc4613bf3bf490..ac05e3cf95b1ce4009ee1424713baf2d34902a94 100644 --- a/tensorflow/cc/framework/cc_ops_test.cc +++ b/tensorflow/cc/framework/cc_ops_test.cc @@ -22,8 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - +namespace ops { namespace { Output Linear(const Scope& scope, Input x, Input w, Input b) { @@ -39,8 +38,6 @@ void GetColocationConstraints(const Output& tensor, constraints)); } -} // namespace - TEST(CCOpTest, Basic) { Scope root = Scope::NewRootScope(); auto c = Const(root, {{1, 1}}); @@ -249,4 +246,6 @@ TEST(CCOpTest, InvalidFinalize) { string::npos); } +} // namespace +} // namespace ops } // namespace tensorflow diff --git a/tensorflow/cc/framework/grad_op_registry.h b/tensorflow/cc/framework/grad_op_registry.h index 190b96f68506c6b5252d6c0184f1712310477a8a..0fc5abb20c884a66539682099497e2c8511a620f 100644 --- a/tensorflow/cc/framework/grad_op_registry.h +++ b/tensorflow/cc/framework/grad_op_registry.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ #include @@ -72,4 +72,4 @@ class GradOpRegistry { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_GRAD_OP_REGISTRY_H_ diff --git a/tensorflow/cc/framework/gradient_checker.h b/tensorflow/cc/framework/gradient_checker.h index d055c60d09c2f33fb1f61165f75b2d04618620b7..1aa215a9088335580667e0c23c7244e6e5047f1a 100644 --- a/tensorflow/cc/framework/gradient_checker.h +++ b/tensorflow/cc/framework/gradient_checker.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -60,4 +60,4 @@ Status ComputeGradientError(const Scope& scope, const Output& x, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENT_CHECKER_H_ diff --git a/tensorflow/cc/framework/gradient_checker_test.cc b/tensorflow/cc/framework/gradient_checker_test.cc index fdc457f40af875d7c0c243246755d0cb87c44a62..d4f0a7f5ab3716be41e22c02a21aca028f76fb88 100644 --- a/tensorflow/cc/framework/gradient_checker_test.cc +++ b/tensorflow/cc/framework/gradient_checker_test.cc @@ -24,10 +24,18 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Complex; +using ops::Const; +using ops::MatMul; +using ops::Placeholder; +using ops::Real; +using ops::Split; +using ops::Square; +using ops::Stack; +using ops::Unstack; + TEST(GradientCheckerTest, BasicFloat) { Scope scope = Scope::NewRootScope(); TensorShape shape({2, 4, 3}); diff --git a/tensorflow/cc/framework/gradients.h b/tensorflow/cc/framework/gradients.h index 717f6f0636d3dd1a546ef7477b100bbfc86ba13d..0a377ad56d139a6ec26ea97b4e1e43495d0b3165 100644 --- a/tensorflow/cc/framework/gradients.h +++ b/tensorflow/cc/framework/gradients.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#define TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -49,4 +49,4 @@ Output NoGradient(); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_GRADIENTS_H_ diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index 07a062e704ed6ffc6389b5897309957a1bfcd1c2..26e3170ad8e4f4fba1c2dc014086acf24d949f72 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -26,10 +26,20 @@ limitations under the License. #include "tensorflow/core/util/equal_graph_def.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Assign; +using ops::Const; +using ops::Identity; +using ops::MatMul; +using ops::OnesLike; +using ops::Placeholder; +using ops::Square; +using ops::Stack; +using ops::StopGradient; +using ops::Unstack; +using ops::Variable; + // TODO(andydavis) Add more unit tests once more gradient functions are ported. class GradientsTest : public ::testing::Test { protected: diff --git a/tensorflow/cc/framework/ops.h b/tensorflow/cc/framework/ops.h index 8d4154220c4b18f9286094b10c1b1e96eb4e31e7..a085e1d6e2de5ad63d11eb8979ae64c26b91366f 100644 --- a/tensorflow/cc/framework/ops.h +++ b/tensorflow/cc/framework/ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#define TENSORFLOW_CC_FRAMEWORK_OPS_H_ #include @@ -296,4 +296,4 @@ class InputList { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_OPS_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_OPS_H_ diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h index 0225ac047291d6297af558fddad6e5315389ff40..30c32bd44b0f22d6b29dd3836d431807d0216818 100644 --- a/tensorflow/cc/framework/scope.h +++ b/tensorflow/cc/framework/scope.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ +#define TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ #include #include @@ -242,4 +242,4 @@ struct CompositeOpScopes { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_H_ diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h index 968c366550ef6f46557cd9b5662d9d0719b31531..8efcfed20d0b86d86d8c20a3d8630c7c6bc909c3 100644 --- a/tensorflow/cc/framework/scope_internal.h +++ b/tensorflow/cc/framework/scope_internal.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#define TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ #include "tensorflow/cc/framework/scope.h" @@ -117,4 +117,4 @@ class Scope::Impl { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_SCOPE_INTERNAL_H_ diff --git a/tensorflow/cc/framework/testutil.h b/tensorflow/cc/framework/testutil.h index a3e19870ec847bcd4f0e0bf0e71dda724024d5d2..7ad6fb4a676639f5d6d3da6a7c08de1894162f0c 100644 --- a/tensorflow/cc/framework/testutil.h +++ b/tensorflow/cc/framework/testutil.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#define TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -44,4 +44,4 @@ void GetTensor(const Scope& scope, const std::vector& assign_vars, } // namespace test } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_TESTUTIL_H_ diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h index 8f592accc93573cb8953a5ab25c04881ca0c2333..cb4e579c8548294ec45b0c3f42cb844e0b87c390 100644 --- a/tensorflow/cc/framework/while_gradients.h +++ b/tensorflow/cc/framework/while_gradients.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#ifndef TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#define TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -37,4 +37,4 @@ Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#endif // TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ diff --git a/tensorflow/cc/gradients/array_grad_test.cc b/tensorflow/cc/gradients/array_grad_test.cc index 455d7330c10cf230462869475f25a1f1b9bf9e9e..4a215fcc9299cf8b8da04cbf151640631ed0d449 100644 --- a/tensorflow/cc/gradients/array_grad_test.cc +++ b/tensorflow/cc/gradients/array_grad_test.cc @@ -23,11 +23,11 @@ limitations under the License. #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { +namespace { + using namespace ops; // NOLINT(build/namespaces) using ops::internal::MirrorPadGrad; -namespace { - class ArrayGradTest : public ::testing::Test { protected: ArrayGradTest() : scope_(Scope::NewRootScope()) {} diff --git a/tensorflow/cc/gradients/data_flow_grad_test.cc b/tensorflow/cc/gradients/data_flow_grad_test.cc index 734dfd3af97b856a7c8c4894c4a6d1a3ade10992..0ba3c0e27b1e545a30925ea3ef9e2c54dc9d0ae9 100644 --- a/tensorflow/cc/gradients/data_flow_grad_test.cc +++ b/tensorflow/cc/gradients/data_flow_grad_test.cc @@ -23,10 +23,13 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Const; +using ops::DynamicPartition; +using ops::DynamicStitch; +using ops::Placeholder; + class DataFlowGradTest : public ::testing::Test { protected: DataFlowGradTest() : scope_(Scope::NewRootScope()) {} diff --git a/tensorflow/cc/gradients/grad_testutil.cc b/tensorflow/cc/gradients/grad_testutil.cc index 04b29d4e8b21eeee200d9e7390868d701eda3c22..304117d3719346202d3a8a18637f7c915d4a47f9 100644 --- a/tensorflow/cc/gradients/grad_testutil.cc +++ b/tensorflow/cc/gradients/grad_testutil.cc @@ -18,16 +18,14 @@ limitations under the License. #include "tensorflow/cc/framework/grad_op_registry.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace test { Status CallGradFunction(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { - GradFunc grad_fn; - TF_RETURN_IF_ERROR( - GradOpRegistry::Global()->Lookup(op.node()->type_string(), &grad_fn)); + ops::GradFunc grad_fn; + TF_RETURN_IF_ERROR(ops::GradOpRegistry::Global()->Lookup( + op.node()->type_string(), &grad_fn)); TF_RETURN_IF_ERROR(grad_fn(scope, op, grad_inputs, grad_outputs)); TF_RETURN_IF_ERROR(scope.status()); return Status::OK(); diff --git a/tensorflow/cc/gradients/grad_testutil.h b/tensorflow/cc/gradients/grad_testutil.h index d31f412754ff59cc7782b14e285071a8d4218d08..70c81f1a73a394322c602a5c51e3c2a40aca2397 100644 --- a/tensorflow/cc/gradients/grad_testutil.h +++ b/tensorflow/cc/gradients/grad_testutil.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#ifndef TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#define TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -32,4 +32,4 @@ Status CallGradFunction(const Scope& scope, const Operation& op, } // namespace test } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ +#endif // TENSORFLOW_CC_GRADIENTS_GRAD_TESTUTIL_H_ diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index d7446b9560fd7dc8377ea3710641906b274313a9..52c177212a8c88f1857defcc38de4a01ac47dab0 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -473,6 +473,41 @@ Status AddNGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("AddN", AddNGrad); +Status PowGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto x = ConjugateHelper(scope, op.input(0)); + auto y = ConjugateHelper(scope, op.input(1)); + auto z = ConjugateHelper(scope, op.output(0)); + auto grad = grad_inputs[0]; + // grad * y * pow(x, y - 1) + auto one = Cast(scope, Const(scope, 1.0), y.type()); + auto gx_1 = Mul(scope, + Mul(scope, grad, y), + Pow(scope, x, Sub(scope, y, one))); + // Avoid false singularity at x = 0 + DataType x_dtype = x.type(); + auto zero = Cast(scope, Const(scope, 0.0), x_dtype); + if (x_dtype == DT_COMPLEX64 || x_dtype == DT_COMPLEX128) { + // real(x) < 0 is fine for the complex case + auto log_x = Where3(scope, + NotEqual(scope, x, zero), + Log(scope, x), + ZerosLike(scope, x)); + auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); + } else { + // There's no sensible real value to return if x < 0, so return 0 + auto log_x = Where3(scope, + Greater(scope, x, zero), + Log(scope, x), + ZerosLike(scope, x)); + auto gy_1 = Mul(scope, Mul(scope, grad, z), log_x); + return BinaryGradCommon(scope, op, grad_outputs, gx_1, gy_1); + } +} +REGISTER_GRADIENT_OP("Pow", PowGrad); + // MaximumMinimumGradCommon adds shared ops to calculate gradients for // the binary Maximum and Minimum ops. Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, @@ -794,6 +829,183 @@ Status MinOrMaxGrad(const Scope& scope, const Operation& op, REGISTER_GRADIENT_OP("Min", MinOrMaxGrad); REGISTER_GRADIENT_OP("Max", MinOrMaxGrad); +Status ProdGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto zero = Const(scope, 0); + auto one = Const(scope, 1); + + // The gradient can be expressed by dividing the product by each entry of + // the input tensor. If our input is + // [ + // [3, 4], + // [5, 6], + // [7, 8] + // ] + // and we do a Prod operation on the axis 1, we will obtain [[105, 192]]. + // The gradient will have the same shape as the input + // [ + // [105/3, 192/4], + // dz * [105/5, 192/6], + // [105/7, 192/6] + // ] + // If the input contains a zero, the division is impossible but + // if we take the calculation that gave the first gradient + // (3 * 5 * 6)/3 is equal to 5 * 6 + // the trick will be to cumprod the elements on the axis without + // the element at the current position (3 in the example above). + // We will take as example: + // [ + // [ + // [3.0, 4.0], + // [5.0, 6.0], + // [7.0, 8.0] + // ], + // [ + // [3.0, 5.0], + // [0.0, 6.0], + // [5.0, 6.0] + // ] + // ] + + // [2, 3, 2] + auto input_shape = Shape(scope, op.input(0)); + + // The Reshape with -1 flattens the reduction indices. + // [1] + auto reduction_indices = Reshape(scope, op.input(1), {-1}); + + // [2, 1, 2] + auto output_shape_kept_dims = + ReducedShapeHelper(scope, input_shape, reduction_indices); + + // [1, 3, 1] + auto tile_scaling = SafeDivHelper(scope, input_shape, output_shape_kept_dims); + + // [[[105, 192]], [[0, 180]]] + auto grad = Reshape(scope, grad_inputs[0], output_shape_kept_dims); + + // [[[105, 192], [105, 192], [105, 192]], [[0, 180], [0, 180], [0, 180]]] + auto grad_tiled = Tile(scope, grad, tile_scaling); + + Scope cpu_scope = scope.WithDevice("/cpu:0"); + + // [3] + auto rank = Rank(cpu_scope, op.input(0)); + + + // Normalize any negative indices in the reduction_axes to positive values. + auto reduction_indices_pos = Mod(cpu_scope, Add(cpu_scope, reduction_indices, rank), rank); + + // [1] + auto reduced = Cast(cpu_scope, reduction_indices_pos, DataType::DT_INT32); + + // [0, 1, 2] + auto idx = Range(cpu_scope, zero, rank, one); + + // [0, 2] + auto other = SetDiff1D(cpu_scope, idx, reduced).out; + + // [1, 0, 2] + auto perm = + Concat(cpu_scope, std::initializer_list{reduced, other}, 0); + + // 3 => [3] + auto reduced_num = Prod(cpu_scope, Gather(scope, input_shape, reduced), 0); + + // 2 * 2 => [2] + auto other_num = Prod(cpu_scope, Gather(scope, input_shape, other), 0); + + // [ + // [ + // [ 3., 4.], + // [ 3., 5.] + // ], + // [ + // [ 5., 6.], + // [ 0., 6.] + // ], + // [ + // [ 7., 8.], + // [ 5., 6.] + // ] + // ] + auto permuted = Transpose(scope, op.input(0), perm); + + // [3, 2, 2] + auto permuted_shape = Shape(scope, permuted); + + // [ + // [ 3., 4., 3., 5.], + // [ 5., 6., 0., 6.], + // [ 7., 8., 5., 6.] + // ] + auto reshaped = Reshape( + scope, permuted, + Stack(scope, std::initializer_list{reduced_num, other_num})); + + // [ + // [ 1., 1., 1., 1.], + // [ 3., 4., 3., 5.], + // [ 15., 24., 0., 30.] + // ] + auto left = Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true)); + + // [ + // [ 35., 48., 0., 36.], + // [ 7., 8., 5., 6.], + // [ 1., 1., 1., 1.] + // ] + auto right = + Cumprod(scope, reshaped, zero, Cumprod::Exclusive(true).Reverse(true)); + + // left * right = + // [ + // [ 35., 48., 0., 36.], + // [ 21., 32., 15., 30.], + // [ 15., 24., 0., 30.] + // ] + // y = + // [ + // [ + // [ 35., 48.], + // [ 0., 36.] + // ], + // [ + // [ 21., 32.], + // [ 15., 30.] + // ], + // [ + // [ 15., 24.], + // [ 0., 30.] + // ] + // ] + auto y = Reshape(scope, Mul(scope, left, right), permuted_shape); + + // out = + // [ + // [ + // [ 35., 48.], + // [ 21., 32.], + // [ 15., 24.] + // ], + // [ + // [ 0., 36.], + // [ 15., 30.], + // [ 0., 30.] + // ] + // ] + auto out = + Mul(scope, grad_tiled, Transpose(scope, y, InvertPermutation(scope, perm))); + + grad_outputs->push_back(Reshape(scope, out, input_shape)); + + // stop propagation along reduction_indices + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("Prod", ProdGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 6313f41da5e5f9cf88be4c8a84408a8df77f0e25..1b4c7c2688083e74433da3dce2849b8c37443684 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -23,10 +23,31 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::Abs; +using ops::Add; +using ops::AddN; +using ops::BatchMatMul; +using ops::Const; +using ops::Div; +using ops::Greater; +using ops::MatMul; +using ops::Max; +using ops::Maximum; +using ops::Mean; +using ops::Min; +using ops::Minimum; +using ops::Mul; +using ops::Placeholder; +using ops::Pow; +using ops::Prod; +using ops::RealDiv; +using ops::SquaredDifference; +using ops::Sub; +using ops::Sum; +using ops::Where3; + // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions // to a testutil library. @@ -83,6 +104,7 @@ class CWiseUnaryGradTest : public ::testing::Test { Output y; switch (op_type) { + using namespace ops; // NOLINT(build/namespaces) case ABS: y = Abs(scope_, x); break; @@ -843,6 +865,14 @@ TEST_F(NaryGradTest, SquaredDifference) { RunTest({x1, x2}, {x1_shape, x2_shape}, {y}, {x1_shape}); } +TEST_F(NaryGradTest, Pow) { + TensorShape shape({3}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); + // fix exponent to avoid overflow + auto y = Pow(scope_, x, Const(scope_, {1.f, 2.f, 3.f})); + RunTest({x}, {shape}, {y}, {shape}); +} + TEST_F(NaryGradTest, Maximum) { TensorShape shape({3, 2}); auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); @@ -865,5 +895,14 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } +TEST_F(NaryGradTest, Prod) { + TensorShape x_shape({2, 3, 2}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = Prod(scope_, x, {1}); + // y's shape is the result of reducing x along axes 1 + TensorShape y_shape({2, 1, 2}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/gradients/nn_grad_test.cc b/tensorflow/cc/gradients/nn_grad_test.cc index f9063e836509669d81d03b1d2f0d32d1166b6eca..0cfe5f6e3c49f7c4a3cafbf48ff4e54a0ffd0d47 100644 --- a/tensorflow/cc/gradients/nn_grad_test.cc +++ b/tensorflow/cc/gradients/nn_grad_test.cc @@ -23,10 +23,22 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { -using namespace ops; // NOLINT(build/namespaces) - namespace { +using ops::BiasAdd; +using ops::Conv2D; +using ops::Elu; +using ops::L2Loss; +using ops::LogSoftmax; +using ops::LRN; +using ops::MaxPool; +using ops::MaxPoolV2; +using ops::Placeholder; +using ops::Relu; +using ops::Relu6; +using ops::Selu; +using ops::Softmax; + class NNGradTest : public ::testing::Test { protected: NNGradTest() : scope_(Scope::NewRootScope()) {} diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index d11fda475b3db58bf83cdb94079c8fde8d1170f7..424a683665f31b5e25eeceeb40477fc31640ce90 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ -#define THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ +#ifndef TENSORFLOW_CC_OPS_CONST_OP_H_ +#define TENSORFLOW_CC_OPS_CONST_OP_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -82,4 +82,4 @@ std::vector AsNodeOutList(const Scope& scope, } // namespace ops } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_CONST_OP_H_ +#endif // TENSORFLOW_CC_OPS_CONST_OP_H_ diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt deleted file mode 100644 index 4aac990e748b0a79cbc3b353b4121a582b0883b0..0000000000000000000000000000000000000000 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ /dev/null @@ -1,238 +0,0 @@ -# array_ops -op { name: "BroadcastArgs" rename_to: "BroadcastDynamicShape" } -op { name: "BroadcastGradientArgs" hide: true } -op { name: "ConcatOffset" skip: true } # Maybe should just be hidden? -op { name: "Concat" skip: true } -op { name: "ConcatV2" rename_to: "Concat" } -op { name: "ExpandDims" input_rename: { from: "dim" to: "axis" } } -op { name: "ListDiff" rename_to: "SetDiff1D" } -op { name: "MirrorPadGrad" hide: true } -op { name: "Reverse" skip: true } -op { name: "ReverseV2" rename_to: "Reverse" } -op { name: "Split" input_rename: { from: "split_dim" to: "axis" } } -op { name: "SplitV" input_rename: { from: "split_dim" to: "axis" } } -op { name: "Squeeze" attr_rename: { from: "squeeze_dims" to: "axis" } } -op { name: "Pack" rename_to: "Stack" } -op { name: "Unpack" rename_to: "Unstack" } -op { name: "Select" rename_to: "Where3" input_rename: { from: "t" to: "x" } input_rename: { from: "e" to: "y" } } -op { name: "Where" input_rename: { from: "input" to: "condition" } } - - -# candidate_sampling_ops -op { name: "ThreadUnsafeUnigramCandidateSampler", skip: true } - -# control_flow_ops -# TODO(joshl): Hide Switch and Merge once we write and migrate users to -# a Cond() API. -#op { name: "Switch" hide: true } -#op { name: "Merge" hide: true } -op { name: "RefMerge" hide: true } -op { name: "Exit" hide: true } -op { name: "RefExit" hide: true } -op { name: "Enter" hide: true } -op { name: "RefEnter" hide: true } -op { name: "RefIdentity" hide: true } - -# ctc_ops - -# data_flow_ops -op { name: "FakeQueue" skip: true } -op { name: "FIFOQueue" skip: true} -op { name: "FIFOQueueV2" rename_to: "FIFOQueue" } -op { name: "PaddingFIFOQueue" skip: true } -op { name: "PaddingFIFOQueueV2" rename_to: "PaddingFIFOQueue" } -op { name: "PriorityQueue" skip: true } -op { name: "PriorityQueueV2" rename_to: "PriorityQueue" } -op { name: "QueueClose" skip: true } -op { name: "QueueCloseV2" rename_to: "QueueClose" } -op { name: "QueueDequeue" skip: true } -op { name: "QueueDequeueV2" rename_to: "QueueDequeue" } -op { name: "QueueDequeueMany" skip: true } -op { name: "QueueDequeueManyV2" rename_to: "QueueDequeueMany" } -op { name: "QueueDequeueUpTo" skip: true } -op { name: "QueueDequeueUpToV2" rename_to: "QueueDequeueUpTo" } -op { name: "QueueEnqueue" skip: true } -op { name: "QueueEnqueueV2" rename_to: "QueueEnqueue" } -op { name: "QueueEnqueueMany" skip: true } -op { name: "QueueEnqueueManyV2" rename_to: "QueueEnqueueMany" } -op { name: "QueueSize" skip: true } -op { name: "QueueSizeV2" rename_to: "QueueSize" } -op { name: "RandomShuffleQueue" skip: true } -op { name: "RandomShuffleQueueV2" rename_to: "RandomShuffleQueue" } -op { name: "ReaderNumRecordsProduced" skip: true } -op { name: "ReaderNumRecordsProducedV2" rename_to: "ReaderNumRecordsProduced" } -op { name: "ReaderNumWorkUnitsCompleted" skip: true } -op { name: "ReaderNumWorkUnitsCompletedV2" rename_to: "ReaderNumWorkUnitsCompleted" } -op { name: "ReaderRead" skip: true } -op { name: "ReaderReadUpTo" skip: true } -op { name: "ReaderReadUpToV2" rename_to: "ReaderReadUpTo" } -op { name: "ReaderReadV2" rename_to: "ReaderRead" } -op { name: "ReaderReset" skip: true } -op { name: "ReaderResetV2" rename_to: "ReaderReset" } -op { name: "ReaderRestoreState" skip: true } -op { name: "ReaderRestoreStateV2" rename_to: "ReaderRestoreState" } -op { name: "ReaderSerializeState" skip: true } -op { name: "ReaderSerializeStateV2" rename_to: "ReaderSerializeState" } -op { name: "FixedLengthRecordReader" skip: true } -op { name: "FixedLengthRecordReaderV2" rename_to: "FixedLengthRecordReader" } -op { name: "IdentityReader" skip: true } -op { name: "IdentityReaderV2" rename_to: "IdentityReader" } -op { name: "TFRecordReader" skip: true } -op { name: "TFRecordReaderV2" rename_to: "TFRecordReader" } -op { name: "TextLineReader" skip: true } -op { name: "TextLineReaderV2" rename_to: "TextLineReader" } - -# Skip hash table ops until we have better support in C++ (ops are currently -# only used in contrib) -op { name: "HashTable" skip: true } -op { name: "InitializeTable" skip: true } -op { name: "InitializeTableFromTextFile" skip: true } -op { name: "LookupTableFind" skip: true } -op { name: "LookupTableImport" skip: true } -op { name: "LookupTableInsert" skip: true } -op { name: "LookupTableSize" skip: true } -op { name: "MutableDenseHashTable" skip: true } -op { name: "MutableHashTable" skip: true } -op { name: "MutableHashTableOfTensors" skip: true } - -# Stack ops are internal to control flow gradients (not yet implemented in C++) -op { name: "Stack" skip: true } -op { name: "StackClose" skip: true } -op { name: "StackPop" skip: true } -op { name: "StackPush" skip: true } -op { name: "StackV2" skip: true } -op { name: "StackCloseV2" skip: true } -op { name: "StackPopV2" skip: true } -op { name: "StackPushV2" skip: true } - -op { name: "TensorArrayCloseV2" skip: true } -op { name: "TensorArrayCloseV3" rename_to: "TensorArrayClose" } -op { name: "TensorArrayConcatV2" skip: true } -op { name: "TensorArrayConcatV3" rename_to: "TensorArrayConcat" } -op { name: "TensorArrayGatherV2" skip: true } -op { name: "TensorArrayGatherV3" rename_to: "TensorArrayGather" } -op { name: "TensorArrayGradV2" skip: true } -op { name: "TensorArrayGradV3" rename_to: "TensorArrayGrad" } -op { name: "TensorArrayReadV2" skip: true } -op { name: "TensorArrayReadV3" rename_to: "TensorArrayRead" } -op { name: "TensorArrayScatterV2" skip: true } -op { name: "TensorArrayScatterV3" rename_to: "TensorArrayScatter" } -op { name: "TensorArraySizeV2" skip: true } -op { name: "TensorArraySizeV3" rename_to: "TensorArraySize" } -op { name: "TensorArraySplitV2" skip: true } -op { name: "TensorArraySplitV3" rename_to: "TensorArraySplit" } -op { name: "TensorArrayV2" skip: true } -op { name: "TensorArrayV3" rename_to: "TensorArray" } -op { name: "TensorArrayWriteV2" skip: true } -op { name: "TensorArrayWriteV3" rename_to: "TensorArrayWrite" } - -op { name: "WholeFileReader" skip: true } -op { name: "WholeFileReaderV2" rename_to: "WholeFileReader" } - -# functional_ops - -# image_ops -op { name: "AdjustContrastv2" rename_to: "AdjustContrast" } -op { name: "ResizeBilinearGrad" hide: true } -op { name: "ResizeBicubicGrad" hide: true } -op { name: "ResizeNearestNeighborGrad" hide: true } - -# io_ops - -# linalg_ops -op { name: "SelfAdjointEigV2" rename_to: "SelfAdjointEig" } - -# logging_ops -op { name: "AudioSummaryV2" rename_to: "AudioSummary" } - -# lookup_ops -op { name: "LookupTableFind" skip: true } -op { name: "LookupTableFindV2" rename_to: "LookupTableFind" } -op { name: "LookupTableInsert" skip: true } -op { name: "LookupTableInsertV2" rename_to: "LookupTableInsert" } -op { name: "LookupTableSize" skip: true } -op { name: "LookupTableSizeV2" rename_to: "LookupTableSize" } -op { name: "LookupTableExport" skip: true } -op { name: "LookupTableExportV2" rename_to: "LookupTableExport" } -op { name: "LookupTableImport" skip: true } -op { name: "LookupTableImportV2" rename_to: "LookupTableImport" } -op { name: "HashTable" skip: true } -op { name: "HashTableV2" rename_to: "HashTable" } -op { name: "MutableHashTable" skip: true } -op { name: "MutableHashTableV2" rename_to: "MutableHashTable" } -op { name: "MutableHashTableOfTensors" skip: true } -op { name: "MutableHashTableOfTensorsV2" rename_to: "MutableHashTableOfTensors" } -op { name: "MutableDenseHashTable" skip: true } -op { name: "MutableDenseHashTableV2" rename_to: "MutableDenseHashTable" } -op { name: "InitializeTable" skip: true } -op { name: "InitializeTableV2" rename_to: "InitializeTable" } -op { name: "InitializeTableFromTextFile" skip: true } -op { name: "InitializeTableFromTextFileV2" rename_to: "InitializeTableFromTextFile" } - -# math_ops -op { name: "All" alias: "ReduceAll" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Any" alias: "ReduceAny" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Max" alias: "ReduceMax" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Mean" alias: "ReduceMean" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Min" alias: "ReduceMin" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Mul" rename_to: "Multiply" alias: "Mul" } -op { name: "Neg" rename_to: "Negate" alias: "Neg" } -op { name: "Prod" alias: "ReduceProd" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "Sub" rename_to: "Subtract" alias: "Sub" } -op { name: "Sum" alias: "ReduceSum" input_rename: { from: "reduction_indices" to: "axis" } } -op { name: "SigmoidGrad" hide: true } -op { name: "TanhGrad" hide: true } -op { name: "InvGrad" hide: true } -op { name: "ReciprocalGrad" hide: true } -op { name: "SqrtGrad" hide: true } -op { name: "RsqrtGrad" hide: true } - -# *Grad ops get hidden, only for use by the gradient code. -op { name: "SigmoidGrad" hide: true } -op { name: "TanhGrad" hide: true } -op { name: "InvGrad" hide: true } -op { name: "ReciprocalGrad" hide: true } -op { name: "SqrtGrad" hide: true } -op { name: "RsqrtGrad" hide: true } - -# nn_ops -op { name: "AvgPoolGrad" hide: true } -op { name: "LRNGrad" hide: true } -op { name: "MaxPoolGrad" hide: true } -op { name: "MaxPoolGradWithArgmax" hide: true } -op { name: "ReluGrad" hide: true } -op { name: "Relu6Grad" hide: true } -op { name: "EluGrad" hide: true } -op { name: "SeluGrad" hide: true } -op { name: "SoftplusGrad" hide: true } -op { name: "SoftsignGrad" hide: true } -op { name: "FractionalAvgPoolGrad" hide: true } -op { name: "FractionalMaxPoolGrad" hide: true } -op { name: "TopKV2" rename_to: "TopK" } -op { name: "BiasAddV1" skip: true } # Use BiasAdd instead - -# parsing_ops - -# random_ops - -op { name: "RandomStandardNormal" rename_to: "RandomNormal" } -# script_ops -# Calling Python functions from a C++ program isn't supported -op { name: "PyFunc" skip: true } -op { name: "PyFuncStateless" skip: true} - -# sdca_ops - -# state_ops - -op { name: "Variable" skip: true } -op { name: "VariableV2" rename_to: "Variable" } - -# sparse_ops - -# string_ops - -# user_ops - -# training_ops - diff --git a/tensorflow/cc/ops/standard_ops.h b/tensorflow/cc/ops/standard_ops.h index 0c021f0b3ac02c596e0511e650a3caa0002c25d1..98f53010ecf78f769c7d89d6aafc48fdb772f42e 100644 --- a/tensorflow/cc/ops/standard_ops.h +++ b/tensorflow/cc/ops/standard_ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#ifndef TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#define TENSORFLOW_CC_OPS_STANDARD_OPS_H_ #include "tensorflow/cc/ops/array_ops.h" #include "tensorflow/cc/ops/candidate_sampling_ops.h" @@ -37,4 +37,4 @@ limitations under the License. #include "tensorflow/cc/ops/training_ops.h" #include "tensorflow/cc/ops/user_ops.h" -#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_STANDARD_OPS_H_ +#endif // TENSORFLOW_CC_OPS_STANDARD_OPS_H_ diff --git a/tensorflow/cc/ops/while_loop.cc b/tensorflow/cc/ops/while_loop.cc index e0251efb2a424f86bd5a4885ef22d1928e04bd3e..d1c918d464bc9684b0db6dade2fb80cb2bd6691a 100644 --- a/tensorflow/cc/ops/while_loop.cc +++ b/tensorflow/cc/ops/while_loop.cc @@ -116,7 +116,7 @@ Status CreateCond(const Scope& scope, const CondGraphBuilderFn& cond, return Status::OK(); } -// Create the bdoy subgraph defined by `body`. `outputs` must be non-null and +// Create the body subgraph defined by `body`. `outputs` must be non-null and // empty. Status CreateBody(const Scope& scope, const BodyGraphBuilderFn& body, const std::vector& inputs, diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index a04476056a058ff0951a6347e8ffc05bc5ff5023..727237b5c7ad4d31dba1aaaf6d5600773d69223e 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ -#define THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#ifndef TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#define TENSORFLOW_CC_OPS_WHILE_LOOP_H_ #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" @@ -71,4 +71,4 @@ Status BuildWhileLoop(const Scope& scope, const std::vector& inputs, } // namespace ops } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_OPS_WHILE_LOOP_H_ +#endif // TENSORFLOW_CC_OPS_WHILE_LOOP_H_ diff --git a/tensorflow/cc/profiler/BUILD b/tensorflow/cc/profiler/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..00799526fce572e7bb80199ccb8ce1cc89874031 --- /dev/null +++ b/tensorflow/cc/profiler/BUILD @@ -0,0 +1,36 @@ +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") + +tf_cuda_cc_test( + name = "profiler_test", + srcs = ["profiler_test.cc"], + deps = [ + ":profiler", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "profiler", + srcs = ["profiler.cc"], + hdrs = ["profiler.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/profiler:protos_all_cc", + "//tensorflow/core/profiler:tfprof_options", + "//tensorflow/core/profiler/internal:tfprof_stats", + ], +) diff --git a/tensorflow/cc/profiler/profiler.cc b/tensorflow/cc/profiler/profiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..3e55bac73e6d32a1fa5ddcc1937744e2cf56657d --- /dev/null +++ b/tensorflow/cc/profiler/profiler.cc @@ -0,0 +1,57 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/cc/profiler/profiler.h" + +namespace tensorflow { +namespace tfprof { + +Profiler::Profiler(const GraphDef& graph) { + std::unique_ptr graph_ptr(new GraphDef()); + *graph_ptr = graph; + stats_.reset(new TFStats(std::move(graph_ptr), nullptr, nullptr, nullptr)); +} + +void Profiler::AddStep(int64 step, const RunMetadata& run_meta) { + std::unique_ptr run_meta_ptr(new RunMetadata()); + *run_meta_ptr = run_meta; + stats_->AddRunMeta(step, std::move(run_meta_ptr)); +} + +GraphNodeProto Profiler::ProfileGraph(const Options& options) { + stats_->BuildView(kCmds[1]); + return stats_->ShowGraphNode(kCmds[1], options); +} + +GraphNodeProto Profiler::ProfileNameScope(const Options& options) { + stats_->BuildView(kCmds[0]); + return stats_->ShowGraphNode(kCmds[0], options); +} + +MultiGraphNodeProto Profiler::ProfileOperations(const Options& options) { + stats_->BuildView(kCmds[3]); + return stats_->ShowMultiGraphNode(kCmds[3], options); +} + +Status Profiler::SerializeToString(string* content) { + if (!content) { + return Status(error::Code::INVALID_ARGUMENT, + "Cannot use null string pointer for SerializeToString."); + } + stats_->SerializeToString(content); + return Status::OK(); +} + +} // namespace tfprof +} // namespace tensorflow diff --git a/tensorflow/cc/profiler/profiler.h b/tensorflow/cc/profiler/profiler.h new file mode 100644 index 0000000000000000000000000000000000000000..6077c45c5854fd5812ccb7c91522f93ed4e54883 --- /dev/null +++ b/tensorflow/cc/profiler/profiler.h @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_PROFILER_PROFILER_H_ +#define TENSORFLOW_CC_PROFILER_PROFILER_H_ + +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/profiler/internal/tfprof_stats.h" +#include "tensorflow/core/profiler/tfprof_options.h" +#include "tensorflow/core/profiler/tfprof_output.pb.h" + +namespace tensorflow { +namespace tfprof { + +/// @addtogroup core +/// @{ + +/// A `Profiler` object lets the caller profile the execution of a graph. +/// +/// Example: +/// // First build a graph and run tracing. +/// Scope root = Scope::NewRootScope(); +/// auto a = Placeholder(root, DT_INT32); +/// auto c = Add(root, a, {41}); +/// +/// ClientSession session(root); +/// std::vector outputs; +/// RunOptions run_options; +/// run_options.set_trace_level(RunOptions::FULL_TRACE); +/// RunMetadata run_meta; +/// Status s = session.Run(run_options, { {a, {1}} }, {c}, &outputs, +/// &run_meta); +/// if (!s.ok()) { ... } +/// +/// // Then create profiler to do profiling. +/// GraphDef graph; +/// root.ToGraphDef(&graph); +/// Profiler profiler(graph); +/// profiler.AddStep(0, run_meta); +/// Options opts = ... // TODO(xpan): Support option building API. +/// MultiGraphNodeProto r = profiler.ProfileOperations(opts); +/// +class Profiler { + public: + /// `graph` is the model's GraphDef. + Profiler(const GraphDef& graph); + + /// Adds tracing information `run_meta` to profiler. A `run_meta` is + /// generated by a TensorFlow session run call. `step` is the key + /// to the `run_meta`. When calling ProfileXXX methods, caller can specify + /// `step` in `options` to seletively profile the corresponding `run_meta`. + /// Multiple different `run_meta` can be keyed by the same `step` in order + /// to group them together. + void AddStep(int64 step, const RunMetadata& run_meta); + + /// Profiles the model by organizing nodes in graph structure. + /// Each node is an op and the nodes are contected by the op inputs/outputs. + GraphNodeProto ProfileGraph(const Options& options); + + /// Profiles the model by organizing nodes in name scope structure. + /// Each node is an op, and nodes are organized by the ops' name + /// scope, similar to a filesystem tree. + /// E.g. /foo is the root of operation /foo/matmul_1 and foo/conv_2. + GraphNodeProto ProfileNameScope(const Options& options); + + /// Profiles the model by organizing nodes by operation types. + /// Each node is an operation type (e.g. Conv2D or MatMul), containing all + /// ops belonging to that type in the model. + MultiGraphNodeProto ProfileOperations(const Options& options); + + /// Serialize the profile content (ProfileProto) into a binary string, + /// User can write the string to file for offline analysis by + /// tfprof command-line tools or graphical user interface. + Status SerializeToString(string* content); + + private: + std::unique_ptr stats_; +}; +/// @} + +} // namespace tfprof +} // namespace tensorflow + +#endif // TENSORFLOW_CC_PROFILER_PROFILER_H_ diff --git a/tensorflow/cc/profiler/profiler_test.cc b/tensorflow/cc/profiler/profiler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..280cd74827fc8ae80737eaf61286535fec959aa8 --- /dev/null +++ b/tensorflow/cc/profiler/profiler_test.cc @@ -0,0 +1,177 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/platform/test.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/profiler/profiler.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/default_device.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace tfprof { + +class ProfilerTest : public ::testing::Test { + protected: + ProfilerTest() {} +}; + +GraphDef CreateGraphDef() { + Scope root = Scope::NewRootScope(); + + auto a = ops::Const(root, {{3, 2}, {-1, 0}}); + + auto x = ops::Const(root.WithOpName("x"), {{1.f}, {1.f}}); + + auto y = ops::MatMul(root.WithOpName("y"), a, x); + + auto y2 = ops::Square(root, y); + + auto y2_sum = ops::Sum(root, y2, 0); + + auto y_norm = ops::Sqrt(root, y2_sum); + + auto y_div = ops::Div(root.WithOpName("y_normalized"), y, y_norm); + + GraphDef def; + TF_CHECK_OK(root.ToGraphDef(&def)); + + return def; +} + +Options Default() { + Options opts(1000, /* max_depth */ + 0, /* min_bytes */ + 0, /* min_peak_bytes */ + 0, /* min_residual_bytes */ + 0, /* min_output_bytes */ + 0, /* min_micros */ + 0, /* min_accelerator_micros */ + 0, /* min_cpu_micros */ + 0, /* min_params */ + 0, /* min_float_ops */ + 0, /* min_occurrence */ + 0, /* step */ + "name", /* order_by */ + {".*"}, /* account_type_regexes */ + {".*"}, /* start_name_regexes */ + {}, /* trim_name_regexes */ + {".*"}, {}, /* hide_name_regexes */ + false, /* account_displayed_op_only */ + {"micros"}, /* select */ + {"none"}, /* output_type */ + {}); + return opts; +} + +template +const T* ExtractNode(const T& pb, const string& name) { + if (pb.name() == name) { + return &pb; + } + for (const T& c : pb.children()) { + const T* ret = ExtractNode(c, name); + if (ret) return ret; + } + return nullptr; +} + +TEST_F(ProfilerTest, Basics) { + SessionOptions options; + options.config.set_allow_soft_placement(true); + std::unique_ptr session(NewSession(options)); + GraphDef def = CreateGraphDef(); + if (options.target.empty()) { + graph::SetDefaultDevice("/gpu:0", &def); + } + + TF_CHECK_OK(session->Create(def)); + + Tensor x(DT_FLOAT, TensorShape({2, 1})); + auto x_flat = x.flat(); + x_flat.setRandom(); + Eigen::Tensor inv_norm = + x_flat.square().sum().sqrt().inverse(); + x_flat = x_flat * inv_norm(); + + std::vector outputs; + RunOptions run_options; + run_options.set_trace_level(RunOptions::FULL_TRACE); + RunMetadata run_metadata; + outputs.clear(); + + Profiler profiler(def); + for (int i = 0; i < 2; ++i) { + TF_CHECK_OK(session->Run(run_options, {{"x", x}}, {"y:0", "y_normalized:0"}, + {}, &outputs, &run_metadata)); + profiler.AddStep(i, run_metadata); + CHECK_EQ(size_t{2}, outputs.size()); + } + + std::vector resp; + TF_CHECK_OK(session->ListDevices(&resp)); + bool has_gpu = false; + for (const auto& dev : resp) { + if (dev.device_type() == "GPU") { + has_gpu = true; + } + } + + GraphNodeProto ret = profiler.ProfileNameScope(Default()); + const GraphNodeProto* matmul = ExtractNode(ret, "y"); + EXPECT_TRUE(matmul); + EXPECT_GT(matmul->exec_micros(), 0); + if (has_gpu) { + EXPECT_GT(matmul->accelerator_exec_micros(), 0); + } else { + EXPECT_EQ(matmul->accelerator_exec_micros(), 0); + } + const GraphNodeProto* square = ExtractNode(ret, "Square"); + EXPECT_TRUE(square); + EXPECT_GT(square->exec_micros(), 0); + if (has_gpu) { + EXPECT_GT(square->accelerator_exec_micros(), 0); + } else { + EXPECT_EQ(square->accelerator_exec_micros(), 0); + } + + Options opts2 = Default(); + opts2.output_type = "timeline"; + string timeline_file = io::JoinPath(testing::TmpDir(), "timeline"); + opts2.output_options["outfile"] = timeline_file; + GraphNodeProto ret2 = profiler.ProfileGraph(opts2); + string s; + TF_CHECK_OK(ReadFileToString(Env::Default(), timeline_file + "_0", &s)); + EXPECT_TRUE(s.find("Square") != s.npos); + + MultiGraphNodeProto ret3 = profiler.ProfileOperations(Default()); + const MultiGraphNodeProto* matmul2 = ExtractNode(ret3, "MatMul"); + EXPECT_TRUE(matmul2); + EXPECT_GT(matmul2->exec_micros(), 0); + if (has_gpu) { + EXPECT_GT(matmul2->accelerator_exec_micros(), 0); + } else { + EXPECT_EQ(matmul2->accelerator_exec_micros(), 0); + } + + TF_CHECK_OK(session->Close()); +} + +} // namespace tfprof +} // namespace tensorflow diff --git a/tensorflow/cc/saved_model/constants.h b/tensorflow/cc/saved_model/constants.h index c940df8a8761d97a859be3af30980ff79ca3577a..645a3f101d1ae7dda88ec4ca622c694dc5a7a919 100644 --- a/tensorflow/cc/saved_model/constants.h +++ b/tensorflow/cc/saved_model/constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ namespace tensorflow { @@ -47,4 +47,4 @@ constexpr char kSavedModelVariablesFilename[] = "variables"; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/loader.cc b/tensorflow/cc/saved_model/loader.cc index f98abc8a817eca7bc129bb03a2ad31b97d957065..faa1e378d07ea94ad08ee084d18bf6a113f054af 100644 --- a/tensorflow/cc/saved_model/loader.cc +++ b/tensorflow/cc/saved_model/loader.cc @@ -62,6 +62,15 @@ Status ReadSavedModel(const string& export_dir, SavedModel* saved_model_proto) { export_dir); } +string GetTagsAsString(const std::unordered_set& tags) { + string tags_as_string = "{ "; + for (const string& tag : tags) { + tags_as_string = strings::StrCat(tags_as_string, tag, " "); + } + tags_as_string = strings::StrCat(tags_as_string, "}"); + return tags_as_string; +} + Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, const std::unordered_set& tags, MetaGraphDef* meta_graph_def_to_load) { @@ -77,14 +86,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, return Status::OK(); } } - string tags_as_string = "{ "; - for (const string& tag : tags) { - tags_as_string = strings::StrCat(tags_as_string, tag, " "); - } - tags_as_string = strings::StrCat(tags_as_string, "}"); return Status(error::Code::NOT_FOUND, "Could not find meta graph def matching supplied tags: " + - tags_as_string + + GetTagsAsString(tags) + ". To inspect available tag-sets in the SavedModel, please " "use the SavedModel CLI: `saved_model_cli`"); } @@ -92,7 +96,9 @@ Status FindMetaGraphDefToLoad(const SavedModel& saved_model_proto, Status LoadMetaGraphIntoSession(const MetaGraphDef& meta_graph_def, const SessionOptions& session_options, std::unique_ptr* session) { - session->reset(NewSession(session_options)); + Session* session_p = nullptr; + TF_RETURN_IF_ERROR(NewSession(session_options, &session_p)); + session->reset(session_p); return (*session)->Create(meta_graph_def.graph_def()); } @@ -233,7 +239,8 @@ Status LoadSavedModelInternal(const SessionOptions& session_options, return Status(error::Code::NOT_FOUND, "SavedModel not found in export directory: " + export_dir); } - LOG(INFO) << "Loading SavedModel from: " << export_dir; + LOG(INFO) << "Loading SavedModel with tags: " << GetTagsAsString(tags) + << "; from: " << export_dir; SavedModel saved_model_proto; TF_RETURN_IF_ERROR(ReadSavedModel(export_dir, &saved_model_proto)); @@ -281,7 +288,8 @@ Status LoadSavedModel(const SessionOptions& session_options, return end_microseconds - start_microseconds; }(); auto log_and_count = [&](const string& status_str) { - LOG(INFO) << "Loading SavedModel: " << status_str << ". Took " + LOG(INFO) << "SavedModel load for tags " << GetTagsAsString(tags) + << "; Status: " << status_str << ". Took " << load_latency_microsecs << " microseconds."; load_attempt_count->GetCell(export_dir, status_str)->IncrementBy(1); }; diff --git a/tensorflow/cc/saved_model/loader.h b/tensorflow/cc/saved_model/loader.h index 3d634dd51543bed8d3c074bdc56c251f97d56976..a8e098fa5440e7a8f72fd0b52737dcb06435b908 100644 --- a/tensorflow/cc/saved_model/loader.h +++ b/tensorflow/cc/saved_model/loader.h @@ -15,8 +15,8 @@ limitations under the License. /// SavedModel loading functions and SavedModelBundle struct. -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ +#define TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ #include #include @@ -61,4 +61,4 @@ bool MaybeSavedModelDirectory(const string& export_dir); } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_LOADER_H_ diff --git a/tensorflow/cc/saved_model/loader_test.cc b/tensorflow/cc/saved_model/loader_test.cc index 0ad6b33bba5fcceaca68e2f179cef2232c689a80..4c64d2cfe3c10e6c7ed82a2d72460a0b34283bb2 100644 --- a/tensorflow/cc/saved_model/loader_test.cc +++ b/tensorflow/cc/saved_model/loader_test.cc @@ -155,6 +155,24 @@ TEST_F(LoaderTest, NoTagMatchMultiple) { << st.error_message(); } +TEST_F(LoaderTest, SessionCreationFailure) { + SavedModelBundle bundle; + // Use invalid SessionOptions to cause session creation to fail. Default + // options work, so provide an invalid value for the target field. + SessionOptions session_options; + constexpr char kInvalidTarget[] = "invalid target"; + session_options.target = kInvalidTarget; + RunOptions run_options; + + const string export_dir = + io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataSharded); + Status st = LoadSavedModel(session_options, run_options, export_dir, + {kSavedModelTagServe}, &bundle); + EXPECT_FALSE(st.ok()); + EXPECT_TRUE(StringPiece(st.error_message()).contains(kInvalidTarget)) + << st.error_message(); +} + TEST_F(LoaderTest, PbtxtFormat) { SavedModelBundle bundle; SessionOptions session_options; diff --git a/tensorflow/cc/saved_model/python/BUILD b/tensorflow/cc/saved_model/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f5fbc75edcba9d5ae9ef7432de224df766bcab9e --- /dev/null +++ b/tensorflow/cc/saved_model/python/BUILD @@ -0,0 +1,30 @@ +# Description: +# CLIF wrappers for TensorFlow SavedModels. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = ["//visibility:public"], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_py_clif_cc") + +tf_py_clif_cc( + name = "loader", + srcs = ["loader.clif"], + deps = [ + "//tensorflow/cc/saved_model:loader", + ], +) diff --git a/tensorflow/cc/saved_model/python/loader.clif b/tensorflow/cc/saved_model/python/loader.clif new file mode 100644 index 0000000000000000000000000000000000000000..b102757d2eeb46ee713d8ed0d0c3d66b58740ee0 --- /dev/null +++ b/tensorflow/cc/saved_model/python/loader.clif @@ -0,0 +1,4 @@ +from "third_party/tensorflow/cc/saved_model/loader.h": + namespace `tensorflow`: + class SavedModelBundle: + def __init__(self) diff --git a/tensorflow/cc/saved_model/signature_constants.h b/tensorflow/cc/saved_model/signature_constants.h index b2d39bd55beb48a05489236395a208e41deb9c8f..7d8c07f5cf0a310c20193469cb6d18664f738d96 100644 --- a/tensorflow/cc/saved_model/signature_constants.h +++ b/tensorflow/cc/saved_model/signature_constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ namespace tensorflow { @@ -66,4 +66,4 @@ static constexpr char kRegressOutputs[] = "outputs"; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_SIGNATURE_CONSTANTS_H_ diff --git a/tensorflow/cc/saved_model/tag_constants.h b/tensorflow/cc/saved_model/tag_constants.h index b71cb263ca42dab7e830c1880ec4b311bc272f82..68a090e0c4cf79cfa87771a80447b8112fc37fb9 100644 --- a/tensorflow/cc/saved_model/tag_constants.h +++ b/tensorflow/cc/saved_model/tag_constants.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ -#define THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ +#ifndef TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ namespace tensorflow { @@ -32,4 +32,4 @@ constexpr char kSavedModelTagTrain[] = "train"; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ +#endif // TENSORFLOW_CC_SAVED_MODEL_TAG_CONSTANTS_H_ diff --git a/tensorflow/cc/tools/BUILD b/tensorflow/cc/tools/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..97f66e79b8ad9f383b22f56e9385fc6d2080e1f8 --- /dev/null +++ b/tensorflow/cc/tools/BUILD @@ -0,0 +1,57 @@ +# Description: +# TensorFlow cc tools. + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", +) + +cc_library( + name = "freeze_saved_model", + srcs = ["freeze_saved_model.cc"], + hdrs = ["freeze_saved_model.h"], + deps = [ + "//tensorflow/cc/saved_model:loader", + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "freeze_saved_model_test", + srcs = ["freeze_saved_model_test.cc"], + deps = [ + ":freeze_saved_model", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +# ----------------------------------------------------------------------------- +# Google-internal targets. + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/cc/tools/freeze_saved_model.cc b/tensorflow/cc/tools/freeze_saved_model.cc new file mode 100644 index 0000000000000000000000000000000000000000..ddf372cdef21e1b3892c9a03714478d5a5785517 --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/tools/freeze_saved_model.h" + +#include + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace tensorflow { + +namespace { + +// Gets tensor names from tensor_info and inserts them into the set of tensor +// names. +void GetTensorNamesFromTensorInfo(const TensorInfo& tensor_info, + std::unordered_set* tensor_names) { + if (tensor_info.has_coo_sparse()) { + // If the tensor is sparse we have to add all three tensors of the sparse + // representations. + const TensorInfo_CooSparse& coo_sparse = tensor_info.coo_sparse(); + tensor_names->insert(coo_sparse.values_tensor_name()); + tensor_names->insert(coo_sparse.indices_tensor_name()); + tensor_names->insert(coo_sparse.dense_shape_tensor_name()); + } else { + tensor_names->insert(tensor_info.name()); + } +} + +// Gets the union of all inputs and outputs of all SignatureDefs in the bundle +void GetSignatureDefsInputsAndOutputs( + const SavedModelBundle& saved_model_bundle, + std::unordered_set* inputs, std::unordered_set* outputs) { + for (auto& sigdef_elem : saved_model_bundle.meta_graph_def.signature_def()) { + const SignatureDef& signature_def = sigdef_elem.second; + for (auto& input_elem : signature_def.inputs()) { + GetTensorNamesFromTensorInfo(input_elem.second, inputs); + } + for (auto& output_elem : signature_def.outputs()) { + GetTensorNamesFromTensorInfo(output_elem.second, outputs); + } + } +} + +// Gets a map from string node name to NodeDef. +void GetNodeNameToNodeDefMap( + GraphDef* graph_def, + std::unordered_map* name_to_node_map) { + for (size_t i = 0; i < graph_def->node_size(); i++) { + NodeDef* node = graph_def->mutable_node(i); + (*name_to_node_map)[node->name()] = node; + } +} + +// Gets the set of node names needed by `outputs` and the corresponding set of +// variable nodes to convert. +void GetReachableNodesAndVariables( + GraphDef* graph_def, const std::unordered_set& outputs, + std::unordered_set* reachable_node_names, + std::unordered_set* variable_node_names) { + // TODO(suharshs): Add support for ResourceVariables. + static const std::unordered_set* kVariableTypes = + new std::unordered_set({"Variable", "VariableV2"}); + // name_to_node_map is needed to get the inputs from the NodeDef corresponding + // the a string node name. These inputs are used when doing our backwards + // traversal. + std::unordered_map name_to_node_map; + GetNodeNameToNodeDefMap(graph_def, &name_to_node_map); + std::queue nodes_to_visit; + for (const string& tensor_name : outputs) { + // We need to strip off the tensor part to get the node name. + std::vector tensor_name_parts = str_util::Split(tensor_name, ':'); + nodes_to_visit.push(tensor_name_parts[0]); + } + // We do a traversal backwards from the outputs specified in the MetaGraphDef. + while (!nodes_to_visit.empty()) { + const string node_name = nodes_to_visit.front(); + nodes_to_visit.pop(); + if (reachable_node_names->find(node_name) != reachable_node_names->end()) { + continue; + } + reachable_node_names->insert(node_name); + NodeDef* node = name_to_node_map[node_name]; + if (kVariableTypes->find(node->op()) != kVariableTypes->end()) { + variable_node_names->insert(node->name()); + } + for (const string& input : node->input()) { + nodes_to_visit.push(input); + } + } +} + +// Gets a map from variable name to variable value. +Status GetVariableNameToTensorMap( + Session* session, std::unordered_set variable_names_set, + std::unordered_map* variable_name_to_value_map) { + if (variable_names_set.empty()) { + return Status::OK(); + } + std::vector variable_names; + std::vector tensor_names; + for (const string& node_name : variable_names_set) { + variable_names.push_back(node_name); + // We need to run tensors, so append ":0". + tensor_names.push_back(node_name + ":0"); + } + std::vector outputs; + TF_RETURN_IF_ERROR( + session->Run(/* inputs */ {}, tensor_names, /* targets */ {}, &outputs)); + for (size_t i = 0; i < variable_names.size(); i++) { + (*variable_name_to_value_map)[variable_names[i]] = outputs[i]; + } + return Status::OK(); +} + +// Converts a Variable NodeDef into a Constant NodeDef. +void ConvertVariableToConstant(const NodeDef& variable_node, + const Tensor& variable_value, + NodeDef* const_node) { + const_node->set_name(variable_node.name()); + const_node->set_op("Const"); + (*const_node->mutable_attr())["dtype"] = variable_node.attr().at("dtype"); + variable_value.AsProtoTensorContent( + (*const_node->mutable_attr())["value"].mutable_tensor()); +} + +// Freezes the subgraph of all nodes needed by `outputs`. +Status FreezeGraphDef(const SavedModelBundle& saved_model_bundle, + const std::unordered_set& outputs, + GraphDef* frozen_graph_def) { + GraphDef graph_def = saved_model_bundle.meta_graph_def.graph_def(); + // Copy versions and library as-is from original graph. + *frozen_graph_def->mutable_versions() = graph_def.versions(); + *frozen_graph_def->mutable_library() = graph_def.library(); + // If the graph is empty there is nothing left to do. + if (graph_def.node_size() == 0) { + return Status::OK(); + } + std::unordered_set reachable_node_names; + std::unordered_set variable_node_names; + GetReachableNodesAndVariables(&graph_def, outputs, &reachable_node_names, + &variable_node_names); + std::unordered_map variable_to_value_map; + TF_RETURN_IF_ERROR( + GetVariableNameToTensorMap(saved_model_bundle.session.get(), + variable_node_names, &variable_to_value_map)); + // We copy the nodes in the same order they were in the original graph_def. + for (const NodeDef& node : graph_def.node()) { + if (reachable_node_names.find(node.name()) == reachable_node_names.end()) { + continue; + } + if (variable_node_names.find(node.name()) != variable_node_names.end()) { + ConvertVariableToConstant(node, variable_to_value_map[node.name()], + frozen_graph_def->add_node()); + } else { + // If the node isn't a variable, just copy the node as-is. + *frozen_graph_def->add_node() = node; + } + } + return Status::OK(); +} + +} // namespace + +Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set* inputs, + std::unordered_set* outputs) { + GetSignatureDefsInputsAndOutputs(saved_model_bundle, inputs, outputs); + TF_RETURN_IF_ERROR( + FreezeGraphDef(saved_model_bundle, *outputs, frozen_graph_def)); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/tools/freeze_saved_model.h b/tensorflow/cc/tools/freeze_saved_model.h new file mode 100644 index 0000000000000000000000000000000000000000..b10f29805a4515f9d49426cc41e0d375cd32b072 --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model.h @@ -0,0 +1,43 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ +#define TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ + +#include + +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Returns a frozen GraphDef, input tensors, and output tensors from the loaded +// SavedModelBundle. +// `inputs` and `outputs` consist of the union of all inputs and outputs in the +// SignatureDefs in the SavedModelBundle. +// FreezeSavedModel sets `frozen_graph_def` to a GraphDef of all nodes needed by +// `outputs`. All variables in the supplied SavedModelBundle are converted to +// constants, set to the value of the variables, by running the restored Session +// in the SavedModelBundle. +// WARNING: Only the variable checkpoints will be reflected in the frozen +// graph_def. All saved_model assets will be ignored. +Status FreezeSavedModel(const SavedModelBundle& saved_model_bundle, + GraphDef* frozen_graph_def, + std::unordered_set* inputs, + std::unordered_set* outputs); + +} // namespace tensorflow + +#endif // TENSORFLOW_CC_TOOLS_FREEZE_SAVED_MODEL_H_ diff --git a/tensorflow/cc/tools/freeze_saved_model_test.cc b/tensorflow/cc/tools/freeze_saved_model_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..52a81a50284aec36bba4e56a0232c886cb0cb6cf --- /dev/null +++ b/tensorflow/cc/tools/freeze_saved_model_test.cc @@ -0,0 +1,307 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/tools/freeze_saved_model.h" + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { + +class FreezeTest : public ::testing::Test { + protected: + void GraphDefEqual(const GraphDef& actual, const GraphDef& expected) { + EXPECT_EQ(actual.ShortDebugString(), expected.ShortDebugString()); + } + + // Builds a SignatureDef with the provided `inputs` and `outputs`. + SignatureDef BuildSignatureDef(const std::unordered_set& inputs, + const std::unordered_set& outputs) { + SignatureDef signature_def; + for (const string& input : inputs) { + (*signature_def.mutable_inputs())[input].set_name(input); + } + for (const string& output : outputs) { + (*signature_def.mutable_outputs())[output].set_name(output); + } + return signature_def; + } + + // Adds `signature_def` to `saved_model_bundle` under `key`. + void AddSignatureDefToSavedModelBundle(const SignatureDef& signature_def, + const string& key, + SavedModelBundle* saved_model_bundle) { + MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def; + (*meta_graph_def->mutable_signature_def())[key] = signature_def; + } + + // Adds an initialized session to `saved_model_bundle` using `graph_def` and + // initializing with `init_node`. + Status InitializeSavedModelBundleSession( + const GraphDef& graph_def, const string& init_node, + SavedModelBundle* saved_model_bundle) { + SessionOptions session_options; + saved_model_bundle->session.reset(NewSession(session_options)); + TF_RETURN_IF_ERROR(saved_model_bundle->session->Create(graph_def)); + if (!init_node.empty()) { + std::vector outputs; + return saved_model_bundle->session->Run( + /* inputs */ {}, /* output_tensors */ {}, {init_node}, &outputs); + } + return Status::OK(); + } + + // Adds `graph_def` to `saved_model_bundle` and initializes a session with + // `init_node`. + Status AddGraphDefToSavedModelBundle(const GraphDef& graph_def, + const string& init_node, + SavedModelBundle* saved_model_bundle) { + MetaGraphDef* meta_graph_def = &saved_model_bundle->meta_graph_def; + *meta_graph_def->mutable_graph_def() = graph_def; + return InitializeSavedModelBundleSession(graph_def, init_node, + saved_model_bundle); + } + + // Adds `graph_def` and `outputs` as the GraphDef and SignatureDef in + // `saved_model_bundle` and initializes a session with `init_node`. + Status AddGraphDefWithOutputsToSavedModelBundle( + const GraphDef& graph_def, const std::unordered_set& outputs, + const string& init_node, SavedModelBundle* saved_model_bundle) { + SignatureDef signature_def = + BuildSignatureDef(std::unordered_set(), outputs); + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + saved_model_bundle); + return AddGraphDefToSavedModelBundle(graph_def, init_node, + saved_model_bundle); + } + + // Runs and compares the outputs of `tensor_name` on both the + // `unfrozen_session` and the `frozen_graph_def. + void RunAndCompareFrozenAndUnfrozenGraphs(Session* unfrozen_session, + const GraphDef& frozen_graph_def, + const string& tensor_name) { + std::vector unfrozen_outputs; + TF_ASSERT_OK(unfrozen_session->Run(/* inputs */ {}, {tensor_name}, + /* targets */ {}, &unfrozen_outputs)); + + SessionOptions session_options; + std::unique_ptr frozen_session(NewSession(session_options)); + TF_ASSERT_OK(frozen_session->Create(frozen_graph_def)); + std::vector frozen_outputs; + TF_ASSERT_OK(frozen_session->Run(/* inputs */ {}, {tensor_name}, + /* targets */ {}, &frozen_outputs)); + + test::ExpectTensorEqual(unfrozen_outputs[0], frozen_outputs[0]); + } +}; + +TEST_F(FreezeTest, InputsAndOutputsSingleSignatureDef) { + // Test that inputs and outputs get correctly populated for a single + // SignatureDef. + SavedModelBundle saved_model_bundle; + std::unordered_set expected_inputs = {"input0:0", "input1:0"}; + std::unordered_set expected_outputs = {"output0:0", "output1:0"}; + SignatureDef signature_def = + BuildSignatureDef(expected_inputs, expected_outputs); + AddSignatureDefToSavedModelBundle(signature_def, "signature_def", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, InputsAndOutputsMultipleSignatureDefs) { + // Test that inputs and outputs get correctly merged and populated when + // multiple SignatureDefs are provided. + SavedModelBundle saved_model_bundle; + SignatureDef signature_def_0 = BuildSignatureDef({"input0:0"}, {"output0:0"}); + SignatureDef signature_def_1 = BuildSignatureDef({"input1:0"}, {"output1:0"}); + AddSignatureDefToSavedModelBundle(signature_def_0, "signature_def_0", + &saved_model_bundle); + AddSignatureDefToSavedModelBundle(signature_def_1, "signature_def_1", + &saved_model_bundle); + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + std::unordered_set expected_inputs = {"input0:0", "input1:0"}; + std::unordered_set expected_outputs = {"output0:0", "output1:0"}; + EXPECT_EQ(expected_inputs, inputs); + EXPECT_EQ(expected_outputs, outputs); +} + +TEST_F(FreezeTest, GraphDefVersionsAndLibrary) { + // Test that GraphDef versions and library are copied correctly into the + // frozen graph. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + graph_def.mutable_versions()->set_producer(1234); + graph_def.mutable_versions()->set_min_consumer(1234); + *graph_def.mutable_library()->add_function() = test::function::NonZero(); + TF_ASSERT_OK( + AddGraphDefToSavedModelBundle(graph_def, "", &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithNoVariables) { + // Test freezing a graph with no variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle(graph_def, {"c:0"}, "", + &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDefEqual(frozen_graph_def, graph_def); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNotNeededByOutputs) { + // Test freezing a graph with variables that are not needed by the outputs in + // the SignatureDef. The resulting graph shouldn't be frozen, but + // non-dependent nodes should be pruned. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output b = ops::Const(scope.WithOpName("b"), 10.0f, {}); + Output c = ops::Mul(scope.WithOpName("c"), a, b); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + GraphDef expected_graph_def; + Scope expected_scope = Scope::NewRootScope(); + Output expected_a = ops::Const(expected_scope.WithOpName("a"), 10.0f, {}); + Output expected_b = ops::Const(expected_scope.WithOpName("b"), 10.0f, {}); + Output expected_c = + ops::Mul(expected_scope.WithOpName("c"), expected_a, expected_b); + TF_ASSERT_OK(expected_scope.ToGraphDef(&expected_graph_def)); + + GraphDefEqual(frozen_graph_def, expected_graph_def); + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNeededByOutputs) { + // Test freezing a graph with variables that are needed by outputs in the + // SignatureDef. The variables should be frozen. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output c = ops::Mul(scope.WithOpName("c"), a, var); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + EXPECT_EQ(frozen_graph_def.node_size(), 3); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +TEST_F(FreezeTest, GraphDefWithVariablesNeededAndNotNeededByOutputs) { + // Test freezing a graph with some variables that are needed and not needed by + // the outputs in the SignatureDef. The resulting graph should only freeze + // dependent variables. + SavedModelBundle saved_model_bundle; + GraphDef graph_def; + Scope scope = Scope::NewRootScope(); + Output a = ops::Const(scope.WithOpName("a"), 10.0f, {}); + Output var = ops::Variable(scope.WithOpName("var"), {}, DataType::DT_FLOAT); + Output c = ops::Mul(scope.WithOpName("c"), a, var); + Output assign = ops::Assign(scope.WithOpName("assign"), var, a); + Output var_1 = + ops::Variable(scope.WithOpName("var_1"), {}, DataType::DT_FLOAT); + Output assign_1 = ops::Assign(scope.WithOpName("assign_1"), var, a); + TF_ASSERT_OK(scope.ToGraphDef(&graph_def)); + // "c" isnt dependent on the variable, so nothing should be frozen. + TF_ASSERT_OK(AddGraphDefWithOutputsToSavedModelBundle( + graph_def, {"c:0"}, assign.name(), &saved_model_bundle)); + + GraphDef frozen_graph_def; + std::unordered_set inputs; + std::unordered_set outputs; + TF_ASSERT_OK(FreezeSavedModel(saved_model_bundle, &frozen_graph_def, &inputs, + &outputs)); + + // There should be 3 nodes in the resulting graph_def, and none should be + // variables. + EXPECT_EQ(frozen_graph_def.node_size(), 3); + for (const NodeDef& node : frozen_graph_def.node()) { + EXPECT_NE(node.op(), "Variable") << node.name(); + EXPECT_NE(node.op(), "VariableV2") << node.name(); + } + + RunAndCompareFrozenAndUnfrozenGraphs(saved_model_bundle.session.get(), + frozen_graph_def, "c:0"); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/training/coordinator.h b/tensorflow/cc/training/coordinator.h index 0e01b19cd98bc797b7bb25da55c05d96f3eb93c7..7168b775251d38687d604b5294405389a8c1b04f 100644 --- a/tensorflow/cc/training/coordinator.h +++ b/tensorflow/cc/training/coordinator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#ifndef TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#define TENSORFLOW_CC_TRAINING_COORDINATOR_H_ #include #include @@ -128,4 +128,4 @@ class Coordinator { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_COORDINATOR_H_ +#endif // TENSORFLOW_CC_TRAINING_COORDINATOR_H_ diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 2d3450032388bfee96055f23cf621af0fa4731ae..21189b4b046b87b8609483109096fda6144681b8 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ -#define THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#ifndef TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#define TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ #include #include @@ -137,4 +137,4 @@ class QueueRunner : public RunnerInterface { } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ +#endif // TENSORFLOW_CC_TRAINING_QUEUE_RUNNER_H_ diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index a9a6ea84319a18a8fbce648391bf5918ff6d9a08..0900e87ebabd378e6237b77ca0ef01677c07c244 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -24,7 +24,6 @@ tf_cc_test( srcs = ["runtime_test.cc"], deps = [ ":runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -53,6 +52,7 @@ cc_library( "flags.h", ], deps = [ + ":embedded_protocol_buffers", ":runtime", # needed by codegen to print aligned_buffer_bytes "//tensorflow/compiler/tf2xla", "//tensorflow/compiler/tf2xla:common", @@ -69,9 +69,7 @@ cc_library( "//tensorflow/compiler/xla/client:compile_only_client", "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", - "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -81,13 +79,18 @@ cc_library( tf_cc_test( name = "codegen_test", srcs = ["codegen_test.cc"], - data = ["codegen_test_h.golden"], + data = [ + "codegen_test_h.golden", + "codegen_test_o.golden", + ], deps = [ ":tfcompile_lib", "//tensorflow/compiler/xla:shape_util", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", + "@llvm//:support", # fixdeps: keep + "@llvm//:x86_code_gen", # fixdeps: keep ], ) @@ -111,6 +114,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -128,7 +132,9 @@ tf_library( config = "test_graph_tfadd.config.pbtxt", cpp_class = "AddComp", graph = "test_graph_tfadd.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -139,7 +145,9 @@ tf_library( config = "test_graph_tfunknownop.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -151,7 +159,9 @@ tf_library( config = "test_graph_tfunknownop2.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # A test of tf_library that includes a graph with an unknown op, but where @@ -162,7 +172,9 @@ tf_library( config = "test_graph_tfunknownop3.config.pbtxt", cpp_class = "UnknownOpAddComp", graph = "test_graph_tfunknownop.pbtxt", - tags = ["manual"], + tags = [ + "manual", + ], ) # Utility library for benchmark binaries, used by the *_benchmark rules that are @@ -185,11 +197,27 @@ cc_library( name = "benchmark_extra_android", tags = [ "manual", - "notap", ], visibility = ["//visibility:public"], ) +cc_library( + name = "embedded_protocol_buffers", + srcs = ["embedded_protocol_buffers.cc"], + hdrs = ["embedded_protocol_buffers.h"], + deps = [ + "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + "@llvm//:execution_engine", + "@llvm//:support", + "@llvm//:target", + ], +) + tf_cc_test( name = "benchmark_test", srcs = ["benchmark_test.cc"], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index ae22f7edc423247b34895411d19d7a3c21f86d4f..2cae85e8965216eaaee4d3032015d0016258a5c1 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/aot/embedded_protocol_buffers.h" #include "tensorflow/compiler/aot/runtime.h" #include "tensorflow/compiler/tf2xla/str_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h" @@ -101,21 +102,8 @@ Status ComputeArgSizes(const CompileResult& compile_result, std::vector* arg_sizes) { const xla::ProgramShape& ps = compile_result.program_shape; for (int i = 0; i < ps.parameters_size(); ++i) { - if (i == ps.parameters_size() - 1 && compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = ps.parameters(i).element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(ps)); - } - arg_sizes->push_back(-1); - } else { - arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( - ps.parameters(i), compile_result.pointer_size)); - } + arg_sizes->push_back(xla::ShapeUtil::ByteSizeOf( + ps.parameters(i), compile_result.pointer_size)); } return Status::OK(); } @@ -165,11 +153,6 @@ string RewriteWithName(const string& name, string code, Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); - if (compile_result.has_context_arg) { - // If the compiled function needs a XlaLocalRuntimeContext* arg, it's - // always last, and is set in the class constructor. - num_args--; - } if (config.feed_size() != num_args) { return errors::InvalidArgument("mismatch between feed_size(", config.feed_size(), ") and num_args(", @@ -281,49 +264,6 @@ string GenNameToIndexCode(const T& entries, bool generate) { return code; } -// Converts the given `str` into a comma-separated list of per-character values. -string StringToCharList(const string& str) { - string list; - for (const char c : str) { - if (!list.empty()) { - list += ","; - } - list += strings::StrCat(static_cast(c)); - } - return list; -} - -string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) { - // No need for any static magic if we're not supposed to generate the data. - if (!generate) { - return "{\n return nullptr;\n }"; - } - // The parameter names are currently meaningless, and redundant with the rest - // of our metadata, so clear them out to avoid confusion and save space. - program_shape.clear_parameter_names(); - const string proto_str = program_shape.SerializeAsString(); - // Embed the program shape as a serialized protobuf in the header file. - // - // TODO(toddw): This strategy will likely fail for larger protobufs, depending - // on the C++ compiler that is used. Figure out another solution if necessary. - string code = R"({ - static const xla::ProgramShape* kShape = []() { - static const char kProto[] = {{{PROTO_LIST}}}; - static constexpr int kProtoSize = {{PROTO_SIZE}}; - xla::ProgramShape* shape = new xla::ProgramShape; - shape->ParseFromArray(kProto, kProtoSize); - return shape; - }(); - return kShape; - })"; - str_util::ReplaceAllPairs( - &code, { - {"{{PROTO_LIST}}", StringToCharList(proto_str)}, - {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())}, - }); - return code; -} - Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { for (const tf2xla::Feed& feed : config.feed()) { if (!feed.name().empty()) { @@ -340,8 +280,9 @@ Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { } // namespace -Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, - const CompileResult& compile_result, string* header) { +Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, + const CompileResult& compile_result, + const MetadataResult& metadata_result, string* header) { TF_RETURN_IF_ERROR(ValidateConfig(config)); TF_RETURN_IF_ERROR(ValidateFeedFetchCppNames(config)); const int64 result_index = compile_result.aot->result_buffer_index(); @@ -391,8 +332,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, ? R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" : ""; - const string program_shape_code = - GenProgramShapeCode(ps, opts.gen_program_shape); // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. @@ -418,7 +357,9 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); + +{{DECLS_FROM_OBJ_FILE}} {{NS_START}} // {{CLASS}} represents a computation previously specified in a @@ -474,7 +415,6 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -483,7 +423,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} {{CLASS}}(const {{CLASS}}&) = delete; @@ -496,8 +436,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -543,7 +483,10 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const char** StaticResultNames() {{RESULT_NAMES_CODE}} // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}} + static const xla::ProgramShape* StaticProgramShape() { + static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; + return kShape; + } }; {{NS_END}} @@ -560,26 +503,68 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{HAS_CONTEXT_ARG}}", - compile_result.has_context_arg ? "true" : "false"}, {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, - {"{{PROGRAM_SHAPE_CODE}}", program_shape_code}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, {"{{TEMP_SIZES}}", str_util::Join(temp_sizes, ", ")}, - }; + {"{{DECLS_FROM_OBJ_FILE}}", + str_util::Join(metadata_result.header_variable_decls, "\n")}, + {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", + metadata_result.program_shape_access_shim}}; str_util::ReplaceAllPairs(header, rewrites); return Status::OK(); } +static string CreateUniqueIdentifierForProgramShape(const CodegenOpts& opts) { + string result = "__tfcompile"; + for (const string& n : opts.namespaces) { + strings::StrAppend(&result, "_", n); + } + + strings::StrAppend(&result, "_", opts.class_name, "_ProgramShape"); + return result; +} + +Status GenerateMetadata(const CodegenOpts& opts, + const CompileResult& compile_result, + MetadataResult* metadata_result) { + std::unique_ptr program_shape; + + if (opts.gen_program_shape) { + program_shape = + tensorflow::MakeUnique(compile_result.program_shape); + // The parameter names are currently meaningless, and redundant with the + // rest of our metadata, so clear them out to avoid confusion and save + // space. + program_shape->clear_parameter_names(); + } + + // When asked to serialize a null protobuf, CreateEmbeddedProtocolBuffer gives + // a shim that evaluates to nullptr, which is what we want. + + TF_ASSIGN_OR_RETURN( + EmbeddedProtocolBuffer embedded_program_shape, + CreateEmbeddedProtocolBuffer(opts.target_triple, + CreateUniqueIdentifierForProgramShape(opts), + "xla::ProgramShape", program_shape.get())); + + metadata_result->program_shape_access_shim = + std::move(embedded_program_shape.cpp_shim_expression); + metadata_result->header_variable_decls.emplace_back( + std::move(embedded_program_shape.cpp_variable_decl)); + metadata_result->object_file_data = + std::move(embedded_program_shape.object_file_data); + return Status::OK(); +} + Status ParseCppClass(const string& cpp_class, string* class_name, std::vector* namespaces) { class_name->clear(); diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 76dd0cc3cf9470a1beb2a4725724f640aecfec7f..3430b1f96cf4d3c035b76c77ccf124c5d164751e 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -26,11 +26,15 @@ limitations under the License. namespace tensorflow { namespace tfcompile { -// HeaderOpts specifies options for header-file generation. -struct HeaderOpts { +// CodegenOpts specifies code generation options for the generated header file +// and the generated metadata object file. +struct CodegenOpts { // The name of the generated C++ class, wrapping the generated function. string class_name; + // Target triple for the architecture we're targeting. + string target_triple; + // Namespaces specifies a list of C++ namespaces to add to the generated // header. If empty, all symbols will be in the global namespace. std::vector namespaces; @@ -42,11 +46,36 @@ struct HeaderOpts { bool gen_program_shape = false; }; +// Describes a generated metadata object file. +struct MetadataResult { + // These are top level "extern C" declarations that are expected to be visible + // wherever program_shape_access_shim is emitted. + std::vector header_variable_decls; + + // program_shape_access_shim is a C++ expression that constructs the + // xla::ProgramShape instance for the CompileResult passed to + // GenerateMetadata. + string program_shape_access_shim; + + // The contents of the object (".o") file. + string object_file_data; +}; + +// Generates a metadata object file according to `opts` and `compile_result`. +// The generated object file is returned via `metadata_result`. +Status GenerateMetadata(const CodegenOpts& opts, + const CompileResult& compile_result, + MetadataResult* metadata_result); + // GenerateHeader uses the meta-information from compile_result to generate a // C++ header giving access to the function in the generated object file. The // header includes API usage documentation. -Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, - const CompileResult& compile_result, string* header); +// +// metadata_result is an instance of MetadataResult obtained by a previous +// invocation to GenerateMetadata. +Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, + const CompileResult& compile_result, + const MetadataResult& metadata_result, string* header); // ParseCppClass parses `cpp_class` into its `class_name` and `namespaces` // components. The syntax is [[::],...]. This diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 0f6114666fcc89c631434527d2ae8c92c039ffea..972b7d51ecb3798e61757ac55e973075a23b433a 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "llvm/Support/TargetSelect.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -123,9 +124,39 @@ TEST_F(ParseCppClassTest, ParseFail) { ExpectFail("good::0bad"); } -TEST(GenerateHeader, Golden) { - HeaderOpts opts; +static void CompareWithGoldenFile( + const string& tensorflow_relative_golden_file_name, + const string& expected_contents) { + // To update the golden file, flip update_golden to true and run the + // following: + // bazel test --test_strategy=local \ + // third_party/tensorflow/compiler/aot:codegen_test + const bool update_golden = false; + const string golden_file_name = io::JoinPath( + testing::TensorFlowSrcRoot(), tensorflow_relative_golden_file_name); + + if (update_golden) { + TF_EXPECT_OK( + WriteStringToFile(Env::Default(), golden_file_name, expected_contents)); + } + + string golden_file_contents; + TF_ASSERT_OK(ReadFileToString(Env::Default(), golden_file_name, + &golden_file_contents)); + EXPECT_EQ(golden_file_contents, expected_contents); +} + +TEST(CodegenTest, Golden) { + // Normally CpuCompiler::CpuCompiler does this, but in this test we've + // bypassed the Cpu compiler so we have to do this manually. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + LLVMInitializeX86Target(); + LLVMInitializeX86TargetMC(); + + CodegenOpts opts; opts.class_name = "MyClass"; + opts.target_triple = "x86_64-pc-linux"; opts.namespaces = {"foo", "bar"}; opts.gen_name_to_index = true; opts.gen_program_shape = true; @@ -145,32 +176,27 @@ TEST(GenerateHeader, Golden) { { xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - xla::ShapeUtil::MakeOpaqueShape(), }, xla::ShapeUtil::MakeTupleShape( {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); - compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; + + MetadataResult metadata_result; + TF_ASSERT_OK(GenerateMetadata(opts, compile_result, &metadata_result)); + + // The other fields in metadata_result are tested as part of the generated + // header test. + + CompareWithGoldenFile("compiler/aot/codegen_test_o.golden", + metadata_result.object_file_data); + string header; - TF_EXPECT_OK(GenerateHeader(opts, config, compile_result, &header)); + TF_ASSERT_OK( + GenerateHeader(opts, config, compile_result, metadata_result, &header)); - // Compare against the golden file. - const string golden_name = io::JoinPath(testing::TensorFlowSrcRoot(), - "compiler/aot/codegen_test_h.golden"); - // To update the golden file, flip update_golden to true and run the - // following: - // bazel test --test_strategy=local \ - // third_party/tensorflow/compiler/aot:codegen_test - const bool update_golden = false; - if (update_golden) { - TF_EXPECT_OK(WriteStringToFile(Env::Default(), golden_name, header)); - } - string golden_data; - TF_EXPECT_OK(ReadFileToString(Env::Default(), golden_name, &golden_data)); - EXPECT_EQ(header, golden_data); + CompareWithGoldenFile("compiler/aot/codegen_test_h.golden", header); } - } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 65f342ce27ef09092f252f791973f245a8cdd6f3..ac3b5873318873b5fdf41bd556a0b2abddc2b30b 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -19,7 +19,9 @@ namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, - const void** args, void** temps); + const void** args, void** temps, tensorflow::int64* profile_counters); + +extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; namespace foo { namespace bar { @@ -48,7 +50,7 @@ namespace bar { // is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) +// ((unknown): f32[1,2], (unknown): s64[3,4]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 @@ -58,11 +60,11 @@ namespace bar { class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. - static constexpr size_t kNumArgs = 3; + static constexpr size_t kNumArgs = 2; // Byte size of each argument buffer. There are kNumArgs entries. static const intptr_t* ArgSizes() { - static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96, -1}; + static constexpr intptr_t kArgSizes[kNumArgs] = {8, 96}; return kArgSizes; } @@ -77,7 +79,6 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { data->temp_sizes = TempSizes(); data->num_temps = kNumTemps; data->result_index = kResultIndex; - data->requires_runtime_context = true; data->arg_names = StaticArgNames(); data->result_names = StaticResultNames(); data->program_shape = StaticProgramShape(); @@ -86,7 +87,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { return *kStaticData; } - MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} MyClass(const MyClass&) = delete; @@ -99,8 +100,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // void set_argN_data(void* data) // Sets the buffer of type T for positional argument N. May be called in // any AllocMode. Must be called before Run to have an affect. Must be - // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument, - // to set the argument buffers. + // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional + // argument, to set the argument buffers. // // T* argN_data() // Returns the buffer of type T for positional argument N. @@ -236,12 +237,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { // Shape of the args and results. static const xla::ProgramShape* StaticProgramShape() { static const xla::ProgramShape* kShape = []() { - static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; - static constexpr int kProtoSize = 50; - xla::ProgramShape* shape = new xla::ProgramShape; - shape->ParseFromArray(kProto, kProtoSize); - return shape; - }(); + xla::ProgramShape* proto = new xla::ProgramShape; + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52); + return proto; + }(); return kShape; } }; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden new file mode 100644 index 0000000000000000000000000000000000000000..eb001c5d45bdfefc76629d7303d89f5480432235 Binary files /dev/null and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b8cc6024cb85e4f6269313927ff66d1d9a1cf79..c87f2b75dfa18ad5c3eda4bd6fcbcb3083ef73fd 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -94,9 +94,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform) .ValueOrDie(); xla::Computation computation; - TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, - &computation, - &compile_result->has_context_arg)); + TF_RETURN_IF_ERROR( + ConvertGraphDefToXla(graph_def, config, client, &computation)); if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 965c2960816b3acc8d2209e6824d88647de0ce14..e03c5b1aa77c1262ed903aae3072ef65f34d80a2 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -34,7 +34,6 @@ struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; xla::ProgramShape program_shape; // Static shape of args and results. - bool has_context_arg = false; // Is last arg XlaLocalRuntimeContext? string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. }; diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.cc b/tensorflow/compiler/aot/embedded_protocol_buffers.cc new file mode 100644 index 0000000000000000000000000000000000000000..6489929a576d6469c4ff1358ca5ee9d27fb578bb --- /dev/null +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.cc @@ -0,0 +1,158 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/embedded_protocol_buffers.h" + +#include +#include + +#include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/ObjectMemoryBuffer.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/TargetRegistry.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "tensorflow/compiler/tf2xla/str_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/util.h" + +namespace tensorflow { +namespace tfcompile { + +using xla::llvm_ir::AsStringRef; + +static std::unique_ptr CreateModuleWithEmbeddedProtocolBuffer( + llvm::LLVMContext* llvm_context, llvm::TargetMachine* target_machine, + const ::tensorflow::protobuf::MessageLite& proto, + StringPiece unique_identifier, string* protobuf_array_symbol_name, + int64* protobuf_array_size) { + string protobuf_array_contents = proto.SerializeAsString(); + *protobuf_array_symbol_name = + strings::StrCat(unique_identifier, "_protobuf_array_contents"); + *protobuf_array_size = protobuf_array_contents.size(); + + std::unique_ptr module = + MakeUnique("embedded_data_module", *llvm_context); + + llvm::Constant* protobuf_array_initializer = + llvm::ConstantDataArray::getString(*llvm_context, + AsStringRef(protobuf_array_contents), + /*AddNull=*/false); + new llvm::GlobalVariable( + *module, protobuf_array_initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, + protobuf_array_initializer, AsStringRef(*protobuf_array_symbol_name)); + + return module; +} + +static string CreateCPPShimExpression(StringPiece qualified_cpp_protobuf_name, + StringPiece protobuf_array_symbol_name, + int64 protobuf_array_size) { + string code = + "[]() {\n" + " {{PROTOBUF_NAME}}* proto = new {{PROTOBUF_NAME}};\n" + " proto->ParseFromArray(&{{ARRAY_SYMBOL}}[0], {{ARRAY_SIZE}});\n" + " return proto;\n" + " }()"; + + str_util::ReplaceAllPairs( + &code, + { + {"{{ARRAY_SYMBOL}}", strings::StrCat(protobuf_array_symbol_name)}, + {"{{ARRAY_SIZE}}", strings::StrCat(protobuf_array_size)}, + {"{{PROTOBUF_NAME}}", strings::StrCat(qualified_cpp_protobuf_name)}, + }); + return code; +} + +static StatusOr CodegenModule(llvm::TargetMachine* target_machine, + std::unique_ptr module) { + llvm::SmallVector stream_buffer; + llvm::raw_svector_ostream ostream(stream_buffer); + llvm::legacy::PassManager codegen_passes; + + if (target_machine->addPassesToEmitFile( + codegen_passes, ostream, llvm::TargetMachine::CGFT_ObjectFile)) { + return xla::InternalError( + "Could not create pass pipeline to generate object file"); + } + + codegen_passes.run(*module); + + return string(stream_buffer.begin(), stream_buffer.end()); +} + +static StatusOr> +GetTargetMachineFromTriple(StringPiece target_triple) { + std::string error; + std::string normalized_triple = + llvm::Triple::normalize(AsStringRef(target_triple)); + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(normalized_triple, error); + if (target == nullptr) { + return xla::InternalError("TargetRegistry::lookupTarget failed: %s", + error.c_str()); + } + + return WrapUnique(target->createTargetMachine( + normalized_triple, /*CPU=*/"", + /*Features=*/"", llvm::TargetOptions(), llvm::None)); +} + +StatusOr CreateEmbeddedProtocolBuffer( + StringPiece target_triple, StringPiece symbol_prefix, + StringPiece qualified_cpp_protobuf_name, + const ::tensorflow::protobuf::MessageLite* proto) { + TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, + GetTargetMachineFromTriple(target_triple)); + + llvm::LLVMContext llvm_context; + string object_file, cpp_shim, cpp_variable_decl; + + if (proto) { + string protobuf_array_symbol_name; + int64 protobuf_array_size; + + std::unique_ptr module_with_serialized_proto = + CreateModuleWithEmbeddedProtocolBuffer( + &llvm_context, target_machine.get(), *proto, symbol_prefix, + &protobuf_array_symbol_name, &protobuf_array_size); + TF_ASSIGN_OR_RETURN(object_file, + CodegenModule(target_machine.get(), + std::move(module_with_serialized_proto))); + cpp_shim = CreateCPPShimExpression(qualified_cpp_protobuf_name, + protobuf_array_symbol_name, + protobuf_array_size); + + cpp_variable_decl = strings::StrCat("extern \"C\" char ", + protobuf_array_symbol_name, "[];"); + } else { + TF_ASSIGN_OR_RETURN( + object_file, + CodegenModule(target_machine.get(), + MakeUnique("empty_module", llvm_context))); + cpp_shim = "nullptr"; + } + + return {{cpp_shim, cpp_variable_decl, object_file}}; +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/embedded_protocol_buffers.h b/tensorflow/compiler/aot/embedded_protocol_buffers.h new file mode 100644 index 0000000000000000000000000000000000000000..8436e0ff67f352a24e3d16b46f16c1ad2f3a5957 --- /dev/null +++ b/tensorflow/compiler/aot/embedded_protocol_buffers.h @@ -0,0 +1,73 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines utilities to help "embed" protocol buffers into object +// (".o") files. These C++ binaries and shared objects can link in these .o to +// get access to said protocol buffers at runtime. + +#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ +#define TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace tfcompile { +using xla::StatusOr; + +// Represents a protocol buffer embedded into an object file and describes a way +// to access it at runtime. +struct EmbeddedProtocolBuffer { + // cpp_shim_expression is a C++ expression that creates an instance of said + // protocol buffer when executed. + string cpp_shim_expression; + + // cpp_variable_decl is an "extern C" array declaration that is used in + // cpp_shim_expression. It must be visible wherever cpp_shim_expression is + // emitted. + string cpp_variable_decl; + + // The contents of the object (".o") file the protocol buffer is embbed in. + // This needs to be linked in to any program that wants to execute + // cpp_variable_decl . + string object_file_data; +}; + +// Creates an object file that contains `proto`. +// +// `proto` is allowed to be nullptr, in which case the generated C++ shim +// expression is just `nullptr`, and the generated object file does not define +// any symbols. +// +// `target_triple` is the target triple for the target architecture for the +// generated object file. +// +// `symbol_prefix` is prefix that is guaranteed to be unique across the binary +// or DSO the generated object file will be linked into. +// +// `qualified_cpp_protobuf_name` is a qualified ("qualified" as in C++ +// namespace qualified) protocol buffer name. This needs is only used in +// EmbeddedProtocolBuffer::cpp_shim_expression so relatively qualified +// names are fine as long as they're valid wherever cpp_shim_expression +// is emitted. +StatusOr CreateEmbeddedProtocolBuffer( + StringPiece target_triple, StringPiece symbol_prefix, + StringPiece qualified_cpp_protobuf_name, + const ::tensorflow::protobuf::MessageLite* proto); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_EMBEDDED_PROTOCOL_BUFFERS_H_ diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 7c2f27e550d44c2487f91acf1029c962ac3f5d01..8c95cb8f90ee031fdbb97fabd9d86f848b42e4c5 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -59,8 +59,13 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "namespaces may precede the class name, separated by double-colons. " "The class will be generated in the given namespace(s), or if no " "namespaces are given, within the global namespace."}, - {"out_object", &flags->out_object, "Output object file name."}, + {"out_function_object", &flags->out_function_object, + "Output object file containing the generated function for the " + "TensorFlow model."}, {"out_header", &flags->out_header, "Output header file name."}, + {"out_metadata_object", &flags->out_metadata_object, + "Output object file name containing optional metadata for the generated " + "function."}, {"out_session_module", &flags->out_session_module, "Output session module proto."}, {"gen_name_to_index", &flags->gen_name_to_index, diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index 3519659e3af7cd345f30080a07ce91fb858623fb..d266fbead61f7eb43863d1c67c0f86926ae9452d 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -34,7 +34,8 @@ struct MainFlags { string target_features; string entry_point; string cpp_class; - string out_object; + string out_function_object; + string out_metadata_object; string out_header; string out_session_module; diff --git a/tensorflow/compiler/aot/runtime_test.cc b/tensorflow/compiler/aot/runtime_test.cc index ac79c278c1fdf8b6aedcb52121c767b8ba0ad358..6d603a02eb4ceade6832ba67b2981814ee25327a 100644 --- a/tensorflow/compiler/aot/runtime_test.cc +++ b/tensorflow/compiler/aot/runtime_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 7dfd49cc3b92f83fd64ca62bd2230938ce2d0a65..28aab6eb614ca7123d9e00f7f5cc3661b62e23f7 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -74,7 +74,9 @@ tf_library( # compile but the others in this directory succeed, you may need to # expand the "required by all tf_library targets" list in tfcompile.bzl. include_standard_runtime_deps = False, - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -84,7 +86,9 @@ tf_library( cpp_class = "AddWithCkptComp", freeze_checkpoint = "test_graph_tfadd_with_ckpt.ckpt", graph = "test_graph_tfadd_with_ckpt.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -95,7 +99,9 @@ tf_library( freeze_checkpoint = "test_graph_tfadd_with_ckpt_saver.ckpt", freeze_saver = "test_graph_tfadd_with_ckpt_saver.saver", graph = "test_graph_tfadd_with_ckpt_saver.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -104,7 +110,9 @@ tf_library( config = "test_graph_tffunction.config.pbtxt", cpp_class = "FunctionComp", graph = "test_graph_tffunction.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -113,7 +121,9 @@ tf_library( config = "test_graph_tfgather.config.pbtxt", cpp_class = "GatherComp", graph = "test_graph_tfgather.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -122,7 +132,9 @@ tf_library( config = "test_graph_tfmatmul.config.pbtxt", cpp_class = "foo::bar::MatMulComp", graph = "test_graph_tfmatmul.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_library( @@ -131,7 +143,9 @@ tf_library( config = "test_graph_tfmatmulandadd.config.pbtxt", cpp_class = "MatMulAndAddComp", graph = "test_graph_tfmatmulandadd.pb", - tags = ["manual"], + tags = [ + "manual", + ], tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) @@ -141,13 +155,17 @@ tf_library( config = "test_graph_tfsplits.config.pbtxt", cpp_class = "SplitsComp", graph = "test_graph_tfsplits.pb", - tags = ["manual"], + tags = [ + "manual", + ], ) tf_cc_test( name = "tfcompile_test", srcs = ["tfcompile_test.cc"], - tags = ["manual"], + tags = [ + "manual", + ], deps = [ ":test_graph_tfadd", ":test_graph_tfadd_with_ckpt", diff --git a/tensorflow/compiler/aot/tests/make_test_graphs.py b/tensorflow/compiler/aot/tests/make_test_graphs.py index a898eab1d1ab0eb5d55983bf366753c968887296..89c7cd4507cbd476104a039d6083d8f89de11278 100644 --- a/tensorflow/compiler/aot/tests/make_test_graphs.py +++ b/tensorflow/compiler/aot/tests/make_test_graphs.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import argparse +import os import sys from tensorflow.core.protobuf import saver_pb2 @@ -53,7 +54,7 @@ def tfadd_with_ckpt(out_dir): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt = '%s/test_graph_tfadd_with_ckpt.ckpt' % out_dir + ckpt = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt.ckpt') saver.save(sess, ckpt) @@ -68,10 +69,10 @@ def tfadd_with_ckpt_saver(out_dir): sess.run(init_op) sess.run(y.assign(y + 42)) # Without the checkpoint, the variable won't be set to 42. - ckpt_file = '%s/test_graph_tfadd_with_ckpt_saver.ckpt' % out_dir + ckpt_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.ckpt') saver.save(sess, ckpt_file) # Without the SaverDef, the restore op won't be named correctly. - saver_file = '%s/test_graph_tfadd_with_ckpt_saver.saver' % out_dir + saver_file = os.path.join(out_dir, 'test_graph_tfadd_with_ckpt_saver.saver') with open(saver_file, 'wb') as f: f.write(saver.as_saver_def().SerializeToString()) @@ -129,7 +130,7 @@ def write_graph(build_graph, out_dir): g = ops.Graph() with g.as_default(): build_graph(out_dir) - filename = '%s/test_graph_%s.pb' % (out_dir, build_graph.__name__) + filename = os.path.join(out_dir, 'test_graph_%s.pb' % build_graph.__name__) with open(filename, 'wb') as f: f.write(g.as_graph_def().SerializeToString()) diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 6b037f276ad1d6771b904bb970f45f32ae9531b8..413efd9cea3b6f71574615ad9ca92471ff925781 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) { // Run tests that use set_argN_data separately, to avoid accidentally re-using // non-existent buffers. TEST(TFCompileTest, Add_SetArg) { - AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); int32 arg_x = 10; int32 arg_y = 32; @@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); foo::bar::MatMulComp matmul( - foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY); + foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 6c385af3b36df78b3f674b3464d68d904ca92907..9dff1be09fede6f65f82c2f36d94be07e781949f 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -4,7 +4,7 @@ To use from your BUILD file, add the following line to load the macro: -load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") Then call the macro like this: @@ -16,14 +16,15 @@ tf_library( ) """ -load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts") +load("//tensorflow:tensorflow.bzl", + "if_android", "tf_cc_test", "tf_copts") def tf_library(name, graph, config, freeze_checkpoint=None, freeze_saver=None, cpp_class=None, gen_test=True, gen_benchmark=True, visibility=None, testonly=None, tfcompile_flags=None, - tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile", + tfcompile_tool="//tensorflow/compiler/aot:tfcompile", include_standard_runtime_deps=True, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. @@ -102,6 +103,7 @@ def tf_library(name, graph, config, # Now run freeze_graph to convert variables into constants. freeze_args = (" --input_graph=$(location " + graph + ")" + + " --checkpoint_version=1" + " --input_binary=" + str(not graph.endswith(".pbtxt")) + " --input_checkpoint=$(location " + freeze_checkpoint + ")" + " --output_graph=$(location " + freeze_file + ")" + @@ -119,16 +121,17 @@ def tf_library(name, graph, config, out_nodes_file, ] + freeze_saver_srcs, outs=[freeze_file], - cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" + + cmd=("$(location //tensorflow/python/tools:freeze_graph)" + freeze_args), - tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"], + tools=["//tensorflow/python/tools:freeze_graph"], tags=tags, ) tfcompile_graph = freeze_file # Rule that runs tfcompile to produce the header and object file. header_file = name + ".h" - object_file = name + ".o" + metadata_object_file = name + "_tfcompile_metadata.o" + function_object_file = name + "_tfcompile_function.o" ep = ("__" + PACKAGE_NAME + "__" + name).replace("/", "_") if type(tfcompile_flags) == type(""): flags = tfcompile_flags @@ -142,7 +145,8 @@ def tf_library(name, graph, config, ], outs=[ header_file, - object_file, + metadata_object_file, + function_object_file, ], cmd=("$(location " + tfcompile_tool + ")" + " --graph=$(location " + tfcompile_graph + ")" + @@ -151,7 +155,8 @@ def tf_library(name, graph, config, " --cpp_class=" + cpp_class + " --target_triple=" + target_llvm_triple() + " --out_header=$(@D)/" + header_file + - " --out_object=$(@D)/" + object_file + + " --out_metadata_object=$(@D)/" + metadata_object_file + + " --out_function_object=$(@D)/" + function_object_file + " " + flags), tools=[tfcompile_tool], visibility=visibility, @@ -202,7 +207,7 @@ def tf_library(name, graph, config, need_xla_data_proto = (flags and flags.find("--gen_program_shape") != -1) native.cc_library( name=name, - srcs=[object_file], + srcs=[function_object_file, metadata_object_file], hdrs=[header_file], visibility=visibility, testonly=testonly, @@ -210,22 +215,19 @@ def tf_library(name, graph, config, # These deps are required by all tf_library targets even if # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", - "@org_tensorflow//tensorflow/core:framework_lite", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/core:framework_lite", ] + (need_xla_data_proto and [ # If we're generating the program shape, we must depend on the proto. - "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla:xla_data_proto", ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", ] or []) + (deps or []), tags=tags, @@ -251,29 +253,32 @@ def tf_library(name, graph, config, name=("gen_" + test_name), testonly=1, srcs=[ - "@org_tensorflow//tensorflow/compiler/aot:test.cc", + "//tensorflow/compiler/aot:test.cc", header_file, ], outs=[test_file], cmd=("sed " + sed_replace + - " $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " + + " $(location //tensorflow/compiler/aot:test.cc) " + "> $(OUTS)"), tags=tags, ) - # The cc_test rule for the generated code. - native.cc_test( + # The cc_test rule for the generated code. To ensure that this works + # reliably across build configurations, we must use tf_cc_test instead of + # native.cc_test. This is related to how we build + # //tensorflow/core:lib -- see the note in tensorflow/core/BUILD + # for more details. + tf_cc_test( name=test_name, srcs=[test_file], deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "@org_tensorflow//tensorflow/compiler/aot:runtime", - "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", - "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/aot:tf_library_test_main", + "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", - "@org_tensorflow//tensorflow/core:lib", - "@org_tensorflow//tensorflow/core:test", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], tags=tags, ) @@ -281,7 +286,7 @@ def tf_library(name, graph, config, if gen_benchmark: benchmark_name = name + "_benchmark" benchmark_file = benchmark_name + ".cc" - benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" + + benchmark_main = ("//tensorflow/compiler/aot:" + "benchmark_main.template") # Rule to rewrite benchmark.cc to produce the benchmark_file. @@ -299,7 +304,9 @@ def tf_library(name, graph, config, tags=tags, ) - # The cc_benchmark rule for the generated code. + # The cc_benchmark rule for the generated code. This does not need the + # tf_cc_binary since we (by deliberate design) do not depend on + # //tensorflow/core:lib. # # Note: to get smaller size on android for comparison, compile with: # --copt=-fvisibility=hidden @@ -313,13 +320,12 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "@org_tensorflow//tensorflow/compiler/aot:benchmark", - "@org_tensorflow//tensorflow/compiler/aot:runtime", - "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/aot:benchmark", + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", ] + if_android([ - "@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android", + "//tensorflow/compiler/aot:benchmark_extra_android", ]), tags=tags, ) @@ -329,11 +335,11 @@ def target_llvm_triple(): # TODO(toddw): Add target_triple for other targets. For details see: # http://llvm.org/docs/doxygen/html/Triple_8h_source.html return select({ - "@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android", - "@org_tensorflow//tensorflow:android_arm": "armv7-none-android", - "@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android", - "@org_tensorflow//tensorflow:android_x86": "i686-none-android", - "@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin", + "//tensorflow:android_armeabi": "armv5-none-android", + "//tensorflow:android_arm": "armv7-none-android", + "//tensorflow:android_arm64": "aarch64-none-android", + "//tensorflow:android_x86": "i686-none-android", + "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "//tensorflow:darwin": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 6ab3d474187c7df2131f94c9f42f0d0f2f9d99d7..e2f01179d4e2e4f6ef72b2761d06e130ffa3a94f 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -91,19 +91,26 @@ Status Main(const MainFlags& flags) { // Write output files. Env* env = Env::Default(); const std::vector& obj = compile_result.aot->object_file_data(); - TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_function_object, StringPiece(obj.data(), obj.size()))); - HeaderOpts header_opts; - header_opts.gen_name_to_index = flags.gen_name_to_index; - header_opts.gen_program_shape = flags.gen_program_shape; + CodegenOpts codegen_opts; + codegen_opts.gen_name_to_index = flags.gen_name_to_index; + codegen_opts.gen_program_shape = flags.gen_program_shape; + codegen_opts.target_triple = flags.target_triple; if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } - TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &header_opts.class_name, - &header_opts.namespaces)); - string header; + TF_RETURN_IF_ERROR(ParseCppClass(flags.cpp_class, &codegen_opts.class_name, + &codegen_opts.namespaces)); + + MetadataResult metadata_result; TF_RETURN_IF_ERROR( - GenerateHeader(header_opts, config, compile_result, &header)); + GenerateMetadata(codegen_opts, compile_result, &metadata_result)); + TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_metadata_object, + metadata_result.object_file_data)); + string header; + TF_RETURN_IF_ERROR(GenerateHeader(codegen_opts, config, compile_result, + metadata_result, &header)); TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_header, header)); return Status::OK(); } @@ -114,7 +121,8 @@ Status Main(const MainFlags& flags) { int main(int argc, char** argv) { tensorflow::tfcompile::MainFlags flags; flags.target_triple = "x86_64-pc-linux"; - flags.out_object = "out.o"; + flags.out_function_object = "out_model.o"; + flags.out_metadata_object = "out_helper.o"; flags.out_header = "out.h"; flags.entry_point = "entry"; diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index bf7d9cf14d10f41aa48ea594a8d63db97b9973e1..a711319607f4ff2b83aa0ebe50e215b3d0e2258e 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -110,19 +110,6 @@ cc_library( alwayslink = True, ) -# Internal targets below this point. - -cc_library( - name = "common", - srcs = [ - "defs.cc", - ], - hdrs = [ - "defs.h", - ], - visibility = [":friends"], -) - cc_library( name = "xla_device", srcs = [ @@ -135,6 +122,8 @@ cc_library( "xla_device_context.h", "xla_device_ops.h", ], + # Public visibility is needed for external TF/XLA backends. + visibility = ["//visibility:public"], deps = [ ":common", ":jit_compilation_passes", @@ -164,6 +153,19 @@ cc_library( ], ) +# Internal targets below this point. + +cc_library( + name = "common", + srcs = [ + "defs.cc", + ], + hdrs = [ + "defs.h", + ], + visibility = [":friends"], +) + cc_library( name = "xla_compilation_cache", srcs = ["xla_compilation_cache.cc"], @@ -215,7 +217,6 @@ cc_library( ":common", ":compilation_passes", "//tensorflow/compiler/jit/kernels:xla_launch_op", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -245,12 +246,13 @@ cc_library( "//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags", "//tensorflow/compiler/jit/ops:parallel_check_op", "//tensorflow/compiler/jit/ops:xla_ops", - "//tensorflow/compiler/tf2xla:const_analysis", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc index 22899ebeebc929055518893b358f7950d380d6f6..9c372a012789fc25ca0a711349c09ca62edc6754 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc @@ -16,22 +16,30 @@ limitations under the License. #include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h" #include +#include #include +#include +#include +#include #include "tensorflow/compiler/jit/graph_to_functiondef.h" #include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/graph/tensor_id.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -48,19 +56,75 @@ const char* const kXlaNumResourceArgsAttr = "_XlaNumResourceArgs"; namespace { +bool AreAllParentsConst(const Node& n, + const gtl::FlatSet& runtime_const_nodes) { + if (n.type_string() == "GuaranteeConst" || n.type_string() == "Const") { + // If the current node is itself a cast-to-const, no need + // to look at the incoming edges. + return true; + } + + bool all_parents_const = true; + bool atleast_one_non_control_edge = false; + for (const Edge* in : n.in_edges()) { + atleast_one_non_control_edge = + atleast_one_non_control_edge || !in->IsControlEdge(); + if (!in->IsControlEdge() && runtime_const_nodes.count(in->src()) == 0) { + all_parents_const = false; + break; + } + } + return all_parents_const && atleast_one_non_control_edge; +} + +void MarkGuaranteedConstants( + const Graph& graph, + const std::vector>& src_arg_pairs) { + gtl::FlatSet guaranteed_const_nodes; + std::vector srcs; + srcs.reserve(src_arg_pairs.size()); + for (const auto& src_arg : src_arg_pairs) { + srcs.push_back(src_arg.first); + } + ReverseDFSFrom(graph, srcs, /*enter=*/nullptr, + /*leave=*/[&guaranteed_const_nodes](const Node* n) { + // TODO(vinuraja): Doesn't work in the presence of loops. + if (AreAllParentsConst(*n, guaranteed_const_nodes)) { + guaranteed_const_nodes.insert(n); + } + }); + + for (auto& src_arg : src_arg_pairs) { + if (guaranteed_const_nodes.count(src_arg.first) != 0) { + VLOG(1) << "Guaranteed const found: " << src_arg.first->DebugString(); + src_arg.second->AddAttr("_is_guaranteed_constant", true); + } + } +} + // A node/slot pair. // TODO(phawkins): is there a common definition of this? struct NodeSlot { - NodeSlot() : node(nullptr), slot(-1) {} - NodeSlot(const Node* node, int slot) : node(node), slot(slot) {} + NodeSlot() : node(nullptr), slot(-1), dtype(DT_INVALID) {} + NodeSlot(const Node* node, int slot) + : node(node), slot(slot), dtype(DT_INVALID) {} + NodeSlot(const Node* node, int slot, DataType dtype) + : node(node), slot(slot), dtype(dtype) {} const Node* node; int slot; + // Optional: used to record the destination type of a source NodeSlot in case + // the source output is a Ref type that is cast to a Tensor at the + // destination. + DataType dtype; + bool operator==(const NodeSlot& other) const { - return node == other.node && slot == other.slot; + return node == other.node && slot == other.slot && dtype == other.dtype; } + // Leave dtype out of the hash since there are never two NodeSlots with the + // same node and slot and different dtypes. struct Hasher { uint64 operator()(NodeSlot const& s) const { return Hash64Combine(std::hash()(s.node), @@ -75,10 +139,22 @@ struct NodeSlot { }; }; +// TODO(phawkins) add a canonical copy of these operator names and refactor +// everything to use it. +static const char* const kArgOp = "_Arg"; +static const char* const kRetValOp = "_Retval"; +static const char* const kHostComputeOp = "_XlaHostCompute"; +static const char* const kSendFromHostOp = "_XlaSendFromHost"; +static const char* const kRecvAtHostOp = "_XlaRecvAtHost"; + class Encapsulator { public: - Encapsulator(string group_attribute, Graph const* graph_in) - : group_attribute_(std::move(group_attribute)), graph_in_(graph_in) {} + Encapsulator(string group_attribute, string outside_compilation_attribute, + Graph const* graph_in) + : group_attribute_(std::move(group_attribute)), + outside_compilation_attribute_( + std::move(outside_compilation_attribute)), + graph_in_(graph_in) {} // Find subgraphs marked with 'group_attribute', and build a new // subgraph, one for each value of 'group_attribute'. @@ -96,57 +172,419 @@ class Encapsulator { // Write a copy of the input graph to 'graph_out', where the subgraphs are // replaced with calls to the new functions. - Status BuildOutputGraph(bool parallel_checking, Graph* graph_out); + Status BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library); private: - // Returns the key attribute associated with a node. Returns the empty string - // if no key attribute is found. - string GetFunctionNameAttr(const Node* node) const; - // A subgraph of the input, all marked with a common 'group_attribute' - // value. - struct Subgraph { + // value. A subgraph may contain multiple `outside_compilation' clusters. + // + // In the following simple example, A, B, ..., E are nodes in the original + // graph. The group attributes and outside_compilation attributes g and oc are + // each shown as either 0 or empty. + // + // A --> B --> C --> D --> E + // g: g:0 g:0 g:0 g: + // oc: oc: oc:0 oc: oc: + // + // The example is rewritten to two graphs; one on the host and one to be + // compiled. The host graph is as follows. RAH is a RecvAtHost node receiving + // input from the compiled cluster, and SFH is a SendFromHost node sending + // input back to the compiled cluster. Dotted edges are control edges. A + // 'sequencing' node S is inserted, and both RAH and SFH are connected via S + // to E (and in general all nodes that depend on nodes in the compiled + // cluster) to ensure that they are not pruned. + // + // A --> Call --> E + // ^ + // . + // ........> S + // .... ^ + // .. . + // RAH --> C --> SFH + // + // The compiled cluster is as follows. HC is a HostCompute node which is the + // source of a channel to the RAH node above and the destination of a channel + // from the SFH node above. + // + // Arg --> B --> HC --> D --> Retval + // + // The channels HC/RAH and SFH/HC each transmit multiple tensors, so there is + // at most one RAH and SFH in each outside_compilation cluster. This design is + // preferred over adding separate Arg/Retval nodes for each transmitted value + // because it allows optimizations to the host code that would like to limit + // communication between host and device and, e.g., raise only one interrupt + // per channel rather than one per transmitted value. + // + // The shapes of the outputs from the HC node in general cannot be determined + // until the shapes of its inputs are known at compile time, since e.g., + // above, the shape of C's outputs aren't known until the shape of its inputs + // are known. If the shapes of the HC's outputs can be determined during the + // rewrite, they are stored in the node's 'shapes' attr. Otherwise a minimal + // graph is stored in the shape_inference_graph attr. This graph can be used + // when compiling the HC Op to determined the shape of the SFH inputs given + // the shapes of any ancestor RAH outputs. If it can be determined that the + // shape of the SFH inputs will not be inferrable even once the shapes of the + // RAH outputs are known, an error is returned by the rewriter. + class Subgraph { + public: + // Creates a graph to build the subgraph in, if it doesn't already exist, + // using the same op registry and versions as graph_in. + Node* MakeNodeImage(const Graph* graph_in, Node* node); + + // Returns the graph the subgraph is being built in. + Graph* GetGraph() const; + + // Builds a FunctionDef, and adds it to 'library'. The value of the + // 'group_attribute' annotations becomes the function name. If + // 'reuse_existing_functions' is set, use an existing function with the same + // name, if any. If 'rewrite_subgraph_fn' is set, it is applied to the + // subgraph before function conversion. + Status BuildFunctionDef(const string& name_in, + const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, + FunctionLibraryDefinition* library); + + // Adds the function call node to graph_out. + Status AddFunctionCallNode( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); + + // Adds _RecvAtHost and _SendFromHost nodes, where needed, to graph_out. + Status AddOutsideCompilationHostIONodes( + const string& subgraph_name, + const std::unordered_map& node_images, + Graph* graph_out); + + // Returns the names of all the outside_compilation subgraphs in this + // Subgraph. + void GetOutsideCompilationSubgraphNames(std::vector* names) const; + + // Returns the Node that inputs to the function should be wired up to. + Node* GetCallNodeForInputs() const; + + // Returns the Node that outputs to the function should be wired up to. + Node* GetCallNodeForOutputs() const; + + // Returns the index of the arg that the dst of edge should connect to. + int GetArgIndexForEdge(const Edge* edge) const; + + // Returns the index of the result that the src of edge should connect to. + int GetResultIndexForEdge(const Edge* edge) const; + + // Returns the RecvAtHost node for an outside_compilation subgraph. + Node* GetRecvAtHostNode( + const string& outside_compilation_subgraph_name) const; + + // Returns the output slot for the RecvAtHost node that corresponds to the + // source of edge in an outside_compilation subgraph. + int GetRecvAtHostSlot(const string& outside_compilation_subgraph_name, + const Edge* edge) const; + + // Returns the SendFromHost node for an outside_compilation subgraph. + Node* GetSendFromHostNode( + const string& outside_compilation_subgraph_name) const; + + // Returns the input slot for the SendFromHost node that corresponds to the + // destination of edge in an outside_compilation subgraph. + int GetSendFromHostSlot(const string& outside_compilation_subgraph_name, + const Edge* edge) const; + + // Creates an _Arg node for the src node of edge, and add its index to + // args_by_src_, if none exists yet. Also adds its index to args_by_dst_, + // and adds the edge within the subgraph from the _Arg node to the image of + // the dst node. + Status RecordArg(const Edge* edge, + const std::unordered_map& node_images, + std::vector>* src_arg_pairs); + + // Creates a _Retval node for the src node of edge, and add it to results_, + // if none exists yet. If a new _Retval node is created, also adds the edge + // within the subgraph from the src to the _Retval node. + Status RecordResult( + const Edge* edge, + const std::unordered_map& node_images); + + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Creates an entry for the src node of edge in the list of + // inputs for the outside_compilation subgraph, if none exists yet. + void RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge); + + // Creates an outside_compilation subgraph for outside_compilation_id if + // none exists yet. Creates an entry for the src node of edge in the list of + // outputs by src for the outside_compilation subgraph, if none exists + // yet. Creates an entry for the dst node of edge in the list of outputs by + // dst for the outside_compilation subgraph. + void RecordOutsideCompilationOutputOrControl( + const string& outside_compilation_id, const Edge* edge); + + // Adds the HostCompute nodes for each outside_compilation subgraph. + Status AddHostComputes( + const string& subgraph_name, + const std::unordered_map& node_images); + + // Creates the sequencer node if it doesn't exist, adding it to graph_out. + Status MakeSequencingNode(const string& subgraph_name, Graph* graph_out); + + // If there is a sequencer node, adds a control edge from the sequencer to + // all the downstream nodes of call_node_outputs. + void ConnectSequencerToOutputs(Graph* graph_out); + + Status AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph); + + Status ReplaceFunctionDef(FunctionLibraryDefinition* library); + + private: + struct OutsideCompilationSubgraph { + // Map from source (producer node/slot) tensors in the original graph to + // input index (slot number in the HostCompute/RecvAtHost nodes that will + // be created) for the outside_compilation subgraph. + std::unordered_map inputs; + + // Set of nodes in the original graph that are the source of control edges + // that cross from the containing compiled subgraph into the + // outside_compilation subgraph. These are recorded by + // RecordOutsideCompilationInputOrControl while walking all the subgraph + // edges, and lifted control edges within the subgraph are added by + // AddSendsToOutsideCompilation once the _HostCompute node has been + // created. The matching control edge from _RecvAtHost to the + // destination is added by CopyEdgeToOutputGraph. + std::unordered_set control_inputs; + + // Maps from source (producer node/slot) and destination (consumer + // node/slot) tensors in the original graph to output index (slot number + // in the SendFromHost/HostCompute nodes that will be created) for the + // outside_compilation subgraph. + std::unordered_map outputs_by_src; + std::unordered_map outputs_by_dst; + + // Set of nodes in the original graph that are the destination of control + // edges that cross from the outside_compilation subgraph into the + // containing compiled subgraph. These are recorded by + // RecordOutsideCompilationOutputOrControl while walking all the subgraph + // edges, and lifted control edges within the subgraph are added by + // AddRecvsFromToOutsideCompilation once the _HostCompute node has been + // created. The matching control edge from the source to _SendFromHost to + // the destination is added by CopyEdgeToOutputGraph. + std::unordered_set control_outputs; + + // Name of the _HostCompute node in the subgraph. + string host_compute_name; + + // _RecvAtHost node in the output graph. Not owned. + Node* recv_at_host = nullptr; + + // _SendFromHost node in the output graph. Not owned. + Node* send_from_host = nullptr; + }; + + // Builds a ParallelCheck op that compares the output of the original + // subgraph with the encapsulated subgraph. + Status BuildParallelCheckOp( + const std::unordered_map& node_images, + Graph* graph_out); + + // Builds a _RecvAtHost node producing all the inputs of an + // outside_compilation subgraph and stores it in oc_subgraph.recv_at_host. + Status AddRecvAtHostNode(const string& subgraph_name, + const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, + Graph* graph_out); + + // Builds a _SendFromHost node consuming all the outputs of an + // outside_compilation subgraph and stores it in oc_subgraph.send_from_host. + Status AddSendFromHostNode( + const std::unordered_map& node_images, + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out); + // The subgraph extracted from the input graph, suitable for being turned // into a FunctionDef. Inputs are fed by _Arg nodes, and outputs are // returned by _Retval nodes. - std::unique_ptr graph; + std::unique_ptr graph_; // Which device are these nodes on? Used to assign a device to the call // node. - string device; + string device_; // NodeDef for the function call node. - NodeDef call_node_def; + NodeDef call_node_def_; // Function call node(s) in the output graph. Not owned. // If parallel_checking is enabled, 'call_node_inputs' is the function call // node to which inputs should be fed, and 'call_node_outputs' is the // parallel check op from which outputs should be read. If parallel checking // is disabled, both point to the function call node. - Node* call_node_inputs; - Node* call_node_outputs; + Node* call_node_inputs_; + Node* call_node_outputs_; // Maps from source (producer node/slot) and destination // (consumer node/slot) tensors in the input graph to _Arg numbers in // the subgraph. The source map is one-to-one, whereas the dest map may be // many-to-one. - std::unordered_map args_by_src; - std::unordered_map args_by_dst; + std::unordered_map args_by_src_; + std::unordered_map args_by_dst_; // The _Arg nodes in the subgraph, in order by argument number. - std::vector args; + std::vector args_; // Map from source tensor in the input graph to result #. - std::unordered_map results; + std::unordered_map results_; + + // The outside_compilation clusters in this subgraph. + std::unordered_map + outside_compilation_subgraphs_; + + // NoOp node in the output graph that is sequenced after the call node and + // used to prevent host-side outside_compilation sends and recvs from being + // pruned. + Node* sequencer_ = nullptr; }; - // Builds a ParallelCheck op that compares the output of the original subgraph - // with the encapsulated subgraph. - Status BuildParallelCheckOp( + // Returns the key attribute and outside_compilation attribute associated + // with a node in attr, and outside_compilation_attr, respectively. Sets + // either result to the empty string if the respective attribute is not + // found. Returns error status if there is an outside_compilation attribute + // and no key attribute, + Status GetFunctionNameAttr(Node const* node, string* attr, + string* outside_compilation_attr) const; + + // Copies edges local to a subgraph. Adds _Arg and _Retval nodes to + // subgraphs for data edges that cross subgraph boundaries. + Status CopySubgraphEdges( + const std::unordered_map& node_images, + std::vector>* src_arg_pairs); + + // Copies all marked nodes to a subgraph. Does nothing for unmarked nodes, + // or nodes marked outside_compilation. + Status CopySubgraphNodes(std::unordered_map* node_images); + + // Copies all nodes that aren't in a compiled subgraph to the output graph. + Status CopyNodesToOutputGraph( + bool parallel_checking, Graph* graph_out, + std::unordered_map* node_images); + + // Adds function call nodes for each compiled subgraph. + Status AddFunctionCallNodes( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); + + // Adds _RecvAtHost and _SendFromHost nodes, where needed, for all + // outside_compilation subgraphs. + Status AddOutsideCompilationHostIONodes( + const std::unordered_map& node_images, + Graph* graph_out); + + // Finds the image of an edge source in the output graph. If the edge crosses + // a subgraph boundary it is the output of a call node, otherwise it is a node + // in the output graph. + Status FindOutputImageOfEdgeSrc( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_src_node, Node** src_image); + + // Finds an edge source slot in the output graph. If the edge crosses a + // subgraph boundary it is a slot on the output of a call node or a + // _RecvAtHost node, otherwise it is a slot on a node in the output graph. + int FindOutputSlotOfEdgeSrc(const string& src_func_id, + const string& src_outside_compilation_id, + const string& dst_func_id, + const string& dst_outside_compilation_id, + const Edge* edge); + + // Finds the image of an edge destination in the output graph. If the edge + // crosses a subgraph boundary it is the input of a call node or a + // _SendFromHost node, otherwise it is a node in the output graph. + Status FindOutputImageOfEdgeDst( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_dst_node, Node** dst_image); + + // Finds an edge destination slot in the output graph. If the edge crosses a + // subgraph boundary it is a slot on the input of a call node or a + // _SendFromHost node, otherwise it is a slot on a node in the output graph. + int FindOutputSlotOfEdgeDst(const string& src_func_id, + const string& src_outside_compilation_id, + const string& dst_func_id, + const string& dst_outside_compilation_id, + const Edge* edge); + + // Copies a single edge to the output graph. The edge is either entirely + // within the output graph, or crosses into or out of a compiled subgraph. + Status CopyEdgeToOutputGraph( + const Edge* edge, const string& src_func_id, + const string& src_outside_compilation_id, const string& dst_func_id, + const string& dst_outside_compilation_id, const std::unordered_map& node_images, - const Subgraph& subgraph, Graph* graph_out, Node** parallel_check_op); + bool parallel_checking, Graph* graph_out, + std::unordered_set, NodeSlot::PairHasher>* + edges_added); + + // Adds all edges to the output graph. + Status AddEdgesToOutputGraph( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out); + + // Constructs a minimal shape inference graph that can be used to determine + // the shape of send_node at the time that the subgraph is compiled. + // recv_at_host_nodes contains the names of all the recv_at_host nodes that + // send_node might depend on. These recv_at_host nodes have shapes that are + // not known during the rewrite pass, but will be known at compile time. + // + // If the shapes of all the inputs to send_node can be determined during the + // rewrite pass, on exit graphdef_out is empty and the shapes are returned in + // static_shape_out. Otherwise graphdef_out contains a graph that can be used + // for shape inference at compile time, where all the source nodes of the + // graph are either constants with known shapes, or nodes named in + // recv_at_host_nodes. + // + // A non-OK status is returned if neither of the above conditions can be + // satisfied, e.g., because send_node depends on a node that doesn't have a + // registered shape inference function. + Status DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out); + + // Makes a copy of graph containing only nodes that are ancestors of at least + // one node in send_from_host_nodes and store it in pruned_graph. On exit + // nodes_images contains a mapping from nodes in graph to nodes in + // pruned_graph. All functions in the copied graph are inlined. + Status MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Makes a copy of graph containing only nodes that are ancestors of a + // send_from_host node in an outside_compilation subgraph, and store it in + // pruned_graph. Also perform shape inference on the pruned graph, using + // shape_refiner. On exit node_images contains a mapping from nodes in graph + // to nodes in pruned_graph. + Status MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library); + + // Performs static shape inference, as far as possible, for the send_from_host + // nodes in each outside_compilation subgraph. Where it is not possible to + // determine the shape statically, stores a serialized GraphDef in the + // HostCompute 'shape_inference_graph' attr, to be used at compile time for + // final inference. If the shapes are known statically they are stored in the + // HostCompute 'shapes' attr. + Status GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library); const string group_attribute_; + const string outside_compilation_attribute_; const Graph* graph_in_; std::unordered_map subgraphs_; @@ -154,224 +592,401 @@ class Encapsulator { TF_DISALLOW_COPY_AND_ASSIGN(Encapsulator); }; -// TODO(phawkins) add a canonical copy of these operator names and refactor -// everything to use it. -static const char* const kArgOp = "_Arg"; -static const char* const kRetValOp = "_Retval"; +Node* Encapsulator::Subgraph::GetCallNodeForInputs() const { + return call_node_inputs_; +} -// Returns the function name attached to 'node', or the empty string if there is -// none. -string Encapsulator::GetFunctionNameAttr(Node const* node) const { - string attr; - if (!GetNodeAttr(node->attrs(), group_attribute_, &attr).ok()) { - attr.clear(); - } - return attr; +Node* Encapsulator::Subgraph::GetCallNodeForOutputs() const { + return call_node_outputs_; } -Status Encapsulator::SplitIntoSubgraphs() { - Status s; +int Encapsulator::Subgraph::GetArgIndexForEdge(const Edge* edge) const { + return args_by_dst_.at(NodeSlot(edge->dst(), edge->dst_input())); +} - // Map from input graph nodes to subgraph nodes. - std::unordered_map node_images; +int Encapsulator::Subgraph::GetResultIndexForEdge(const Edge* edge) const { + return results_.at(NodeSlot(edge->src(), edge->src_output())); +} - // Copy all marked nodes to a subgraph. Do nothing for unmarked nodes. - for (Node* node : graph_in_->op_nodes()) { - string func_id = GetFunctionNameAttr(node); - if (func_id.empty()) continue; +Node* Encapsulator::Subgraph::GetRecvAtHostNode( + const string& outside_compilation_subgraph_name) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .recv_at_host; +} - Subgraph& subgraph = subgraphs_[func_id]; - if (!subgraph.graph) { - subgraph.graph.reset(new Graph(graph_in_->op_registry())); - subgraph.graph->set_versions(graph_in_->versions()); - } +int Encapsulator::Subgraph::GetRecvAtHostSlot( + const string& outside_compilation_subgraph_name, const Edge* edge) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .inputs.at(NodeSlot(edge->src(), edge->src_output())); +} - Node* image = subgraph.graph->CopyNode(node); - image->ClearAttr(group_attribute_); - node_images[node] = image; +Node* Encapsulator::Subgraph::GetSendFromHostNode( + const string& outside_compilation_subgraph_name) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .send_from_host; +} - if (subgraph.device.empty()) { - subgraph.device = node->assigned_device_name().empty() - ? node->requested_device() - : node->assigned_device_name(); - } +int Encapsulator::Subgraph::GetSendFromHostSlot( + const string& outside_compilation_subgraph_name, const Edge* edge) const { + return outside_compilation_subgraphs_.at(outside_compilation_subgraph_name) + .outputs_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); +} + +Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) { + if (!graph_) { + graph_.reset(new Graph(graph_in->op_registry())); + graph_->set_versions(graph_in->versions()); } - // Copy edges local to a subgraph. Add _Arg and _Retval nodes to subgraphs for - // data edges that cross subgraph boundaries. - for (const Edge* edge : graph_in_->edges()) { - string src_func_id = GetFunctionNameAttr(edge->src()); - string dst_func_id = GetFunctionNameAttr(edge->dst()); - Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); - Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); + if (device_.empty()) { + device_ = node->assigned_device_name().empty() + ? node->requested_device() + : node->assigned_device_name(); + } - // Copy edges that are local to a subgraph. - if (!src_func_id.empty() && src_func_id == dst_func_id) { - Graph* g = subgraphs_[src_func_id].graph.get(); - if (edge->IsControlEdge()) { - g->AddControlEdge(src_image, dst_image); - } else { - g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); - } - continue; - } + return graph_->CopyNode(node); +} - // Ignore cross-boundary control edges for right now. We will lift them - // onto the enclosing call operators in BuildOutputGraph(). - if (edge->IsControlEdge()) continue; +Graph* Encapsulator::Subgraph::GetGraph() const { return graph_.get(); } + +Status Encapsulator::Subgraph::RecordArg( + const Edge* edge, const std::unordered_map& node_images, + std::vector>* src_arg_pairs) { + Node* src_node = edge->src(); + int src_slot = edge->src_output(); + std::unordered_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = + args_by_src_.emplace(NodeSlot(src_node, src_slot), args_by_src_.size()); + int arg_index = iter->second; + if (inserted) { + NodeDef arg_def; + NodeDefBuilder builder( + strings::StrCat(src_node->name(), "_", src_slot, "_arg"), kArgOp); + DataType dtype = edge->dst()->input_type(edge->dst_input()); + builder.Attr("T", dtype); + builder.Attr("index", arg_index); + Status s = builder.Finalize(&arg_def); + if (!s.ok()) return s; - // Add 'src' as an output of its subgraph, if applicable. - if (!src_func_id.empty()) { - Subgraph& src_subgraph = subgraphs_[src_func_id]; - int ret_index = src_subgraph.results.size(); - if (src_subgraph.results - .emplace(NodeSlot(edge->src(), edge->src_output()), ret_index) - .second) { - // Create a new _Retval node - DataType dtype = edge->src()->output_type(edge->src_output()); + Node* arg = graph_->AddNode(arg_def, &s); + if (!s.ok()) return s; - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported: tensor ", - edge->src()->name(), ":", edge->src_output()); - } + src_arg_pairs->push_back({src_node, arg}); + args_.push_back(arg); + } + Node* dst_node = edge->dst(); + Node* dst_image = node_images.at(dst_node); + int dst_slot = edge->dst_input(); + args_by_dst_[NodeSlot(dst_node, dst_slot)] = arg_index; + graph_->AddEdge(args_[arg_index], 0, dst_image, dst_slot); + return Status::OK(); +} + +Status Encapsulator::Subgraph::RecordResult( + const Edge* edge, + const std::unordered_map& node_images) { + Node* src_node = edge->src(); + Node* src_image = node_images.at(src_node); + int src_slot = edge->src_output(); + std::unordered_map::iterator iter; + bool inserted; + std::tie(iter, inserted) = + results_.emplace(NodeSlot(src_node, src_slot), results_.size()); + int ret_index = iter->second; + if (inserted) { + NodeDef ret_def; + NodeDefBuilder builder( + strings::StrCat(src_node->name(), "_", src_slot, "_retval"), kRetValOp); + DataType dtype = src_node->output_type(src_slot); + builder.Attr("T", dtype); + builder.Attr("index", ret_index); + builder.Input(src_image->name(), src_slot, dtype); + Status s = builder.Finalize(&ret_def); + if (!s.ok()) return s; + Node* ret = graph_->AddNode(ret_def, &s); + if (!s.ok()) return s; - NodeDef ret_def; - ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat(edge->src()->name(), "_", - edge->src_output(), "_retval")); - AddNodeAttr("T", dtype, &ret_def); - AddNodeAttr("index", ret_index, &ret_def); - Node* ret = src_subgraph.graph->AddNode(ret_def, &s); - if (!s.ok()) return s; - - // Add an edge from 'src' to _Retval. - src_subgraph.graph->AddEdge(src_image, edge->src_output(), ret, 0); + graph_->AddEdge(src_image, src_slot, ret, 0); + } + return Status::OK(); +} + +void Encapsulator::Subgraph::RecordOutsideCompilationInputOrControl( + const string& outside_compilation_id, const Edge* edge) { + auto iter = outside_compilation_subgraphs_ + .emplace(outside_compilation_id, OutsideCompilationSubgraph()) + .first; + OutsideCompilationSubgraph& outside_subgraph = iter->second; + if (edge->IsControlEdge()) { + outside_subgraph.control_inputs.insert(edge->src()); + } else { + int input_index = outside_subgraph.inputs.size(); + outside_subgraph.inputs.emplace(NodeSlot(edge->src(), edge->src_output()), + input_index); + } +} + +void Encapsulator::Subgraph::RecordOutsideCompilationOutputOrControl( + const string& outside_compilation_id, const Edge* edge) { + auto subgraph_iter = + outside_compilation_subgraphs_ + .emplace(outside_compilation_id, OutsideCompilationSubgraph()) + .first; + OutsideCompilationSubgraph& outside_subgraph = subgraph_iter->second; + if (edge->IsControlEdge()) { + outside_subgraph.control_outputs.insert(edge->dst()); + } else { + DataType dtype = edge->dst()->input_type(edge->dst_input()); + auto output_iter = + outside_subgraph.outputs_by_src + .emplace(NodeSlot(edge->src(), edge->src_output(), dtype), + outside_subgraph.outputs_by_src.size()) + .first; + int output_index = output_iter->second; + outside_subgraph.outputs_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = + output_index; + } +} + +Status Encapsulator::Subgraph::AddHostComputes( + const string& subgraph_name, + const std::unordered_map& node_images) { + for (auto& oc_subgraph_iter : outside_compilation_subgraphs_) { + const string& oc_subgraph_name = oc_subgraph_iter.first; + OutsideCompilationSubgraph& oc_subgraph = oc_subgraph_iter.second; + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty() || + !oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + // Build a _HostCompute node. + std::vector inputs(oc_subgraph.inputs.size()); + std::vector input_dtypes(oc_subgraph.inputs.size(), DT_INVALID); + std::vector output_dtypes(oc_subgraph.outputs_by_src.size(), + DT_INVALID); + + for (const auto& input_src : oc_subgraph.inputs) { + const Node* src_node = input_src.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = input_src.first.slot; + int input_index = input_src.second; + + DataType dtype = src_node->output_type(src_slot); + inputs[input_index].Reset(src_image->name(), src_slot, dtype); + input_dtypes[input_index] = dtype; } - } - // Add 'dst' as an input of its subgraph, if applicable. - if (!dst_func_id.empty()) { - Subgraph& dst_subgraph = subgraphs_[dst_func_id]; + for (const auto& output : oc_subgraph.outputs_by_src) { + DataType dtype = output.first.dtype; + int output_index = output.second; + output_dtypes[output_index] = dtype; + } - // Create an _Arg node for this tensor, if none exists yet. - std::unordered_map::iterator iter; - bool inserted; - std::tie(iter, inserted) = dst_subgraph.args_by_src.emplace( - NodeSlot(edge->src(), edge->src_output()), dst_subgraph.args.size()); - int arg_index = iter->second; - if (inserted) { - // This is the first time we have seen this tensor. Create an _Arg node. - DataType dtype = edge->dst()->input_type(edge->dst_input()); + NodeDef host_compute_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", + oc_subgraph_name, "_host_compute"), + kHostComputeOp); + builder.Input(inputs); + builder.Attr("Tinputs", input_dtypes); + builder.Attr("Toutputs", output_dtypes); + builder.Attr("key", + strings::StrCat("host_compute_channel_", subgraph_name, "_", + oc_subgraph_name)); + Status s = builder.Finalize(&host_compute_def); + if (!s.ok()) return s; + + Node* host_compute = graph_->AddNode(host_compute_def, &s); + if (!s.ok()) return s; + oc_subgraph.host_compute_name = host_compute->name(); + + // Connect the _HostCompute node to its producers in the subgraph. + for (auto& input_src : oc_subgraph.inputs) { + const Node* src_node = input_src.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = input_src.first.slot; + int input_index = input_src.second; + graph_->AddEdge(src_image, src_slot, host_compute, input_index); + } - if (IsRefType(dtype)) { - return errors::InvalidArgument( - "Ref Tensors (e.g., Variables) are not supported: tensor ", - edge->src()->name(), ":", edge->src_output()); - } + // Connect the _HostCompute node to its control edge producers in the + // subgraph. + for (const auto& src_node : oc_subgraph.control_inputs) { + Node* src_image = node_images.at(src_node); + graph_->AddControlEdge(src_image, host_compute); + } - NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat(edge->src()->name(), "_", - edge->src_output(), "_arg"), - kArgOp); - builder.Attr("T", dtype); - builder.Attr("index", arg_index); - s = builder.Finalize(&arg_def); - if (!s.ok()) return s; + // Connect the consumers in the subgraph to the _HostCompute node. + for (const auto& output : oc_subgraph.outputs_by_dst) { + const Node* dst_node = output.first.node; + Node* dst_image = node_images.at(dst_node); + int dst_slot = output.first.slot; + int output_index = output.second; - Node* arg = dst_subgraph.graph->AddNode(arg_def, &s); - if (!s.ok()) return s; + graph_->AddEdge(host_compute, output_index, dst_image, dst_slot); + } - dst_subgraph.args.push_back(arg); + // Connect the control edge consumers in the subgraph to the _HostCompute + // node. + for (const auto& dst_node : oc_subgraph.control_outputs) { + Node* dst_image = node_images.at(dst_node); + graph_->AddControlEdge(host_compute, dst_image); } - // Add an edge from the _Arg node to 'dst' in the subgraph. - dst_subgraph.args_by_dst[NodeSlot(edge->dst(), edge->dst_input())] = - arg_index; - dst_subgraph.graph->AddEdge(dst_subgraph.args[arg_index], 0, dst_image, - edge->dst_input()); } } - for (auto& entry : subgraphs_) { - FixupSourceAndSinkEdges(entry.second.graph.get()); - } + return Status::OK(); +} - return s; +Status Encapsulator::Subgraph::MakeSequencingNode(const string& subgraph_name, + Graph* graph_out) { + if (sequencer_ == nullptr) { + NodeDef seq_def; + NodeDefBuilder builder(strings::StrCat(subgraph_name, "_sequencer"), + "NoOp"); + Status s = builder.Finalize(&seq_def); + if (!s.ok()) return s; + + sequencer_ = graph_out->AddNode(seq_def, &s); + if (!s.ok()) return s; + sequencer_->set_assigned_device_name(device_); + } + return Status::OK(); } -Status Encapsulator::BuildFunctionDefs( - const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, - FunctionLibraryDefinition* library) { - // For each subgraph, build a FunctionDef. - for (auto& subgraph_entry : subgraphs_) { - string name = subgraph_entry.first; - Subgraph& subgraph = subgraph_entry.second; +void Encapsulator::Subgraph::ConnectSequencerToOutputs(Graph* graph_out) { + if (sequencer_ != nullptr) { + std::unordered_set output_dependencies; + for (Node* node : call_node_outputs_->out_nodes()) { + output_dependencies.insert(node); + } + for (Node* node : output_dependencies) { + graph_out->AddControlEdge(sequencer_, node); + } + } +} - subgraph.call_node_def.set_op(name); - subgraph.call_node_def.set_name(name); - subgraph.call_node_def.set_device(subgraph.device); +Status Encapsulator::Subgraph::BuildFunctionDef( + const string& name_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool reuse_existing_functions, FunctionLibraryDefinition* library) { + // name_in is copied here because name may be modified below if + // rewrite_subgraph_fn is true. + string name = name_in; + call_node_def_.set_op(name); + call_node_def_.set_name(name); + call_node_def_.set_device(device_); + + if (rewrite_subgraph_fn) { + // Initialize the input and output permutations to the identity. + std::vector input_permutation(args_by_src_.size()); + std::iota(input_permutation.begin(), input_permutation.end(), 0); + std::vector output_permutation(results_.size()); + std::iota(output_permutation.begin(), output_permutation.end(), 0); + + TF_RETURN_IF_ERROR(rewrite_subgraph_fn( + &graph_, &input_permutation, &output_permutation, &call_node_def_)); + + // Apply the input/output permutations to the 'args_by_...' and 'results_' + // mappings, so when we build edges in BuildOutputGraph() we + // connect them to the right input/output positions. + if (input_permutation.size() != args_by_src_.size()) { + return errors::InvalidArgument("Input permutation has incorrect size."); + } + if (output_permutation.size() != results_.size()) { + return errors::InvalidArgument("Output permutation has incorrect size."); + } + for (auto& arg : args_by_src_) { + arg.second = input_permutation[arg.second]; + } + for (auto& arg : args_by_dst_) { + arg.second = input_permutation[arg.second]; + } + for (auto& result : results_) { + result.second = output_permutation[result.second]; + } - if (rewrite_subgraph_fn) { - // Initialize the input and output permutations to the identity. - std::vector input_permutation(subgraph.args_by_src.size()); - std::iota(input_permutation.begin(), input_permutation.end(), 0); - std::vector output_permutation(subgraph.results.size()); - std::iota(output_permutation.begin(), output_permutation.end(), 0); + name = call_node_def_.op(); + } - TF_RETURN_IF_ERROR( - rewrite_subgraph_fn(&subgraph.graph, &input_permutation, - &output_permutation, &subgraph.call_node_def)); - - // Apply the input/output permutations to the 'args_by_...' and 'results' - // mappings in 'subgraph', so when we build edges in BuildOutputGraph() we - // connect them to the right input/output positions. - if (input_permutation.size() != subgraph.args_by_src.size()) { - return errors::InvalidArgument("Input permutation has incorrect size."); - } - if (output_permutation.size() != subgraph.results.size()) { - return errors::InvalidArgument( - "Output permutation has incorrect size."); - } - for (auto& arg : subgraph.args_by_src) { - arg.second = input_permutation[arg.second]; - } - for (auto& arg : subgraph.args_by_dst) { - arg.second = input_permutation[arg.second]; - } - for (auto& result : subgraph.results) { - result.second = output_permutation[result.second]; - } + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); - name = subgraph.call_node_def.op(); - } + if (VLOG_IS_ON(1)) { + VLOG(2) << "Build function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("encapsulate_fdef_graph_", name), *graph_, library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("encapsulate_fdef_", name), fdef); + } - FunctionDef fdef; - TF_RETURN_IF_ERROR(GraphToFunctionDef(*subgraph.graph, name, &fdef)); + if (!reuse_existing_functions || library->Find(name) == nullptr) { + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + } + return Status::OK(); +} - if (VLOG_IS_ON(1)) { - VLOG(2) << "Build function def " << name; - dump_graph::DumpGraphToFile( - strings::StrCat("encapsulate_fdef_graph_", name), *subgraph.graph, - library); - dump_graph::DumpFunctionDefToFile( - strings::StrCat("encapsulate_fdef_", name), fdef); +Status Encapsulator::Subgraph::AddShapeInferenceInfo( + const string& outside_compilation_subgraph_name, + const std::vector& shapes, GraphDef* inference_graph) { + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraphs_.at(outside_compilation_subgraph_name); + + Node* host_compute = nullptr; + for (Node* n : graph_->nodes()) { + if (n->name() == oc_subgraph.host_compute_name) { + host_compute = n; + break; } + } + if (host_compute == nullptr) { + return errors::InvalidArgument( + "After rewriting subgraph ", outside_compilation_subgraph_name, + " there is no HostCompute Op for outside compilation subgraph ", + oc_subgraph.host_compute_name); + } - if (!reuse_existing_functions || library->Find(name) == nullptr) { - TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); + if (inference_graph == nullptr) { + host_compute->AddAttr("shape_inference_graph", ""); + host_compute->AddAttr("shapes", shapes); + } else { + string serialized_graph; + if (!inference_graph->SerializeToString(&serialized_graph)) { + return errors::Internal( + "Failed to serialize graph for outside compilation subgraph ", + oc_subgraph.host_compute_name); } + host_compute->AddAttr("shape_inference_graph", serialized_graph); + host_compute->AddAttr("shapes", std::vector()); + } + return Status::OK(); +} + +Status Encapsulator::Subgraph::ReplaceFunctionDef( + FunctionLibraryDefinition* library) { + const string& name = call_node_def_.name(); + + FunctionDef fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*graph_, name, &fdef)); + + if (VLOG_IS_ON(1)) { + VLOG(2) << "Replace function def " << name; + dump_graph::DumpGraphToFile( + strings::StrCat("replace_encapsulate_fdef_graph_", name), *graph_, + library); + dump_graph::DumpFunctionDefToFile( + strings::StrCat("replace_encapsulate_fdef_", name), fdef); } + + TF_RETURN_IF_ERROR(library->RemoveFunction(name)); + TF_RETURN_IF_ERROR(library->AddFunctionDef(fdef)); return Status::OK(); } -Status Encapsulator::BuildParallelCheckOp( +Status Encapsulator::Subgraph::BuildParallelCheckOp( const std::unordered_map& node_images, - const Encapsulator::Subgraph& subgraph, Graph* graph_out, - Node** parallel_check_op) { + Graph* graph_out) { // Build an index mapping output positions to node/slot pairs in the // original graph. - std::vector results_by_num(subgraph.results.size()); - for (const auto& entry : subgraph.results) { + std::vector results_by_num(results_.size()); + for (const auto& entry : results_) { results_by_num[entry.second] = entry.first; } @@ -386,22 +1001,22 @@ Status Encapsulator::BuildParallelCheckOp( expected_outputs[i] = NodeDefBuilder::NodeOut(node_images.at(node_slot.node)->name(), node_slot.slot, result_dtypes[i]); - actual_outputs[i] = NodeDefBuilder::NodeOut(subgraph.call_node_def.name(), - i, result_dtypes[i]); + actual_outputs[i] = + NodeDefBuilder::NodeOut(call_node_def_.name(), i, result_dtypes[i]); } // Assign the parallel check op to a CPU on the same task as the cluster it is // checking. string device, dummy; if (!DeviceNameUtils::SplitDeviceName( - subgraph.call_node_inputs->assigned_device_name(), &device, &dummy)) { + call_node_inputs_->assigned_device_name(), &device, &dummy)) { return errors::InvalidArgument("Could not parse device name"); } strings::StrAppend(&device, "/cpu:0"); NodeDef check_def; TF_RETURN_IF_ERROR( - NodeDefBuilder(graph_out->NewName(strings::StrCat( - subgraph.call_node_def.name(), "_parallel_check")), + NodeDefBuilder(graph_out->NewName(strings::StrCat(call_node_def_.name(), + "_parallel_check")), "ParallelCheck") .Device(device) .Attr("T", result_dtypes) @@ -421,65 +1036,558 @@ Status Encapsulator::BuildParallelCheckOp( const NodeSlot& node_slot = results_by_num[i]; graph_out->AddEdge(node_images.at(node_slot.node), node_slot.slot, check_op, i); - graph_out->AddEdge(subgraph.call_node_inputs, i, check_op, num_results + i); + graph_out->AddEdge(call_node_inputs_, i, check_op, num_results + i); } - *parallel_check_op = check_op; + call_node_outputs_ = check_op; return Status::OK(); } -Status Encapsulator::BuildOutputGraph(bool parallel_checking, - Graph* graph_out) { +Status Encapsulator::Subgraph::AddFunctionCallNode( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { Status s; + call_node_inputs_ = graph_out->AddNode(call_node_def_, &s); + if (!s.ok()) return s; - // Map from nodes in the input graph to nodes in the output graph. + // Copy the assigned device and the key_annotation over. + call_node_inputs_->set_assigned_device_name(device_); + call_node_outputs_ = call_node_inputs_; + + if (parallel_checking) { + TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, graph_out)); + } + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddRecvAtHostNode( + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + std::vector dtypes(oc_subgraph->inputs.size(), DT_INVALID); + + for (const auto& input : oc_subgraph->inputs) { + const Node* src_node = input.first.node; + int src_slot = input.first.slot; + int input_index = input.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[input_index] = dtype; + } + + NodeDef recv_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_recv"), + kRecvAtHostOp); + builder.Attr("Toutputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); + Status s = builder.Finalize(&recv_def); + if (!s.ok()) return s; + + oc_subgraph->recv_at_host = graph_out->AddNode(recv_def, &s); + if (!s.ok()) return s; + oc_subgraph->recv_at_host->set_assigned_device_name(device_); + + // Add a control dependency forcing the RecvAtHost to run before the subgraph + // completes. This has no effect on execution order but prevents the + // RecvAtHost being pruned. + TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); + graph_out->AddControlEdge(oc_subgraph->recv_at_host, sequencer_); + + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddSendFromHostNode( + const std::unordered_map& node_images, + const string& subgraph_name, const string& oc_subgraph_name, + OutsideCompilationSubgraph* oc_subgraph, Graph* graph_out) { + std::vector dtypes(oc_subgraph->outputs_by_src.size(), DT_INVALID); + std::vector inputs( + oc_subgraph->outputs_by_src.size()); + + for (const auto& output : oc_subgraph->outputs_by_src) { + const Node* src_node = output.first.node; + Node* src_image = node_images.at(src_node); + int src_slot = output.first.slot; + int output_index = output.second; + + DataType dtype = src_node->output_type(src_slot); + dtypes[output_index] = dtype; + inputs[output_index].Reset(src_image->name(), src_slot, dtype); + } + + NodeDef send_def; + NodeDefBuilder builder(strings::StrCat("outside_compilation_", subgraph_name, + "_", oc_subgraph_name, "_send"), + kSendFromHostOp); + builder.Attr("Tinputs", dtypes); + builder.Attr("key", strings::StrCat("host_compute_channel_", subgraph_name, + "_", oc_subgraph_name)); + builder.Input(inputs); + Status s = builder.Finalize(&send_def); + if (!s.ok()) return s; + + oc_subgraph->send_from_host = graph_out->AddNode(send_def, &s); + if (!s.ok()) return s; + oc_subgraph->send_from_host->set_assigned_device_name(device_); + + // Add a control dependency forcing the SendFromHost to run before the + // subgraph completes. This has no effect on execution order but prevents the + // RecvAtHost being pruned. + TF_RETURN_IF_ERROR(MakeSequencingNode(subgraph_name, graph_out)); + graph_out->AddControlEdge(oc_subgraph->send_from_host, sequencer_); + + return Status::OK(); +} + +Status Encapsulator::Subgraph::AddOutsideCompilationHostIONodes( + const string& subgraph_name, + const std::unordered_map& node_images, + Graph* graph_out) { + for (auto& outside_compilation_subgraph_entry : + outside_compilation_subgraphs_) { + const string& oc_name = outside_compilation_subgraph_entry.first; + OutsideCompilationSubgraph& oc_subgraph = + outside_compilation_subgraph_entry.second; + + if (!oc_subgraph.inputs.empty() || !oc_subgraph.control_inputs.empty()) { + TF_RETURN_IF_ERROR( + AddRecvAtHostNode(subgraph_name, oc_name, &oc_subgraph, graph_out)); + } + + if (!oc_subgraph.outputs_by_src.empty() || + !oc_subgraph.control_outputs.empty()) { + TF_RETURN_IF_ERROR(AddSendFromHostNode(node_images, subgraph_name, + oc_name, &oc_subgraph, graph_out)); + } + } + return Status::OK(); +} + +void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames( + std::vector* names) const { + for (auto& entry : outside_compilation_subgraphs_) { + names->push_back(entry.first); + } +} + +Status Encapsulator::GetFunctionNameAttr( + Node const* node, string* attr, string* outside_compilation_attr) const { + Status s = GetNodeAttr(node->attrs(), group_attribute_, attr); + if (s.code() == error::Code::NOT_FOUND) { + // Return empty attr if there's no group_attribute. + attr->clear(); + } else { + TF_RETURN_IF_ERROR(s); + } + bool has_group_attr = s.ok(); + s = GetNodeAttr(node->attrs(), outside_compilation_attribute_, + outside_compilation_attr); + if (s.code() == error::Code::NOT_FOUND) { + // Return empty attr if there's no outside_compilation attribute. + outside_compilation_attr->clear(); + } else { + TF_RETURN_IF_ERROR(s); + if (!has_group_attr) { + return errors::InvalidArgument( + "Node ", node->name(), " has ", outside_compilation_attribute_, + " attribute but no ", group_attribute_, " attribute."); + } + } + return Status::OK(); +} + +bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) { + return !func_id.empty() && outside_compilation_id.empty(); +} + +Status Encapsulator::CopySubgraphNodes( + std::unordered_map* node_images) { + for (Node* node : graph_in_->op_nodes()) { + string func_id; + string outside_compilation_id; + TF_RETURN_IF_ERROR( + GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); + if (!IsInSubgraph(func_id, outside_compilation_id)) continue; + + Subgraph& subgraph = subgraphs_[func_id]; + Node* image = subgraph.MakeNodeImage(graph_in_, node); + image->ClearAttr(group_attribute_); + (*node_images)[node] = image; + } + return Status::OK(); +} + +Status Encapsulator::CopySubgraphEdges( + const std::unordered_map& node_images, + std::vector>* src_arg_pairs) { + for (const Edge* edge : graph_in_->edges()) { + string src_func_id; + string src_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, + &src_outside_compilation_id)); + string dst_func_id; + string dst_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, + &dst_outside_compilation_id)); + Node* src_image = gtl::FindWithDefault(node_images, edge->src(), nullptr); + Node* dst_image = gtl::FindWithDefault(node_images, edge->dst(), nullptr); + + // Copy edges that are local to a subgraph. + if (IsInSubgraph(src_func_id, src_outside_compilation_id) && + IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + src_func_id == dst_func_id) { + Graph* g = subgraphs_[src_func_id].GetGraph(); + if (edge->IsControlEdge()) { + g->AddControlEdge(src_image, dst_image); + } else { + g->AddEdge(src_image, edge->src_output(), dst_image, edge->dst_input()); + } + continue; + } + + // Record 'src' as an output of its subgraph, if applicable. + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + if (!edge->IsControlEdge()) { + DataType dtype = edge->src()->output_type(edge->src_output()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as results: " + "tensor ", + edge->src()->name(), ":", edge->src_output()); + } + } + + Subgraph& src_subgraph = subgraphs_[src_func_id]; + if (src_func_id == dst_func_id) { + // src is in the subgraph and dst is outside_compilation in the same + // subgraph. + src_subgraph.RecordOutsideCompilationInputOrControl( + dst_outside_compilation_id, edge); + } else { + // Ignore control edges leaving the subgraph. We will lift them onto the + // enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR(src_subgraph.RecordResult(edge, node_images)); + } + } + } + + // Record 'dst' as an input of its subgraph, if applicable. + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + // Look at the type of the destination not the source, since Ref output + // Tensors can be automatically cast to non-Ref Tensors at the + // destination. + if (!edge->IsControlEdge()) { + DataType dtype = edge->dst()->input_type(edge->dst_input()); + if (IsRefType(dtype)) { + return errors::InvalidArgument( + "Ref Tensors (e.g., Variables) are not supported as args: " + "tensor ", + edge->src()->name(), ":", edge->src_output()); + } + } + + Subgraph& dst_subgraph = subgraphs_[dst_func_id]; + if (src_func_id == dst_func_id) { + // dst is in the subgraph and src is outside_compilation in the same + // subgraph. + dst_subgraph.RecordOutsideCompilationOutputOrControl( + src_outside_compilation_id, edge); + } else { + // Ignore control edges entering the subgraph. We will lift them onto + // the enclosing call operators in BuildOutputGraph(). + if (!edge->IsControlEdge()) { + TF_RETURN_IF_ERROR( + dst_subgraph.RecordArg(edge, node_images, src_arg_pairs)); + } + } + } + } + return Status::OK(); +} + +Status Encapsulator::SplitIntoSubgraphs() { + Status s; + + // Map from input graph nodes to subgraph nodes. std::unordered_map node_images; - // Copy all unmarked nodes to the output graph. + // Each entry of src_arg_pairs is a pair whose first element is a node in the + // original graph that has an output edge in the subgraph, and whose second + // element is the arg node in the subgraph that it sends to. The vector will + // be filled in below in AddArgs. + std::vector> src_arg_pairs; + + TF_RETURN_IF_ERROR(CopySubgraphNodes(&node_images)); + TF_RETURN_IF_ERROR(CopySubgraphEdges(node_images, &src_arg_pairs)); + + // For each subgraph, add the nodes that deal with inputs and outputs its + // nested outside_compilation subgraphs. These could not be added earlier + // during CopySubgraphEdges since we need to discover all the types of the + // inputs and outputs for an outside_compilation subgraph before creating a + // single input and output node for it. + for (auto& entry : subgraphs_) { + Subgraph& subgraph = entry.second; + TF_RETURN_IF_ERROR(subgraph.AddHostComputes(entry.first, node_images)); + } + + MarkGuaranteedConstants(*graph_in_, src_arg_pairs); + + for (auto& entry : subgraphs_) { + Subgraph& subgraph = entry.second; + FixupSourceAndSinkEdges(subgraph.GetGraph()); + } + + return s; +} + +Status Encapsulator::BuildFunctionDefs( + const RewriteSubgraphFn& rewrite_subgraph_fn, bool reuse_existing_functions, + FunctionLibraryDefinition* library) { + for (auto& subgraph_entry : subgraphs_) { + string name = subgraph_entry.first; + Subgraph& subgraph = subgraph_entry.second; + TF_RETURN_IF_ERROR(subgraph.BuildFunctionDef( + name, rewrite_subgraph_fn, reuse_existing_functions, library)); + } + return Status::OK(); +} + +Status Encapsulator::CopyNodesToOutputGraph( + bool parallel_checking, Graph* graph_out, + std::unordered_map* node_images) { for (Node* node : graph_in_->op_nodes()) { - string func_id = GetFunctionNameAttr(node); + string func_id; + string outside_compilation_id; + TF_RETURN_IF_ERROR( + GetFunctionNameAttr(node, &func_id, &outside_compilation_id)); // Don't copy nodes that going to be encapsulated, unless parallel checking // is enabled. - if (!func_id.empty() && !parallel_checking) continue; + if (IsInSubgraph(func_id, outside_compilation_id) && !parallel_checking) + continue; Node* image = graph_out->CopyNode(node); - node_images[node] = image; + if (!outside_compilation_id.empty()) { + if (parallel_checking) { + return errors::InvalidArgument( + "Parallel checking is not supported when outside_compilation " + "clusters are present."); + } + image->ClearAttr(group_attribute_); + image->ClearAttr(outside_compilation_attribute_); + } + (*node_images)[node] = image; + } + (*node_images)[graph_in_->source_node()] = graph_out->source_node(); + (*node_images)[graph_in_->sink_node()] = graph_out->sink_node(); + return Status::OK(); +} + +Status Encapsulator::AddFunctionCallNodes( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { + for (auto& subgraph_entry : subgraphs_) { + TF_RETURN_IF_ERROR(subgraph_entry.second.AddFunctionCallNode( + node_images, parallel_checking, graph_out)); } - node_images[graph_in_->source_node()] = graph_out->source_node(); - node_images[graph_in_->sink_node()] = graph_out->sink_node(); + return Status::OK(); +} - // Add function call nodes for each subgraph. +Status Encapsulator::AddOutsideCompilationHostIONodes( + const std::unordered_map& node_images, + Graph* graph_out) { for (auto& subgraph_entry : subgraphs_) { + const string& subgraph_name = subgraph_entry.first; Subgraph& subgraph = subgraph_entry.second; + TF_RETURN_IF_ERROR(subgraph.AddOutsideCompilationHostIONodes( + subgraph_name, node_images, graph_out)); + } + return Status::OK(); +} - subgraph.call_node_inputs = graph_out->AddNode(subgraph.call_node_def, &s); - if (!s.ok()) return s; +Status Encapsulator::FindOutputImageOfEdgeSrc( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_src_node, Node** src_image) { + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + if (dst_func_id == src_func_id) { + // The edge is from a subgraph to an outside_compilation cluster in the + // same subgraph so use the appropriate _RecvAtHost node in the output + // graph. + TF_RET_CHECK(!dst_outside_compilation_id.empty()); + *src_image = subgraphs_.at(src_func_id) + .GetRecvAtHostNode(dst_outside_compilation_id); + } else { + // The edge is from a subgraph to a regular node in the output graph so + // use the subgraph's call node output. + *src_image = subgraphs_.at(src_func_id).GetCallNodeForOutputs(); + } + } else { + // The source of the edge is in the output graph so use the node image in + // the output graph. + *src_image = node_images.at(original_src_node); + } + return Status::OK(); +} - // Copy the assigned device and the key_annotation over. - subgraph.call_node_inputs->set_assigned_device_name(subgraph.device); - subgraph.call_node_outputs = subgraph.call_node_inputs; +int Encapsulator::FindOutputSlotOfEdgeSrc( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const Edge* edge) { + if (IsInSubgraph(src_func_id, src_outside_compilation_id)) { + const Subgraph& src_subgraph = subgraphs_.at(src_func_id); + if (src_func_id == dst_func_id) { + // 'src' is in a subgraph and 'dst' is outside_compilation in the same + // subgraph. Use the corresponding _RecvAtHost output instead. + return src_subgraph.GetRecvAtHostSlot(dst_outside_compilation_id, edge); + } else { + // 'src' is in a subgraph and 'dst' is a regular node in the output + // graph. Use the corresponding call output instead. + return src_subgraph.GetResultIndexForEdge(edge); + } + } else { + // The source of the edge is in the output graph so use the regular edge + // slot. + return edge->src_output(); + } +} +Status Encapsulator::FindOutputImageOfEdgeDst( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + const Node* original_dst_node, Node** dst_image) { + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + if (src_func_id == dst_func_id) { + // The edge is to a subgraph from an outside_compilation cluster in the + // same subgraph so use the appropriate _SendFromHost node in the output + // graph. + TF_RET_CHECK(!src_outside_compilation_id.empty()); + *dst_image = subgraphs_.at(dst_func_id) + .GetSendFromHostNode(src_outside_compilation_id); + } else { + // The edge is to a subgraph from a regular node in the output graph so + // use the subgraph's call node input. + *dst_image = subgraphs_.at(dst_func_id).GetCallNodeForInputs(); + } + } else { + // The destination of the edge is in the output graph so use the node image + // in the output graph. + *dst_image = node_images.at(original_dst_node); + } + return Status::OK(); +} + +int Encapsulator::FindOutputSlotOfEdgeDst( + const string& src_func_id, const string& src_outside_compilation_id, + const string& dst_func_id, const string& dst_outside_compilation_id, + const Edge* edge) { + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id)) { + const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); + if (dst_func_id == src_func_id) { + // 'dst' is in a subgraph and 'src' is outside_compilation in the same + // subgraph. Use the corresponding _SendFromHost input instead. + return dst_subgraph.GetSendFromHostSlot(src_outside_compilation_id, edge); + } else { + // 'dst' is in a subgraph and 'src' is a regular node in the output + // graph. Use the corresponding call input instead. + return dst_subgraph.GetArgIndexForEdge(edge); + } + } else { + // The destination of the edge is in the output graph so use the regular + // edge slot. + return edge->dst_input(); + } +} + +Status Encapsulator::CopyEdgeToOutputGraph( + const Edge* edge, const string& src_func_id, + const string& src_outside_compilation_id, const string& dst_func_id, + const string& dst_outside_compilation_id, + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out, + std::unordered_set, NodeSlot::PairHasher>* + edges_added) { + Node* src_image; + TF_RETURN_IF_ERROR(FindOutputImageOfEdgeSrc( + src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, edge->src(), &src_image)); + Node* dst_image; + TF_RETURN_IF_ERROR(FindOutputImageOfEdgeDst( + src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, edge->dst(), &dst_image)); + + // If this is a control edge then copy it and return. Lift control edges onto + // the enclosing call operator. + if (edge->IsControlEdge()) { + // Add the control edge, if we have not already added it, using the images + // determined above (potentially call operators or RecvAtHost/SendFromHost). + if (edges_added->emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) + .second) { + graph_out->AddControlEdge(src_image, dst_image); + } + + // If parallel checking is enabled, also add a control edge to the + // corresponding parallel check op. if (parallel_checking) { - TF_RETURN_IF_ERROR(BuildParallelCheckOp(node_images, subgraph, graph_out, - &subgraph.call_node_outputs)); + graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); } + return Status::OK(); + } + + int src_output = + FindOutputSlotOfEdgeSrc(src_func_id, src_outside_compilation_id, + dst_func_id, dst_outside_compilation_id, edge); + + int dst_input = + FindOutputSlotOfEdgeDst(src_func_id, src_outside_compilation_id, + dst_func_id, dst_outside_compilation_id, edge); + + if (IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + parallel_checking) { + // If we are parallel checking, also feed the tensor as an input to the + // corresponding parallel check subgraph. + graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), + edge->dst_input()); + } + + // Add the edge, if we have not already added it. + if (edges_added + ->emplace(NodeSlot(src_image, src_output), + NodeSlot(dst_image, dst_input)) + .second) { + graph_out->AddEdge(src_image, src_output, dst_image, dst_input); } + return Status::OK(); +} +Status Encapsulator::AddEdgesToOutputGraph( + const std::unordered_map& node_images, + bool parallel_checking, Graph* graph_out) { // Set of edges already added to the output graph, represented as (src, dst) // pairs. We use the set to deduplicate edges; multiple edges in the input // graph may map to one edge in the output graph. std::unordered_set, NodeSlot::PairHasher> edges_added; - // Add edges to the graph_out graph. for (const Edge* edge : graph_in_->edges()) { - string src_func_id = GetFunctionNameAttr(edge->src()); - string dst_func_id = GetFunctionNameAttr(edge->dst()); + string src_func_id; + string src_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->src(), &src_func_id, + &src_outside_compilation_id)); + string dst_func_id; + string dst_outside_compilation_id; + TF_RETURN_IF_ERROR(GetFunctionNameAttr(edge->dst(), &dst_func_id, + &dst_outside_compilation_id)); // Ignore edges that are strictly contained within one subgraph, unless // we are constructing parallel check graphs. - if (!src_func_id.empty() && src_func_id == dst_func_id) { + if (IsInSubgraph(src_func_id, src_outside_compilation_id) && + IsInSubgraph(dst_func_id, dst_outside_compilation_id) && + src_func_id == dst_func_id) { if (parallel_checking) { Node* src_image = node_images.at(edge->src()); Node* dst_image = node_images.at(edge->dst()); @@ -493,89 +1601,403 @@ Status Encapsulator::BuildOutputGraph(bool parallel_checking, continue; } - // We have an edge that crosses a cluster boundary. - Node* src_image = src_func_id.empty() - ? node_images.at(edge->src()) - : subgraphs_.at(src_func_id).call_node_outputs; - Node* dst_image = dst_func_id.empty() - ? node_images.at(edge->dst()) - : subgraphs_.at(dst_func_id).call_node_inputs; - - // Copy control edges. Lift control edges onto the enclosing call operator. - if (edge->IsControlEdge()) { - // Add the control edge, if we have not already added it. - if (edges_added.emplace(NodeSlot(src_image, -1), NodeSlot(dst_image, -1)) - .second) { - graph_out->AddControlEdge(src_image, dst_image); + // We have an edge that crosses a cluster boundary or is entirely within the + // unclustered graph. + TF_RETURN_IF_ERROR(CopyEdgeToOutputGraph( + edge, src_func_id, src_outside_compilation_id, dst_func_id, + dst_outside_compilation_id, node_images, parallel_checking, graph_out, + &edges_added)); + } + + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + subgraph.ConnectSequencerToOutputs(graph_out); + } + + return Status::OK(); +} + +namespace { + +// Adds a dummy Const node to graph_out. The "constant" has the type of +// data_type and the shape indicated in 'shape'. The dummy node is not a valid +// Const node because it does not have any value defined, but this doesn't +// matter because it will only be used subsequently for shape inference. (It +// would be possible to add a switch statement over data_type to create a value +// for the constant, but that would entail maintaining the logic as new types +// are added, and is not necessary.) +Node* AddDummyShapedNode(DataType data_type, const TensorShapeProto& shape, + Graph* graph_out) { + TensorProto dummy_proto; + dummy_proto.set_dtype(data_type); + *dummy_proto.mutable_tensor_shape() = shape; + // Don't set any value field in the proto, since it is only going to be used + // for shape inference. + + GraphDefBuilder::Options options(graph_out, /*status=*/nullptr); + NodeBuilder node_builder(options.GetNameForOp("KnownShape"), "Const", + options.op_registry()); + node_builder.Attr("dtype", data_type).Attr("value", dummy_proto); + return options.FinalizeBuilder(&node_builder); +} + +// Adds a copy of node_in to graph_out and adds the mapping to +// copied_node_images. +Status CopyShapeInferenceNodeToGraph( + Node* node_in, const Node* send_node, + const std::unordered_map& dummy_node_images, + FunctionLibraryDefinition* library, + std::unordered_map* copied_node_images, Graph* graph_out) { + // Once all the ancestor nodes have been added to graph_out, add this node + // and connect it to its ancestors. + Node* node_out = graph_out->CopyNode(node_in); + (*copied_node_images)[node_in] = node_out; + // Don't bother to build the shape inference graph if there's a node with no + // shape inference function, since it would just result in an error later at + // compile time. + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node_in->type_string(), &op_reg_data)); + if (op_reg_data->shape_inference_fn == nullptr) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because it depends on node ", node_in->name(), + " which does not have a shape inference function registered."); + } + // Add all the edges to the newly copied node. + for (const Edge* in_edge : node_in->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src = in_edge->src(); + const auto iter = dummy_node_images.find(src); + if (iter == dummy_node_images.end()) { + // The src is a copied node so use the original output port. + graph_out->AddEdge((*copied_node_images)[in_edge->src()], + in_edge->src_output(), node_out, + in_edge->dst_input()); + } else { + // The src is a dummy node so use output port 0. + graph_out->AddEdge(iter->second, 0, node_out, in_edge->dst_input()); } + } + } + return Status::OK(); +} - // If parallel checking is enabled, also add a control edge to the - // corresponding parallel check op. - if (parallel_checking) { - graph_out->AddControlEdge(src_image, node_images.at(edge->dst())); +} // namespace + +Status Encapsulator::DoStaticShapeInferenceForOutsideCompilationSend( + const Graph& graph_in, const ShapeRefiner& shape_refiner, + const std::unordered_set& recv_at_host_nodes, Node* send_node, + FunctionLibraryDefinition* library, + std::vector* static_shape_out, + std::unique_ptr* graphdef_out) { + // Maps from nodes in graph_in to nodes in graph_out. + // + // When an edge has fully defined shape the source node in graph_in is + // replaced in graph_out by a dummy constant node. The mapping from nodes + // in graph_in to dummy nodes is stored in dummy_node_images. + // + // When a node in graph_in has at least one ancestor that doesn't have fully + // defined shape, it is copied into graph_out. The mapping from nodes in + // graph_in to copied nodes is stored in copied_node_images. + // + // The two types of node are treated differently because, when adding edges to + // graph_out, an output from a dummy node always uses port 0, whereas an + // output from a copied node uses the same port that was used in graph_in. + std::unordered_map dummy_node_images; + std::unordered_map copied_node_images; + + std::unique_ptr graph_out(new Graph(graph_in.op_registry())); + graph_out->set_versions(graph_in.versions()); + static_shape_out->resize(send_node->num_inputs()); + + // We don't use the standard ReverseDFS because we want to cut off traversal + // whenever we find an output with fully defined shape. + // TODO(misard) make this work properly in the presence of control flow. + struct Work { + Node* node; + bool leave; // Are we entering or leaving node? + }; + std::vector stack({{send_node, false}}); + std::vector visited(graph_in.num_node_ids(), false); + while (!stack.empty()) { + Work w = stack.back(); + stack.pop_back(); + Node* n = w.node; + + if (w.leave) { + TF_RETURN_IF_ERROR(CopyShapeInferenceNodeToGraph( + n, send_node, dummy_node_images, library, &copied_node_images, + graph_out.get())); + } else { + if (visited[n->id()]) continue; + visited[n->id()] = true; + + // Arrange to revisit when all done with all inputs. + stack.push_back(Work{n, true}); + + bool has_parent_with_unknown_shape = false; + for (const Edge* in_edge : n->in_edges()) { + if (!in_edge->IsControlEdge()) { + Node* src_node = in_edge->src(); + int src_port = in_edge->src_output(); + shape_inference::InferenceContext* context = + shape_refiner.GetContext(src_node); + shape_inference::ShapeHandle shape = context->output(src_port); + if (context->FullyDefined(shape)) { + // This ancestor has known shape, so instead of adding it to the + // stack, add a dummy node with that shape to graph_out and + // continue. + TensorShapeProto proto; + context->ShapeHandleToProto(shape, &proto); + dummy_node_images[src_node] = AddDummyShapedNode( + src_node->output_type(src_port), proto, graph_out.get()); + if (n == send_node) { + (*static_shape_out)[in_edge->dst_input()] = proto; + } + } else { + if (!visited[src_node->id()]) { + has_parent_with_unknown_shape = true; + stack.push_back({src_node, false}); + } + } + } + } + if (!has_parent_with_unknown_shape) { + if (n == send_node) { + // The shapes of all the inputs to send_node are statically known. We + // won't have to do any inference at compile time so return now: the + // shapes were stored in static_shape_out above. + graphdef_out->reset(); + return Status::OK(); + } else { + // Any shape that is being processed is either the original send node + // or has at least one output with statically-unknown shape. If the + // latter and it doesn't have any inputs with statically-unknown + // shape, then check that it is of the recv nodes that we can fill in + // the shape of at run-time later. If it isn't one of those, then we + // won't have any additional knowledge at compile time, so we already + // know we won't be able to do shape inference and we can return an + // error now. + if (recv_at_host_nodes.find(n->name()) == recv_at_host_nodes.end()) { + return errors::InvalidArgument( + "Shape inference is not possible for outside_compilation " + "SendFromHost node ", + send_node->name(), " because shape of node ", n->name(), + " will not be known at compilation time."); + } + } } - continue; } + } + + graphdef_out->reset(new GraphDef()); + graph_out->ToGraphDef(graphdef_out->get()); - int src_output = edge->src_output(); - if (!src_func_id.empty()) { - // 'src' is in a subgraph. Use the corresponding call output instead. - const Subgraph& src_subgraph = subgraphs_.at(src_func_id); - src_output = - src_subgraph.results.at(NodeSlot(edge->src(), edge->src_output())); + return Status::OK(); +} + +Status Encapsulator::MakePrunedGraphCopyAndInline( + const Graph& graph, const std::vector& sink_nodes, + std::unique_ptr* pruned_graph, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // First copy all ancestor nodes of sink_nodes into a new graph. + pruned_graph->reset(new Graph(library)); + (*pruned_graph)->set_versions(graph.versions()); + ReverseDFSFrom(graph, sink_nodes, + /*enter=*/nullptr, + /*leave=*/[&](Node* n) { + if (!n->IsSource()) { + Node* copied = (*pruned_graph)->CopyNode(n); + node_images->emplace(n, copied); + } + }); + + // Add all the edges between copied nodes. + for (auto entry : *node_images) { + const Node* orig = entry.first; + Node* image = entry.second; + for (const Edge* out_edge : orig->out_edges()) { + auto iter = node_images->find(out_edge->dst()); + if (iter != node_images->end()) { + // The source and destination are both in the copied graph. + (*pruned_graph) + ->AddEdge(image, out_edge->src_output(), iter->second, + out_edge->dst_input()); + } } + } - int dst_input = edge->dst_input(); + // Find all the function call nodes, and inline them. + std::vector function_nodes; + for (auto node : (*pruned_graph)->nodes()) { + const OpRegistrationData* op_reg_data; + TF_RETURN_IF_ERROR(library->LookUp(node->type_string(), &op_reg_data)); + if (op_reg_data->is_function_op) { + function_nodes.push_back(node); + } + } + for (auto node : function_nodes) { + VLOG(2) << "Inlining function " << node->name(); + const FunctionDef* fdef = library->Find(node->type_string()); + if (fdef == nullptr) { + return errors::Internal("Failed to find function ", node->type_string(), + " in function library."); + } + FunctionBody* fbody = nullptr; + TF_RETURN_IF_ERROR( + FunctionDefToBodyHelper(*fdef, node->attrs(), library, + [library](const string& op, const OpDef** sig) { + return library->LookUpOpDef(op, sig); + }, + &fbody)); + InlineFunctionBody(*library, pruned_graph->get(), node, fbody); + delete fbody; + } - if (!dst_func_id.empty()) { - // 'dst' is in a subgraph. Use the corresponding call input instead. - const Subgraph& dst_subgraph = subgraphs_.at(dst_func_id); - dst_input = - dst_subgraph.args_by_dst.at(NodeSlot(edge->dst(), edge->dst_input())); + return Status::OK(); +} - // If we are parallel checking, also feed the tensor as an input to the - // corresponding parallel check subgraph. - if (parallel_checking) { - graph_out->AddEdge(src_image, src_output, node_images.at(edge->dst()), - edge->dst_input()); +Status Encapsulator::MakeGraphForOutsideCompilationSends( + const Graph& graph, std::unique_ptr* pruned_graph, + ShapeRefiner* shape_refiner, + std::unordered_map* node_images, + FunctionLibraryDefinition* library) { + // Find all the send_from_host nodes in all subgraphs, to use as roots for the + // pruning. + std::vector send_from_host_nodes; + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + if (send_node != nullptr) { + send_from_host_nodes.push_back(send_node); } } - // Add the edge, if we have not already added it. - if (edges_added - .emplace(NodeSlot(src_image, src_output), - NodeSlot(dst_image, dst_input)) - .second) { - graph_out->AddEdge(src_image, src_output, dst_image, dst_input); + } + + // Make a copy of all the graph nodes needed to evaluate the send_from_host + // nodes, inlining any functions as needed. + TF_RETURN_IF_ERROR(MakePrunedGraphCopyAndInline( + graph, send_from_host_nodes, pruned_graph, node_images, library)); + + // Perform shape inference on the pruned graph. + shape_refiner->set_require_shape_inference_fns(false); + FixupSourceAndSinkEdges(pruned_graph->get()); + std::vector post_order; + GetReversePostOrder(*(*pruned_graph), &post_order); + for (auto node : post_order) { + // Ignore the status returned by the shape_refiner. At this point we want + // the best effort shapes, even if no shape function is registered for a + // node. + Status status = shape_refiner->AddNode(node); + if (!status.ok()) { + VLOG(1) << "Shape inference failed for node: " << status; } } - return s; + return Status::OK(); +} + +Status Encapsulator::GetShapeInfoForOutsideCompilationSends( + Graph* graph_out, FunctionLibraryDefinition* library) { + std::unique_ptr pruned_graph; + ShapeRefiner shape_refiner(graph_out->versions(), graph_out->op_registry()); + std::unordered_map node_images; + TF_RETURN_IF_ERROR(MakeGraphForOutsideCompilationSends( + *graph_out, &pruned_graph, &shape_refiner, &node_images, library)); + + for (auto& subgraph_entry : subgraphs_) { + Subgraph& subgraph = subgraph_entry.second; + // Find all the recv_at_host nodes in this subgraph. + std::vector outside_compilation_names; + subgraph.GetOutsideCompilationSubgraphNames(&outside_compilation_names); + std::unordered_set recv_at_host_names; + for (const auto& name : outside_compilation_names) { + Node* recv_node = subgraph.GetRecvAtHostNode(name); + if (recv_node != nullptr) { + recv_at_host_names.insert(recv_node->name()); + } + } + // For each send_from_host node, do as much shape inference as possible + // without knowing the shape of the recv_at_host nodes, and store the + // result, along with enough information to complete the job at compile time + // once the recv_at_host shapes are known. + for (const auto& name : outside_compilation_names) { + Node* send_node = subgraph.GetSendFromHostNode(name); + std::vector static_shape; + std::unique_ptr graphdef; + if (send_node != nullptr) { + TF_RETURN_IF_ERROR(DoStaticShapeInferenceForOutsideCompilationSend( + *pruned_graph, shape_refiner, recv_at_host_names, + node_images[send_node], library, &static_shape, &graphdef)); + if (graphdef == nullptr) { + VLOG(2) << "Send node " << send_node->name() << " shapes"; + for (int i = 0; i < static_shape.size(); ++i) { + VLOG(2) << static_shape[i].DebugString(); + } + } else { + VLOG(2) << "Send node " << send_node->name() << " graph\n" + << graphdef->DebugString(); + } + } + TF_RETURN_IF_ERROR( + subgraph.AddShapeInferenceInfo(name, static_shape, graphdef.get())); + } + if (!outside_compilation_names.empty()) { + TF_RETURN_IF_ERROR(subgraph.ReplaceFunctionDef(library)); + } + } + + return Status::OK(); +} + +Status Encapsulator::BuildOutputGraph(bool parallel_checking, Graph* graph_out, + FunctionLibraryDefinition* library) { + // Map from nodes in the input graph to nodes in the output graph. + std::unordered_map node_images; + + TF_RETURN_IF_ERROR( + CopyNodesToOutputGraph(parallel_checking, graph_out, &node_images)); + TF_RETURN_IF_ERROR( + AddFunctionCallNodes(node_images, parallel_checking, graph_out)); + TF_RETURN_IF_ERROR(AddOutsideCompilationHostIONodes(node_images, graph_out)); + TF_RETURN_IF_ERROR( + AddEdgesToOutputGraph(node_images, parallel_checking, graph_out)); + + TF_RETURN_IF_ERROR( + GetShapeInfoForOutsideCompilationSends(graph_out, library)); + + return Status::OK(); } } // anonymous namespace Status EncapsulateSubgraphsInFunctions( - string group_attribute, const Graph& graph_in, - const RewriteSubgraphFn& rewrite_subgraph_fn, bool parallel_checking, - bool reuse_existing_functions, std::unique_ptr* graph_out, - FunctionLibraryDefinition* library) { + string group_attribute, string outside_compilation_attribute, + const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool parallel_checking, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library) { Status s; - Encapsulator encapsulator(std::move(group_attribute), &graph_in); - s = encapsulator.SplitIntoSubgraphs(); - if (!s.ok()) return s; + Encapsulator encapsulator(std::move(group_attribute), + std::move(outside_compilation_attribute), + &graph_in); + TF_RETURN_IF_ERROR(encapsulator.SplitIntoSubgraphs()); - s = encapsulator.BuildFunctionDefs(rewrite_subgraph_fn, - reuse_existing_functions, library); - if (!s.ok()) return s; + TF_RETURN_IF_ERROR(encapsulator.BuildFunctionDefs( + rewrite_subgraph_fn, reuse_existing_functions, library)); std::unique_ptr out(new Graph(library)); out->set_versions(graph_in.versions()); - s = encapsulator.BuildOutputGraph(parallel_checking, out.get()); - if (!s.ok()) return s; + TF_RETURN_IF_ERROR( + encapsulator.BuildOutputGraph(parallel_checking, out.get(), library)); *graph_out = std::move(out); - return s; + return Status::OK(); } // Finds the types of the _Arg nodes, indexed by position. @@ -690,9 +2112,9 @@ Status EncapsulateSubgraphsPass::Run( }; TF_RETURN_IF_ERROR(EncapsulateSubgraphsInFunctions( - kXlaClusterAttr, **options.graph, rewrite_subgraph, - flags->tf_xla_parallel_checking, /*reuse_existing_functions=*/false, - &graph_out, library)); + kXlaClusterAttr, kXlaOutsideCompilationAttr, **options.graph, + rewrite_subgraph, flags->tf_xla_parallel_checking, + /*reuse_existing_functions=*/false, &graph_out, library)); if (VLOG_IS_ON(1)) { dump_graph::DumpGraphToFile("after_encapsulate_subgraphs", *graph_out, diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h index b0987f76c91ed48df52fab303ea6052ebd8fd336..34be4409a381197d2191e083727aa8d48ab8cd63 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.h @@ -48,6 +48,16 @@ typedef std::function* graph_out, - FunctionLibraryDefinition* library); + string group_attribute, string outside_compilation_attribute, + const Graph& graph_in, const RewriteSubgraphFn& rewrite_subgraph_fn, + bool parallel_checking, bool reuse_existing_functions, + std::unique_ptr* graph_out, FunctionLibraryDefinition* library); // The attribute that marks function calls produced by the encapsulate // subgraphs pass and that should in turn be compiled via _XlaLaunch operators. diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc index 4a1dbaf05dc7824835f3567c6abcf48222720230..aed9cae0f1799c4524da8ee309344849798755d5 100644 --- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc +++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc @@ -29,17 +29,181 @@ limitations under the License. namespace tensorflow { namespace { +template +bool EqualProtoMap(const ::tensorflow::protobuf::Map& a, + const ::tensorflow::protobuf::Map& b, + const std::function& key_to_string, + const std::function& value_to_string, + const std::function& compare, + const string& map_name, string* diff) { + for (const auto& elt_a : a) { + const auto iter = b.find(elt_a.first); + if (iter == b.end()) { + if (diff) { + *diff = strings::StrCat( + map_name, " expected: contains element with key '", + key_to_string(elt_a.first), "' got: map has no such element"); + } + return false; + } + if (!compare(elt_a.first, elt_a.second, iter->second)) { + if (diff) { + *diff = strings::StrCat(map_name, " expected: element with key '", + key_to_string(elt_a.first), " has value '", + value_to_string(elt_a.second), "' got: '", + value_to_string(iter->second), "'"); + } + return false; + } + } + for (const auto& elt_b : b) { + const auto iter = a.find(elt_b.first); + if (iter == a.end()) { + if (diff) { + *diff = strings::StrCat(map_name, " got: contains element with key '", + key_to_string(elt_b.first), + "' expected: map has no such element"); + } + return false; + } + } + return true; +} + +bool EqualFunctionNodeDef(const NodeDef& a, const NodeDef& b, + const string& diff_preamble, string* diff) { + if (a.op() != b.op()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected op '", a.op(), "' got '", b.op()); + } + return false; + } + if (a.device() != b.device()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected device '", a.device(), "' got '", + b.device()); + } + return false; + } + if (a.input_size() != b.input_size()) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + ", expected ", a.input_size(), " inputs got ", + b.input_size(), " expected:\n", a.DebugString(), + "\ngot:\n", b.DebugString()); + } + return false; + } + for (int i = 0; i < a.input_size(); ++i) { + if (a.input(i) != b.input(i)) { + if (diff) { + *diff = strings::StrCat(diff_preamble, " mismatch for node ", a.name(), + " input ", i, ", expected ", a.input(i), + " got ", b.input(i), " expected:\n", + a.DebugString(), "\ngot:\n", b.DebugString()); + } + return false; + } + } + return EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + if (key == "shape_inference_graph") { + // Default serialization of GraphDef is unstable because maps don't + // serialize deterministically. Rather than go through the hoops to + // turn on deterministic serialization of this attr just for this + // test, add logic here to compare determinstically. + GraphDef ga; + if (!ga.ParseFromString(av.s())) { + return false; + } + GraphDef gb; + if (!gb.ParseFromString(bv.s())) { + return false; + } + return EqualGraphDef(ga, gb, nullptr); + } else { + return av.DebugString() == bv.DebugString(); + } + }, + strings::StrCat(diff_preamble, " attr mismatch for node ", a.name()), + diff); +} + bool EqualFunctionDef(const FunctionDef& a, const FunctionDef& b, string* diff) { - // TODO(phawkins) use a more sophisticated equality test. - if (a.DebugString() != b.DebugString()) { + if (a.signature().DebugString() != b.signature().DebugString()) { if (diff) { - *diff = strings::StrCat("Definition mismatch for function ", + *diff = strings::StrCat("Signature mismatch for function ", a.signature().name(), ", expected:\n", - a.DebugString()); + a.signature().DebugString(), "\ngot:\n", + b.signature().DebugString()); } return false; } + if (!EqualProtoMap( + a.attr(), b.attr(), [](const string& s) { return s; }, + [](const AttrValue& v) { return v.DebugString(); }, + [](const string& key, const AttrValue& av, const AttrValue& bv) { + return av.DebugString() == bv.DebugString(); + }, + strings::StrCat("attr mismatch for function ", a.signature().name()), + diff)) { + return false; + } + if (!EqualProtoMap( + a.ret(), b.ret(), [](const string& s) { return s; }, + [](const string& s) { return s; }, + [](const string& key, const string& av, const string& bv) { + return av == bv; + }, + strings::StrCat("ret mismatch for function ", a.signature().name()), + diff)) { + return false; + } + for (int i = 0; i < a.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < b.node_def_size(); ++j) { + if (a.node_def(i).name() == b.node_def(j).name()) { + if (!EqualFunctionNodeDef( + a.node_def(i), b.node_def(j), + strings::StrCat("Function ", a.signature().name()), diff)) { + return false; + } + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", expected: has node '", a.node_def(i).name(), + "' got: no node of that name"); + } + return false; + } + } + for (int i = 0; i < b.node_def_size(); ++i) { + bool found = false; + for (int j = 0; j < a.node_def_size(); ++j) { + if (b.node_def(i).name() == a.node_def(j).name()) { + found = true; + break; + } + } + if (!found) { + if (diff) { + *diff = strings::StrCat("Function ", a.signature().name(), + ", got: has node '", b.node_def(i).name(), + "' expected: no node of that name"); + } + return false; + } + } return true; } @@ -82,13 +246,66 @@ bool EqualFunctionDefLibrary(const FunctionDefLibrary& expected, << diff << "\nActual: " << actual.DebugString(); \ } while (false) -REGISTER_OP("InputTest").Output("o: float"); - -REGISTER_OP("UnaryTest").Input("a: float").Output("o: float"); +// TODO(misard): remove these fake registrations once there are real Ops to be +// compiled. +REGISTER_OP("_XlaHostCompute") + .Input("inputs: Tinputs") + .Output("outputs: Toutputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); + +REGISTER_OP("_XlaSendFromHost") + .Input("input: Tinputs") + .Attr("Tinputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); + +REGISTER_OP("_XlaRecvAtHost") + .Output("output: Toutputs") + .Attr("Toutputs: list(type) >= 0") + .Attr("key: string") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); + +REGISTER_OP("InputTest") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->UnknownShape()); + return Status::OK(); + }); + +REGISTER_OP("InputTestShaped") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + c->set_output(0, c->Vector(2)); + return Status::OK(); + }); + +REGISTER_OP("UnaryTest") + .Input("a: float") + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); REGISTER_OP("BinaryTest") .Input("a: float") .Input("b: float") - .Output("o: float"); + .Output("o: float") + .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { + ::tensorflow::shape_inference::ShapeHandle o; + TF_RETURN_IF_ERROR(c->Merge(c->UnknownShape(), c->input(0), &o)); + c->set_output(0, o); + return Status::OK(); + }); +REGISTER_OP("BinaryTest2") + .Input("a: float") + .Input("b: float") + .Output("o: float") + .SetShapeFn(::tensorflow::shape_inference::UnknownShape); REGISTER_OP("AddNLikeTest") .Input("inputs: N * T") @@ -98,10 +315,58 @@ REGISTER_OP("AddNLikeTest") .SetIsCommutative() .SetIsAggregate(); +Node* NoOp(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("NoOp", opts); +} + Node* Input(const GraphDefBuilder::Options& opts) { return ops::SourceOp("InputTest", opts); } +Node* InputShaped(const GraphDefBuilder::Options& opts) { + return ops::SourceOp("InputTestShaped", opts); +} + +Node* KnownShape(const gtl::ArraySlice& shape, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("Const"), "Const", + opts.op_registry()); + TensorProto value; + value.set_dtype(DT_FLOAT); + for (int dim : shape) { + value.mutable_tensor_shape()->add_dim()->set_size(dim); + } + return opts.WithAttr("value", value) + .WithAttr("dtype", DT_FLOAT) + .FinalizeBuilder(&node_builder); +} + +Node* RecvAtHost(const string& key, const gtl::ArraySlice& dtypes, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("_XlaRecvAtHost"), + "_XlaRecvAtHost", opts.op_registry()); + return opts.WithAttr("Toutputs", dtypes) + .WithAttr("key", key) + .FinalizeBuilder(&node_builder); +} + +Node* SendFromHost(const string& key, const std::vector& inputs, + const GraphDefBuilder::Options& opts) { + if (opts.HaveError()) return nullptr; + NodeBuilder node_builder(opts.GetNameForOp("_XlaSendFromHost"), + "_XlaSendFromHost", opts.op_registry()); + node_builder.Input(inputs); + std::vector dtypes; + for (const auto& node : inputs) { + dtypes.push_back(node.dt); + } + return opts.WithAttr("key", key) + .WithAttr("Tinputs", dtypes) + .FinalizeBuilder(&node_builder); +} + Node* Unary(ops::NodeOut a, const GraphDefBuilder::Options& opts) { return ops::UnaryOp("UnaryTest", std::move(a), opts); } @@ -111,6 +376,11 @@ Node* Binary(ops::NodeOut a, ops::NodeOut b, return ops::BinaryOp("BinaryTest", std::move(a), std::move(b), opts); } +Node* BinaryUnknownShape(ops::NodeOut a, ops::NodeOut b, + const GraphDefBuilder::Options& opts) { + return ops::BinaryOp("BinaryTest2", std::move(a), std::move(b), opts); +} + Node* AddNLike(const std::vector& inputs, const GraphDefBuilder::Options& opts) { if (opts.HaveError()) return nullptr; @@ -145,7 +415,7 @@ Status Encapsulate(GraphDef* graphdef, FunctionDefLibrary* library) { if (!s.ok()) return s; std::unique_ptr graph_out; - s = EncapsulateSubgraphsInFunctions("_encapsulate", *graph, + s = EncapsulateSubgraphsInFunctions("_encapsulate", "_outside", *graph, /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, /*reuse_existing_functions=*/false, @@ -178,6 +448,7 @@ TEST(EncapsulateSubgraphsTest, NoFunctions) { FunctionDefLibrary library_out = library_in; TF_EXPECT_OK(Encapsulate(&graphdef_out, &library_out)); + // If there are no marked nodes, funcification should be a no-op. TF_EXPECT_GRAPH_EQ(graphdef_in, graphdef_out); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_in, library_out); } @@ -230,7 +501,6 @@ TEST(EncapsulateSubgraphsTest, OneFunction) { TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); } - // If there are no marked nodes, funcification should be a no-op. TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); } @@ -342,9 +612,9 @@ TEST(EncapsulateSubgraphsTest, InputDeduplication) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/false, /*reuse_existing_functions=*/false, &graph, - &library)); + "_cluster", "_outside", graph_before_encapsulation, + /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = {"cluster1", "cluster2", "mul", "x"}; EXPECT_EQ(expected_nodes, GraphNodes(*graph)); @@ -374,9 +644,9 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { FunctionLibraryDefinition library(OpRegistry::Global(), {}); std::unique_ptr graph; TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( - "_cluster", graph_before_encapsulation, /*rewrite_subgraph_fn=*/{}, - /*parallel_checking=*/true, /*reuse_existing_functions=*/false, &graph, - &library)); + "_cluster", "_outside", graph_before_encapsulation, + /*rewrite_subgraph_fn=*/{}, /*parallel_checking=*/true, + /*reuse_existing_functions=*/false, &graph, &library)); std::vector expected_nodes = { "add1", "add2", "cluster1", "cluster1_parallel_check/_0", @@ -398,5 +668,978 @@ TEST(EncapsulateSubgraphsTest, ParallelChecking) { EXPECT_EQ(expected_edges, GraphEdges(*graph)); } +const Node* FindNodeByName(const Graph& graph, const string& name) { + for (const Node* node : graph.nodes()) { + if (node->name() == name) return node; + } + return nullptr; +} + +bool HasGuaranteeConstAttr(const Node& n) { + bool is_guaranteed_constant = false; + if (!GetNodeAttr(n.attrs(), "_is_guaranteed_constant", + &is_guaranteed_constant) + .ok()) { + return false; + } + return is_guaranteed_constant; +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Simple) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto const_x2 = ops::Const(root.WithOpName("const_x2"), 10.0f); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto add1 = ops::Add(root.WithOpName("add1"), const_guarantee_x1, const_x2); + add1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", "_outside", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + EXPECT_EQ(2, guaranteed_consts); +} + +TEST(EncapsulateSubgraphsWithGuaranteeConstOpTest, Add) { + Scope root = Scope::NewRootScope().ExitOnError().WithDevice( + "/job:localhost/replica:0/task:0/cpu:0"); + auto x1 = ops::Placeholder(root.WithOpName("x1"), DT_FLOAT); + auto x2 = ops::Placeholder(root.WithOpName("x2"), DT_FLOAT); + auto const_guarantee_x1 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x1"), x1); + auto const_guarantee_x2 = + ops::GuaranteeConst(root.WithOpName("const_guarantee_x2"), x2); + auto const_guarantee_add1 = ops::Add(root.WithOpName("const_guarantee_add1"), + const_guarantee_x1, const_guarantee_x2); + auto add2 = ops::Add(root.WithOpName("add2"), const_guarantee_x1, x2); + auto mul1 = ops::Mul(root.WithOpName("mul1"), const_guarantee_add1, add2); + mul1.node()->AddAttr("_encapsulate", "encapsulate1"); + + Graph graph_before(OpRegistry::Global()); + TF_ASSERT_OK(root.ToGraph(&graph_before)); + + std::unique_ptr graph_after; + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + int guaranteed_consts = 0; + TF_ASSERT_OK(EncapsulateSubgraphsInFunctions( + "_encapsulate", "_outside", graph_before, + /*rewrite_subgraph_fn=*/ + [&guaranteed_consts](std::unique_ptr* graph_ptr, + std::vector* input_permutation, + std::vector* output_permutation, + NodeDef* call_def) { + Graph* graph = graph_ptr->get(); + for (const Node* n : graph->nodes()) { + if (n->type_string() == "_Arg" && + StringPiece(n->name()).starts_with("const")) { + ++guaranteed_consts; + EXPECT_TRUE(HasGuaranteeConstAttr(*n)); + } else { + EXPECT_FALSE(HasGuaranteeConstAttr(*n)); + } + } + return Status::OK(); + }, + /*parallel_checking=*/false, + /*reuse_existing_functions=*/false, &graph_after, &library)); + // Only 1 runtime const, which is const_guarantee_add1. Add2 has one const + // and another non-const, so overall non-const. + EXPECT_EQ(1, guaranteed_consts); +} + +// Test with one function to transform and one outside_compilation cluster. +TEST(EncapsulateSubgraphsTest, OneFunctionOneOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = Binary(b, c, + b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"c"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}, {"C"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"C:o:0", "c:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, + {"c"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().FinalizeBuilder(&node_builder); + + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + b2.opts().WithName("E").WithControlInputs({recv, b})); + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* s = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one function to transform and two outside_compilation clusters. +TEST(EncapsulateSubgraphsTest, OneFunctionTwoOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = Binary(e, f, + b1.opts() + .WithName("G") + .WithControlInputs({e, f}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* h = Binary(d, e, + b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O2")); + Node* i = Unary(h, b1.opts().WithName("I").WithAttr("_encapsulate", "F1")); + Binary(g, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + string shape_string_expected_1; + { + GraphDefBuilder shape1(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape1.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape1.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape1.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape1_graph; + TF_EXPECT_OK(shape1.ToGraphDef(&shape1_graph)); + EXPECT_TRUE(shape1_graph.SerializeToString(&shape_string_expected_1)); + } + + string shape_string_expected_2; + { + GraphDefBuilder shape2(GraphDefBuilder::kFailImmediately); + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + shape2.opts().WithName("E")); + Node* recv2 = + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, + shape2.opts().WithName("outside_compilation_F1_O2_recv")); + Node* h = Binary(ops::NodeOut(recv2, 0), e, shape2.opts().WithName("H")); + SendFromHost("host_compute_channel_F1_O2", {h}, + shape2.opts().WithName("outside_compilation_F1_O2_send")); + GraphDef shape2_graph; + TF_EXPECT_OK(shape2.ToGraphDef(&shape2_graph)); + EXPECT_TRUE(shape2_graph.SerializeToString(&shape_string_expected_2)); + } + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"i_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}, {}}, + {{"I"}, + "UnaryTest", + {"outside_compilation_O2_host_compute:outputs:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O2_host_compute"}, + "_XlaHostCompute", + {"D:o:0", "F:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O2"}, + {"shape_inference_graph", shape_string_expected_2}, + {"shapes", gtl::ArraySlice({})}}, + {"F"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"C:o:0", "D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected_1}, + {"shapes", gtl::ArraySlice({})}}, + {"D"}}, + }, + {{"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(a).Input(b); + Node* call = b2.opts().FinalizeBuilder(&node_builder); + + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts().WithName("E").WithControlInputs({recv1, b})); + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* recv2 = + RecvAtHost("host_compute_channel_F1_O2", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O2_recv")); + Node* g = Binary(e, ops::NodeOut(recv2, 1), + b2.opts().WithName("G").WithControlInputs({recv2, e})); + Node* h = Binary(ops::NodeOut(recv2, 0), e, b2.opts().WithName("H")); + Node* send2 = + SendFromHost("host_compute_channel_F1_O2", {h}, + b2.opts().WithName("outside_compilation_F1_O2_send")); + + Node* s = NoOp(b2.opts() + .WithName("F1_sequencer") + .WithControlInputs({recv1, send1, recv2, send2})); + + Binary(g, call, b2.opts().WithName("J").WithControlInput(s)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with two functions to transform, each with one outside_compilation +// cluster. +TEST(EncapsulateSubgraphsTest, TwoFunctionsTwoOutside) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = InputShaped(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Binary(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Node* g = Binary(e, f, + b1.opts().WithName("G").WithControlInputs({e, f}).WithAttr( + "_encapsulate", "F2")); + Node* h = Binary(d, g, + b1.opts() + .WithName("H") + .WithAttr("_encapsulate", "F2") + .WithAttr("_outside", "O1")); + Node* i = + Binary(f, h, b1.opts().WithName("I").WithAttr("_encapsulate", "F2")); + Binary(g, i, b1.opts().WithName("J")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv, 0), ops::NodeOut(recv, 1), + shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, + {"f_0_retval:float", "d_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"C:o:0", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"C:o:0", "D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT, DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, + {"D"}}, + }, + {{"d_0_retval", "D:o:0"}, {"f_0_retval", "F:o:0"}}); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F2", {"e_0_arg:float", "f_0_arg:float"}, + {"g_0_retval:float", "i_0_retval:float"}, {}, + { + {{"G"}, "BinaryTest", {"e_0_arg", "f_0_arg"}}, + {{"I"}, + "BinaryTest", + {"f_0_arg", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"G:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F2_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, + }, + {{"g_0_retval", "G:o:0"}, {"i_0_retval", "I:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = InputShaped(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT, DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Binary(ops::NodeOut(recv1, 0), ops::NodeOut(recv1, 1), + b2.opts().WithName("E").WithControlInputs({recv1, b})); + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Node* recv2 = + RecvAtHost("host_compute_channel_F2_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F2_O1_recv")); + Node* h = Binary(ops::NodeOut(call1, 1), recv2, + b2.opts().WithName("H").WithControlInput(s1)); + Node* send2 = + SendFromHost("host_compute_channel_F2_O1", {h}, + b2.opts().WithName("outside_compilation_F2_O1_send")); + + NodeBuilder node_builder2("F2", "F2", lib_def.get()); + node_builder2.Input(e).Input(call1); + Node* call2 = b2.opts() + .WithControlInputs({s1, e, call1}) + .FinalizeBuilder(&node_builder2); + Node* s2 = NoOp( + b2.opts().WithName("F2_sequencer").WithControlInputs({recv2, send2})); + Binary(call2, ops::NodeOut(call2, 1), + b2.opts().WithName("J").WithControlInput(s2)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no inputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = + Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Unary(f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts().WithName("E")); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(send1)); + + Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no data inputs but has a +// control input from the compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationControlInput) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithControlInput(d) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = + Binary(d, e, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Unary(f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + TensorShapeProto shape_proto_expected; + shape_proto_expected.add_dim()->set_size(2); + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "BinaryTest", + {"D:o:0", "outside_compilation_O1_host_compute:outputs:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {}, + {{"Tinputs", gtl::ArraySlice({})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", + gtl::ArraySlice({shape_proto_expected})}}, + {"D"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(a, b2.opts().WithName("E").WithControlInput(recv1)); + Node* send1 = + SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts().WithName("outside_compilation_F1_O1_send")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Unary(call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no outputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoOutputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(recv1, b2.opts().WithName("E")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp(b2.opts().WithName("F1_sequencer").WithControlInput(recv1)); + + Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no data outputs but has a +// control output to the compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationControlOutput) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(d, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, + "UnaryTest", + {"D:o:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"D:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", ""}, + {"shapes", gtl::ArraySlice({})}}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* recv1 = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = Unary(recv1, b2.opts().WithName("E")); + Node* send1 = SendFromHost("host_compute_channel_F1_O1", {}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + Node* s1 = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv1, send1})); + + Binary(e, call1, b2.opts().WithName("G").WithControlInput(s1)); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test with one outside_compilation cluster that has no outputs from the +// compiled subgraph. +TEST(EncapsulateSubgraphsTest, OutsideCompilationNoInputsOrOutputs) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = Input(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + Node* c = Unary(a, b1.opts().WithName("C").WithAttr("_encapsulate", "F1")); + Node* d = + Binary(b, c, b1.opts().WithName("D").WithAttr("_encapsulate", "F1")); + Node* e = Unary(a, b1.opts() + .WithName("E") + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Unary(d, b1.opts().WithName("F").WithAttr("_encapsulate", "F1")); + Binary(e, f, b1.opts().WithName("G")); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"a_0_arg:float", "b_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"C"}, "UnaryTest", {"a_0_arg"}}, + {{"D"}, "BinaryTest", {"b_0_arg", "C:o:0"}}, + {{"F"}, "UnaryTest", {"D:o:0"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = Input(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + + Node* e = Unary(a, b2.opts().WithName("E")); + NodeBuilder node_builder1("F1", "F1", lib_def.get()); + node_builder1.Input(a).Input(b); + Node* call1 = b2.opts().FinalizeBuilder(&node_builder1); + + Binary(e, call1, b2.opts().WithName("G")); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + +// Test for shape inference of outside compilation. +TEST(EncapsulateSubgraphsTest, OutsideCompilationShapeInference) { + FunctionDefLibrary library; + GraphDef graphdef; + + { + *library.add_function() = test::function::XTimesTwo(); + + GraphDefBuilder b1(GraphDefBuilder::kFailImmediately); + Node* a = InputShaped(b1.opts().WithName("A")); + Node* b = Input(b1.opts().WithName("B")); + // Give nodes 'c' and 'd' names that collide after lowercasing. + Node* c = Unary(a, b1.opts().WithName("C")); + Node* d = Unary(b, b1.opts().WithName("c").WithControlInput(c).WithAttr( + "_encapsulate", "F1")); + Node* e = BinaryUnknownShape(c, d, + b1.opts() + .WithName("E") + .WithControlInputs({b, d}) + .WithAttr("_encapsulate", "F1") + .WithAttr("_outside", "O1")); + Node* f = Binary(c, e, + b1.opts().WithName("F").WithControlInput(e).WithAttr( + "_encapsulate", "F1")); + Binary(a, f, b1.opts().WithName("G").WithControlInput(e)); + TF_EXPECT_OK(b1.ToGraphDef(&graphdef)); + } + + TF_EXPECT_OK(Encapsulate(&graphdef, &library)); + + FunctionDefLibrary library_expected; + GraphDef graphdef_expected; + + string shape_string_expected; + { + GraphDefBuilder shape(GraphDefBuilder::kFailImmediately); + Node* known = KnownShape({2}, shape.opts().WithName("KnownShape/_0")); + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + shape.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape(known, recv, shape.opts().WithName("E")); + SendFromHost("host_compute_channel_F1_O1", {e}, + shape.opts().WithName("outside_compilation_F1_O1_send")); + GraphDef shape_graph; + TF_EXPECT_OK(shape.ToGraphDef(&shape_graph)); + EXPECT_TRUE(shape_graph.SerializeToString(&shape_string_expected)); + } + + *library_expected.add_function() = test::function::XTimesTwo(); + *library_expected.add_function() = FunctionDefHelper::Create( + "F1", {"b_0_arg:float", "c_0_arg:float"}, {"f_0_retval:float"}, {}, + { + {{"c"}, "UnaryTest", {"b_0_arg"}, {}, {}}, + {{"F"}, + "BinaryTest", + {"c_0_arg", "outside_compilation_O1_host_compute:outputs:0"}, + {}, + {"outside_compilation_O1_host_compute"}}, + {{"outside_compilation_O1_host_compute"}, + "_XlaHostCompute", + {"c:o:0"}, + {{"Tinputs", gtl::ArraySlice({DT_FLOAT})}, + {"Toutputs", gtl::ArraySlice({DT_FLOAT})}, + {"key", "host_compute_channel_F1_O1"}, + {"shape_inference_graph", shape_string_expected}, + {"shapes", gtl::ArraySlice({})}}, + {"c"}}, + }, + {{"f_0_retval", "F:o:0"}}); + + { + std::unique_ptr lib_def( + new FunctionLibraryDefinition(OpRegistry::Global(), library_expected)); + GraphDefBuilder b2(GraphDefBuilder::kFailImmediately, lib_def.get()); + Node* a = InputShaped(b2.opts().WithName("A")); + Node* b = Input(b2.opts().WithName("B")); + Node* c = Unary(a, b2.opts().WithName("C")); + + NodeBuilder node_builder("F1", "F1", lib_def.get()); + node_builder.Input(b).Input(c); + Node* call = + b2.opts().WithControlInputs({c}).FinalizeBuilder(&node_builder); + + Node* recv = + RecvAtHost("host_compute_channel_F1_O1", {DT_FLOAT}, + b2.opts().WithName("outside_compilation_F1_O1_recv")); + Node* e = BinaryUnknownShape( + c, ops::NodeOut(recv, 0), + b2.opts().WithName("E").WithControlInputs({recv, b})); + Node* send = SendFromHost("host_compute_channel_F1_O1", {e}, + b2.opts() + .WithName("outside_compilation_F1_O1_send") + .WithControlInput(e)); + + Node* s = NoOp( + b2.opts().WithName("F1_sequencer").WithControlInputs({recv, send})); + + Binary(a, call, b2.opts().WithName("G").WithControlInputs({s, e})); + TF_EXPECT_OK(b2.ToGraphDef(&graphdef_expected)); + } + + TF_EXPECT_GRAPH_EQ(graphdef_expected, graphdef); + TF_EXPECT_FUNCTIONDEFLIBRARY_EQ(library_expected, library); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 459a582e157f5ddc63997ca93e7c0294293517d3..9bea5663319c8a25249fdc265cee0191556a7c04 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -16,7 +16,6 @@ cc_library( "//tensorflow/compiler/jit:xla_device", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:local_client", diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index e481796d9e626fc8cdf36687ad110b0a8a788be0..6353149e4afdf739fe44dd5c76502ef5d98b8477 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "tensorflow/compiler/jit/xla_device.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/local_client.h" @@ -46,7 +45,7 @@ namespace tensorflow { // see comment on `AllowsAsynchronousDeallocation()`. class XlaAllocator : public xla::DeviceMemoryAllocator { public: - XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context); + XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context); ~XlaAllocator() override; xla::StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; @@ -80,7 +79,8 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { std::unordered_map tensors_; }; -XlaAllocator::XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context) +XlaAllocator::XlaAllocator(const gpu::Platform* platform, + OpKernelContext* op_context) : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} XlaAllocator::~XlaAllocator() = default; @@ -103,7 +103,6 @@ xla::StatusOr XlaAllocator::Allocate( } void* data = reinterpret_cast(const_cast(t.tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = t; return gpu::DeviceMemoryBase(data, size); } @@ -111,7 +110,6 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = reinterpret_cast(const_cast(t->tensor_data().data())); - TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); } @@ -251,24 +249,26 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { xla::LocalClient* client = static_cast(cache->client()); + // Builds an XLA allocator for the device. + XlaAllocator xla_allocator(client->platform(), ctx); + XlaCompiler::Options options; options.client = client; options.device_type = &cache->device_type(); options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); options.graph_def_version = ctx->function_library()->graph_def_version(); options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId); + options.device_allocator = &xla_allocator; const XlaCompiler::CompilationResult* kernel; xla::LocalExecutable* executable; + OP_REQUIRES_OK(ctx, cache->Compile(options, function_, num_constant_args_, - variables, ctx, &kernel, &executable)); + variables, ctx, &kernel, &executable, + /*compile_options=*/nullptr)); VLOG(1) << "Executing XLA Computation..."; - // Builds an XLA allocator for the device. - XlaAllocator xla_allocator(client->platform(), ctx); - XlaLocalRuntimeContext local_runtime_context; - std::unique_ptr output; // Build xla::ShapedBuffers that point directly to the Tensor buffers. std::vector> arg_buffers; @@ -291,27 +291,22 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { gpu::DeviceMemoryBase dmem = gpu::DeviceMemoryBase( const_cast(t->tensor_data().data()), t->tensor_data().size()); - arg_buffers[i] = - xla::ShapedBuffer::MakeArrayShapedBuffer( - shape, client->platform(), client->default_device_ordinal(), dmem) - .ConsumeValueOrDie(); + const xla::Shape on_device_shape = + client->backend().transfer_manager()->HostShapeToDeviceShape(shape); + CHECK(xla::ShapeUtil::Equal(shape, on_device_shape)) + << "On-device shape " + << xla::ShapeUtil::HumanStringWithLayout(on_device_shape) + << " not the same as on-host shape " + << xla::ShapeUtil::HumanStringWithLayout(shape); + arg_buffers[i] = xla::MakeUnique( + /*on_host_shape=*/shape, /*on_device_shape=*/shape, client->platform(), + client->default_device_ordinal()); + arg_buffers[i]->set_buffer(dmem, /*index=*/{}); arg_ptrs[i] = arg_buffers[i].get(); OP_REQUIRES_OK(ctx, xla_allocator.RegisterArgument(t)); } - // Make the final parameter point at local_runtime_context. - if (kernel->requires_runtime_context) { - gpu::DeviceMemoryBase local_runtime_context_dmem( - &local_runtime_context, sizeof(local_runtime_context)); - arg_buffers.push_back( - xla::ShapedBuffer::MakeArrayShapedBuffer( - xla::ShapeUtil::MakeOpaqueShape(), client->platform(), - client->default_device_ordinal(), local_runtime_context_dmem) - .ConsumeValueOrDie()); - arg_ptrs.push_back(arg_buffers.back().get()); - } - // Execute the computation. VLOG(2) << "Executing computation."; xla::ExecutableRunOptions run_options; @@ -323,19 +318,13 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { auto run_result = executable->Run(arg_ptrs, run_options); OP_REQUIRES(ctx, run_result.ok(), run_result.status()); - if (local_runtime_context.error) { - ctx->CtxFailure(errors::InvalidArgument("Compiled kernel returned error: ", - local_runtime_context.error_msg)); - return; - } - output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; // Computation output should always be a tuple. if (VLOG_IS_ON(2)) { - VLOG(2) << "Result tuple shape: " << output->shape().DebugString(); + VLOG(2) << "Result tuple shape: " << output->on_host_shape().DebugString(); } CHECK_EQ(ctx->num_outputs(), kernel->outputs.size()); @@ -387,8 +376,6 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { OP_REQUIRES(ctx, write.input_index >= 0 && write.input_index < ctx->num_inputs(), errors::Internal("Invalid input index for variable write.")); - TensorShape write_shape; - OP_REQUIRES_OK(ctx, XLAShapeToTensorShape(write.shape, &write_shape)); gpu::DeviceMemoryBase buffer = output->buffer({output_num}); @@ -410,7 +397,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { // Looks up the owning Tensor by buffer address. OP_REQUIRES_OK( - ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write_shape, + ctx, xla_allocator.MakeTensorFromBuffer(buffer, write.type, write.shape, variable->tensor())); ++output_num; } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 74c9791f5eaf1fbc43b152520df496a3b552af18..a0211acbbe9eec77d30c7d14293650de8826f41c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -41,6 +41,7 @@ limitations under the License. namespace tensorflow { const char* const kXlaClusterAttr = "_XlaCluster"; +const char* const kXlaOutsideCompilationAttr = "_XlaOutsideCompilation"; namespace { @@ -172,10 +173,15 @@ bool HasResourceInputOrOutput(const Node& node) { DT_RESOURCE) != node.output_types().end(); } +struct NodeCompare { + bool operator()(const Node* a, const Node* b) { return a->id() < b->id(); } +}; +using OrderedNodeSet = std::set; + Status FindCompilationCandidates( const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env, const std::function& is_compilable_fn, - std::unordered_set* candidates) { + OrderedNodeSet* candidates) { OptimizerOptions opts; std::unique_ptr pflr( new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION, @@ -184,6 +190,9 @@ Status FindCompilationCandidates( pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice); for (Node* node : graph.op_nodes()) { + VLOG(2) << "FindCompilationCandidates(): Processing " + << node->DebugString(); + DeviceType device_type(""); TF_RETURN_IF_ERROR( DeviceTypeOfDevice(node->assigned_device_name(), &device_type)); @@ -210,6 +219,20 @@ Status FindCompilationCandidates( !IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) { continue; } + // _Arg nodes in a top-level function represent feeds. + // Do not compile them. + if (node->type_string() == "_Arg") { + VLOG(2) << "Skipping jit compilation for '_Arg'-typed node " + << node->DebugString(); + continue; + } + // _Retval nodes in a top-level function represent fetches. + // Do not compile them. + if (node->type_string() == "_Retval") { + VLOG(2) << "Compilation rejected node: return value " << node->name() + << ": " << node->type_string(); + continue; + } candidates->insert(node); } return Status::OK(); @@ -291,6 +314,7 @@ Status MarkForCompilationPass::Run( static_cast(flags->tf_xla_auto_jit); } bool cpu_global_jit = flags->tf_xla_cpu_global_jit; + VLOG(1) << "flags->tf_xla_cpu_global_jit = " << flags->tf_xla_cpu_global_jit; const FunctionLibraryDefinition* fld = options.flib_def; auto is_compilable = [global_jit_level, cpu_global_jit, fld]( @@ -347,7 +371,7 @@ Status MarkForCompilationPass::RunImpl( Graph* graph = options.graph->get(); - std::unordered_set compilation_candidates; + OrderedNodeSet compilation_candidates; TF_RETURN_IF_ERROR(FindCompilationCandidates( *graph, options.flib_def, (options.session_options != nullptr) ? options.session_options->env diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index f91695800f585f37b72173d5e582c38b1154b69b..e9acbfb19e42cb43cb0b986c438a569de29b2ebc 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -28,6 +28,10 @@ namespace tensorflow { // encapsulate subgraphs pass. extern const char* const kXlaClusterAttr; +// The attribute that marks nodes in a cluster to be placed outside the xla +// compilation by the encapsulate subgraphs pass. +extern const char* const kXlaOutsideCompilationAttr; + // Pass that marks a subset of operators in the graph with attribute // _XlaCluster so they are compiled by the EncapsulateSubgraphsPass. class MarkForCompilationPass : public GraphOptimizationPass { diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index b3d258aea177fbefa4bae51d8156da2ff86c9032..1a8858cccef623185709ab5dc2187a313dd130f7 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/graph_def_builder.h" +#include "tensorflow/core/graph/graph_def_builder_util.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -80,7 +81,7 @@ TEST(XlaCompilationTest, Chains) { ops::UnaryOp("UncompilableUnary", c, builder.opts().WithName("D")); Node* e = ops::UnaryOp("Relu", d, builder.opts().WithName("E")); ops::UnaryOp("Relu", e, builder.opts().WithName("F")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -105,7 +106,7 @@ TEST(XlaCompilationTest, UncompilableCycles) { Node* b = ops::UnaryOp("UncompilableUnary", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -125,7 +126,7 @@ TEST(XlaCompilationTest, CompilableCycles) { .WithAttr("value", Tensor())); Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -148,7 +149,7 @@ TEST(XlaCompilationTest, UnsupportedTypes) { .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -177,7 +178,7 @@ TEST(XlaCompilationTest, ConcatWithConstArg) { concat_builder.Input(dim).Input({a, a}).Attr("N", 2); builder.opts().FinalizeBuilder(&concat_builder); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -212,7 +213,7 @@ TEST(XlaCompilationTest, FunctionCalls) { Node* c = ops::UnaryOp("Relu", b, builder.opts().WithName("C")); ops::UnaryOp("UncompilableFn", c, builder.opts().WithName("D")); ops::BinaryOp("NoInlineFn", c, c, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph, &flib_def)); @@ -244,7 +245,7 @@ TEST(XlaCompilationTest, MetadataOpsDontStartClusters) { Node* c = ops::UnaryOp("Rank", b, builder.opts().WithName("C")); Node* d = ops::UnaryOp("Size", c, builder.opts().WithName("D")); ops::UnaryOp("Shape", d, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); @@ -330,7 +331,7 @@ TEST(XlaCompilationTest, SymbolicGradients) { d_builder.Input({c, c}); builder.opts().FinalizeBuilder(&d_builder); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -382,7 +383,7 @@ TEST(XlaCompilationTest, CyclesWithAllDifferentScopes) { ops::BinaryOp( "MatMul", a, b, builder.opts().WithName("C").WithAttr(kXlaScopeAttr, "ScopeC")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -413,7 +414,7 @@ TEST(XlaCompilationTest, CyclesWithSplittingScopes) { ops::BinaryOp( "Add", b, c, builder.opts().WithName("D").WithAttr(kXlaScopeAttr, "Scope2")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -443,7 +444,7 @@ TEST(XlaCompilationTest, CyclesWithDifferentScopesAndBridge) { "Relu", a, builder.opts().WithName("B").WithAttr(kXlaScopeAttr, "ScopeB")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); - TF_CHECK_OK(builder.ToGraph(graph.get())); + TF_CHECK_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); @@ -484,7 +485,7 @@ TEST(XlaCompilationTest, Resources) { Node* c = ops::UnaryOp("ResourceOutput", b, builder.opts().WithName("C")); Node* d = ops::UnaryOp("ResourceInput", c, builder.opts().WithName("D")); ops::UnaryOp("Relu", d, builder.opts().WithName("E")); - TF_EXPECT_OK(builder.ToGraph(graph.get())); + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); } TF_ASSERT_OK(MarkForCompilation(&graph)); auto clusters = GetClusters(*graph); @@ -525,5 +526,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) { "+-- c\n")); } +TEST(XlaCompilationTest, Retval) { + std::unique_ptr graph(new Graph(OpRegistry::Global())); + GraphDef graphdef; + { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + Node* a = ops::SourceOp("Const", builder.opts() + .WithName("A") + .WithAttr("dtype", DT_FLOAT) + .WithAttr("value", Tensor())); + Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B")); + ops::UnaryOp("_Retval", b, + builder.opts() + .WithName("R") + .WithAttr("T", DT_FLOAT) + .WithAttr("index", 0)); + + TF_EXPECT_OK(GraphDefBuilderToGraph(builder, graph.get())); + } + + TF_ASSERT_OK(MarkForCompilation(&graph)); + auto clusters = GetClusters(*graph); + + EXPECT_EQ(2, clusters.size()); + EXPECT_TRUE(clusters.find("R") == clusters.cend()); + EXPECT_EQ(clusters["A"], clusters["B"]); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index bc2eccd2779b9ff68ae2121f7bc53d6f74aec3e3..6d854a920eb0b4c01b09024ceaef5035e847d392 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -148,8 +148,7 @@ Status BuildArguments(int num_constant_args, XlaCompiler::Argument& arg = (*args)[input_num]; arg.kind = XlaCompiler::Argument::kConstant; arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); arg.constant_value = input; ++input_num; } @@ -170,8 +169,7 @@ Status BuildArguments(int num_constant_args, arg.constant_value = input; } arg.type = input.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(input.dtype(), input.shape(), &arg.shape)); + arg.shape = input.shape(); ++input_num; } @@ -189,8 +187,7 @@ Status BuildArguments(int num_constant_args, if (variable_args[variable_id].present) { const Tensor& value = variable_args[variable_id].value; arg.type = value.dtype(); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(value.dtype(), value.shape(), &arg.shape)); + arg.shape = value.shape(); arg.initialized = true; } else { // The values of uninitialized variables are not passed as inputs, since @@ -199,7 +196,7 @@ Status BuildArguments(int num_constant_args, // uninitialized variables. arg.initialized = false; arg.type = DT_INVALID; - arg.shape = xla::Shape(); + arg.shape = TensorShape(); } ++input_num; } @@ -214,20 +211,16 @@ Status XlaCompilationCache::BuildExecutable( const XlaCompiler::CompilationResult& result, std::unique_ptr* executable) { VLOG(2) << "Compiling to local executable"; - xla::Shape opaque_shape = xla::ShapeUtil::MakeOpaqueShape(); std::vector argument_layouts( result.xla_input_shapes.size()); for (int i = 0; i < result.xla_input_shapes.size(); ++i) { argument_layouts[i] = &result.xla_input_shapes[i]; } - if (result.requires_runtime_context) { - // The final arg is the XlaLocalRuntimeContext*. - argument_layouts.push_back(&opaque_shape); - } xla::ExecutableBuildOptions build_options; build_options.set_device_ordinal(client_->default_device_ordinal()); build_options.set_result_layout(result.xla_output_shape); + build_options.set_device_allocator(options.device_allocator); auto compile_result = client_->Compile(*result.computation, argument_layouts, build_options); @@ -243,7 +236,8 @@ Status XlaCompilationCache::Compile( int num_constant_args, const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable) { + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options) { VLOG(1) << "XlaCompilationCache::Compile " << DebugString(); if (VLOG_IS_ON(2)) { @@ -302,9 +296,9 @@ Status XlaCompilationCache::Compile( XlaCompiler compiler(options); entry->compiled = true; - entry->compilation_status = - compiler.CompileFunction(XlaCompiler::CompileOptions(), function, args, - &entry->compilation_result); + entry->compilation_status = compiler.CompileFunction( + compile_options ? *compile_options : XlaCompiler::CompileOptions(), + function, args, &entry->compilation_result); } *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h index c3a8f68a157a2d34d4a6716c9951b2b698aead79..0858020716fcf4763e42dc0699ad22cfda756942 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.h +++ b/tensorflow/compiler/jit/xla_compilation_cache.h @@ -66,7 +66,8 @@ class XlaCompilationCache : public ResourceBase { const std::vector& variable_args, OpKernelContext* ctx, const XlaCompiler::CompilationResult** compilation_result, - xla::LocalExecutable** executable); + xla::LocalExecutable** executable, + const XlaCompiler::CompileOptions* compile_options); xla::LocalClient* client() const { return client_; } const DeviceType& device_type() const { return device_type_; } diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index fed2c92d763c33aad3c5b3f07c1f33364c797793..c936222f32056e92efced82d5adb3a96c8041a17 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -71,12 +71,14 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor, void* dst_ptr = DMAHelper::base(device_tensor); se::DeviceMemoryBase dev_dst_ptr(dst_ptr, total_bytes); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(&dev_dst_ptr, src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); @@ -105,12 +107,14 @@ void XlaTransferManager::CopyDeviceTensorToCPU(const Tensor* device_tensor, se::DeviceMemoryBase dev_src_ptr(src_ptr, total_bytes); void* dst_ptr = DMAHelper::base(cpu_tensor); - Status status = Status::OK(); + Status status; stream_->ThenMemcpy(dst_ptr, dev_src_ptr, total_bytes); // TODO(hpucha): Make this asynchronous. - if (!stream_->BlockHostUntilDone()) { + Status block_status = stream_->BlockHostUntilDone(); + if (!block_status.ok()) { status = xla::InternalError( - "Failed to complete data transfer on stream %p", stream_); + "Failed to complete data transfer on stream %p: %s", stream_, + block_status.error_message().c_str()); } done(status); diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD index c1edf2448c54ffddd7b70dcdfb1609080ca81b65..da4bc44c7a75c9f8faf16c537a17a1f2d16d5d61 100644 --- a/tensorflow/compiler/plugin/BUILD +++ b/tensorflow/compiler/plugin/BUILD @@ -41,6 +41,15 @@ cc_library( ], ) +# This target is added purely for the purpose of ensuring that `:xla_device` is +# always publicly visible to external XLA backend/plugin developers. +cc_library( + name = "plugin_device", + deps = [ + "//tensorflow/compiler/jit:xla_device", + ], +) + #----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index 6cad2b0824d86a9549cb77518448a7e4eb781bef..782bf82d4149968d5e5fbfb93bbd4ff1dcd75494 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -144,6 +144,21 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_triangular_solve_op_test", + size = "small", + srcs = ["matrix_triangular_solve_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + tf_xla_py_test( name = "clustering_test", size = "small", @@ -240,6 +255,35 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "extract_image_patches_op_test", + size = "small", + srcs = ["extract_image_patches_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "fft_test", + size = "medium", + srcs = ["fft_test.py"], + shard_count = 3, + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/contrib/signal:signal_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:extra_py_tests_deps", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:spectral_ops", + ], +) + tf_xla_py_test( name = "slice_ops_test", size = "small", @@ -279,6 +323,22 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "image_ops_test", + size = "small", + srcs = ["image_ops_test.py"], + tags = [ + "optonly", # Times out frequently in fastbuild mode. + ], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:image_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "lrn_ops_test", size = "medium", @@ -293,6 +353,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "matrix_band_part_test", + size = "medium", + srcs = ["matrix_band_part_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "momentum_test", size = "small", @@ -367,7 +440,9 @@ tf_xla_py_test( size = "small", srcs = ["random_ops_test.py"], # TODO(b/31361304): enable RNG ops on GPU when parallelized. - disabled_backends = ["gpu"], + disabled_backends = [ + "gpu", + ], deps = [ ":xla_test", "//tensorflow/python:framework_for_generated_wrappers", @@ -402,6 +477,19 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "reverse_sequence_op_test", + size = "medium", + srcs = ["reverse_sequence_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "rmsprop_test", size = "small", @@ -416,6 +504,20 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "scan_ops_test", + size = "small", + srcs = ["scan_ops_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "segment_reduction_ops_test", size = "medium", @@ -538,6 +640,7 @@ tf_xla_py_test( name = "variable_ops_test", size = "small", srcs = ["variable_ops_test.py"], + tags = ["optonly"], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -564,6 +667,31 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "gather_nd_op_test", + size = "medium", + srcs = ["gather_nd_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + +tf_xla_py_test( + name = "scatter_nd_op_test", + size = "medium", + srcs = ["scatter_nd_op_test.py"], + tags = ["optonly"], + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "xla_device_test", size = "small", @@ -688,6 +816,17 @@ tf_library( tfcompile_flags = ["--xla_cpu_multi_thread_eigen=false"], ) +tf_xla_py_test( + name = "fake_quant_ops_test", + size = "medium", + srcs = ["fake_quant_ops_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 654dc15e86b21c7742d49281d53c1a75e6a45d3b..30a6d3a74d64f90ad33062df6d1e16e3a575bd63 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -43,7 +43,7 @@ class BinaryOpsTest(XLATestCase): output = op(pa, pb) result = session.run(output, {pa: a, pb: b}) if equality_test is None: - equality_test = self.assertAllClose + equality_test = self.assertAllCloseAccordingToType equality_test(result, expected, rtol=1e-3) def _testSymmetricBinary(self, op, a, b, expected, equality_test=None): @@ -54,14 +54,20 @@ class BinaryOpsTest(XLATestCase): """Tests closeness of two lists of floats.""" self.assertEqual(len(result), len(expected)) for i in range(len(result)): - self.assertAllClose(result[i], expected[i], rtol) + self.assertAllCloseAccordingToType(result[i], expected[i], rtol) def testFloatOps(self): for dtype in self.float_types: + if dtype == dtypes.bfloat16.as_numpy_dtype: + a = -1.01 + b = 4.1 + else: + a = -1.001 + b = 4.01 self._testBinary( lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), - np.array([[[[-1, 2.00009999], [-3, 4.01]]]], dtype=dtype), - np.array([[[[-1.001, 2], [-3.00009, 4]]]], dtype=dtype), + np.array([[[[-1, 2.00009999], [-3, b]]]], dtype=dtype), + np.array([[[[a, 2], [-3.00009, 4]]]], dtype=dtype), expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) self._testBinary( @@ -94,14 +100,12 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.atan2, - np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), - np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), - expected=np.array( - [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) + self._testBinary( + math_ops.atan2, + np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), + np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), + expected=np.array( + [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) self._testBinary( gen_math_ops._reciprocal_grad, @@ -388,30 +392,28 @@ class BinaryOpsTest(XLATestCase): ], dtype=dtype)) - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._testBinary( - math_ops.pow, - dtype(3 + 2j), - dtype(4 - 5j), - expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) - self._testBinary( # empty rhs - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[0, 2], dtype=dtype), - expected=np.zeros(shape=[0, 2], dtype=dtype)) - self._testBinary( # to zero power - math_ops.pow, - np.array([1 + 2j, 2 - 3j], dtype=dtype), - np.zeros(shape=[1, 2], dtype=dtype), - expected=np.ones(shape=[1, 2], dtype=dtype)) - lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) - rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) - scalar = dtype(2 + 2j) - self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) - self._testBinary( - math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) - self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) + self._testBinary( + math_ops.pow, + dtype(3 + 2j), + dtype(4 - 5j), + expected=np.power(dtype(3 + 2j), dtype(4 - 5j))) + self._testBinary( # empty rhs + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[0, 2], dtype=dtype), + expected=np.zeros(shape=[0, 2], dtype=dtype)) + self._testBinary( # to zero power + math_ops.pow, + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.zeros(shape=[1, 2], dtype=dtype), + expected=np.ones(shape=[1, 2], dtype=dtype)) + lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype) + rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype) + scalar = dtype(2 + 2j) + self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs)) + self._testBinary( + math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs)) + self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar)) lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) @@ -421,9 +423,8 @@ class BinaryOpsTest(XLATestCase): self._testBinary( gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) - if atan2_supported: - self._testBinary( - gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) + self._testBinary( + gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2) self._testBinary( gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) @@ -547,7 +548,7 @@ class BinaryOpsTest(XLATestCase): self._testDivision(dtype) def testFloatDivision(self): - for dtype in self.float_types + self.complex_types: + for dtype in self.float_types | self.complex_types: self._testDivision(dtype) def _testRemainder(self, dtype): @@ -773,15 +774,15 @@ class BinaryOpsTest(XLATestCase): def DISABLED_testSparseMatMul(self): # Binary wrappers for sparse_matmul with different hints def SparseMatmulWrapperTF(a, b): - return tf.sparse_matmul(a, b, a_is_sparse=True) + return math_ops.sparse_matmul(a, b, a_is_sparse=True) def SparseMatmulWrapperFT(a, b): - return tf.sparse_matmul(a, b, b_is_sparse=True) + return math_ops.sparse_matmul(a, b, b_is_sparse=True) def SparseMatmulWrapperTT(a, b): - return tf.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True) + return math_ops.sparse_matmul(a, b, a_is_sparse=True, b_is_sparse=True) - self._testMatMul(tf.sparse_matmul) + self._testMatMul(math_ops.sparse_matmul) self._testMatMul(SparseMatmulWrapperTF) self._testMatMul(SparseMatmulWrapperFT) self._testMatMul(SparseMatmulWrapperTT) @@ -1180,6 +1181,50 @@ class BinaryOpsTest(XLATestCase): np.array([4, 5, 6], dtype=np.int32), expected=None) + def testMatrixSetDiag(self): + for dtype in self.numeric_types: + # Square + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]], + dtype=dtype), + np.array([1.0, 2.0, 3.0], dtype=dtype), + expected=np.array([[1.0, 1.0, 0.0], [1.0, 2.0, 1.0], [1.0, 1.0, 3.0]], + dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0], [1.0, 0.0, 3.0]], + [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0], [2.0, 0.0, 6.0]]], + dtype=dtype), + np.array([[-1.0, 0.0, -3.0], [-4.0, -5.0, -6.0]], dtype=dtype), + expected=np.array( + [[[-1.0, 0.0, 3.0], [0.0, 0.0, 0.0], [1.0, 0.0, -3.0]], + [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0], [2.0, 0.0, -6.0]]], + dtype=dtype)) + + # Rectangular + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0, 0.0], [1.0, 0.0, 1.0]], dtype=dtype), + np.array([3.0, 4.0], dtype=dtype), + expected=np.array([[3.0, 1.0, 0.0], [1.0, 4.0, 1.0]], dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[0.0, 1.0], [1.0, 0.0], [1.0, 1.0]], dtype=dtype), + np.array([3.0, 4.0], dtype=dtype), + expected=np.array([[3.0, 1.0], [1.0, 4.0], [1.0, 1.0]], dtype=dtype)) + + self._testBinary( + array_ops.matrix_set_diag, + np.array([[[1.0, 0.0, 3.0], [0.0, 2.0, 0.0]], + [[4.0, 0.0, 4.0], [0.0, 5.0, 0.0]]], dtype=dtype), + np.array([[-1.0, -2.0], [-4.0, -5.0]], + dtype=dtype), + expected=np.array([[[-1.0, 0.0, 3.0], [0.0, -2.0, 0.0]], + [[-4.0, 0.0, 4.0], [0.0, -5.0, 0.0]]], + dtype=dtype)) if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py index 5e06f9a72401935b9681c35a164b51f50a8538ae..035cdea1786d39f3d21bb63be5c8ccffe1608bdf 100644 --- a/tensorflow/compiler/tests/categorical_op_test.py +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -35,6 +35,9 @@ from tensorflow.python.platform import googletest class CategoricalTest(XLATestCase): """Test cases for random-number generating operators.""" + def output_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + def _chi2(self, expected, actual): """Returns Chi2 GOF statistic.""" actual = np.asarray(actual) @@ -55,7 +58,8 @@ class CategoricalTest(XLATestCase): """ with self.test_session() as sess, self.test_scope(): random_seed.set_random_seed(1618) - op = random_ops.multinomial(logits, num_samples) + op = random_ops.multinomial(logits, num_samples, + output_dtype=dtypes.int32) d = sess.run(op) batch_size, num_classes = logits.shape @@ -73,11 +77,11 @@ class CategoricalTest(XLATestCase): return freqs_mat - def _testRngIsNotConstant(self, rng, dtype): + def _testRngIsNotConstant(self, rng, dtype, output_dtype): # Tests that 'rng' does not always return the same value. with self.test_session() as sess: with self.test_scope(): - x = rng(dtype) + x = rng(dtype, output_dtype) # The random-number generator, if working correctly, should produce the # same output multiple times with low probability. @@ -92,21 +96,25 @@ class CategoricalTest(XLATestCase): (not np.array_equal(y, w))) def testCategoricalIsNotConstant(self): - def rng(unused_dtype): - return random_ops.multinomial([[1., 1., 1.]], 10) + def rng(dtype, output_dtype): + return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10, + output_dtype=output_dtype) - dtype = dtypes.float32 - self._testRngIsNotConstant(rng, dtype) + dtype = np.float32 + for output_dtype in self.output_dtypes(): + self._testRngIsNotConstant(rng, dtype, output_dtype) def testCategoricalIsInRange(self): - for dtype in [dtypes.float32, dtypes.float64]: - with self.test_session() as sess: - with self.test_scope(): - x = random_ops.multinomial( - array_ops.ones(shape=[1, 20], dtype=dtype), 1000) - y = sess.run(x) - self.assertTrue((y >= 0).sum() == 1000) - self.assertTrue((y < 20).sum() == 1000) + for dtype in self.float_types: + for output_dtype in self.output_dtypes(): + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), 1000, + output_dtype=output_dtype) + y = sess.run(x) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) def testSamplingCorrectness(self): np.random.seed(1618) # Make it reproducible. diff --git a/tensorflow/compiler/tests/conv2d_test.py b/tensorflow/compiler/tests/conv2d_test.py index 0d617eb37c5d92c87abb0f996b731112257a2b80..62577b70ce96e220d79978f01614b2d9a3647680 100644 --- a/tensorflow/compiler/tests/conv2d_test.py +++ b/tensorflow/compiler/tests/conv2d_test.py @@ -34,7 +34,13 @@ from tensorflow.python.platform import googletest class Conv2DTest(XLATestCase): - def _VerifyValues(self, input_sizes, filter_sizes, stride, padding, expected): + def _VerifyValues(self, + input_sizes=None, + filter_sizes=None, + strides=None, + dilations=None, + padding=None, + expected=None): """Tests that tf.nn.conv2d produces the expected value. Args: @@ -42,7 +48,8 @@ class Conv2DTest(XLATestCase): [batch, input_rows, input_cols, input_depth]. filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, input_depth, output_depth]. - stride: Stride. + strides: Strides. + dilations: RHS dilations. padding: Padding type. expected: Expected output. """ @@ -50,73 +57,136 @@ class Conv2DTest(XLATestCase): total_size_2 = np.prod(filter_sizes) x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) x2 = np.arange(1, total_size_2 + 1, dtype=np.float32).reshape(filter_sizes) - strides = [1, stride, stride, 1] + strides = [1] + strides + [1] + if dilations is None: + dilations = [1, 1] + dilations = [1] + dilations + [1] with self.test_session() as sess: + t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) with self.test_scope(): - t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) - t2 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) out = nn_ops.conv2d( - t1, t2, strides=strides, padding=padding, data_format="NHWC") + t1, + t2, + strides=strides, + padding=padding, + data_format="NHWC", + dilations=dilations) value = sess.run(out, {t1: x1, t2: x2}) - self.assertArrayNear(expected, np.ravel(value), 1e-3) + self.assertAllClose(expected, value, 1e-3) def testConv2D1x1Filter(self): - expected_output = [ + expected_output = np.reshape([ 30.0, 36.0, 42.0, 66.0, 81.0, 96.0, 102.0, 126.0, 150.0, 138.0, 171.0, 204.0, 174.0, 216.0, 258.0, 210.0, 261.0, 312.0 - ] + ], [1, 2, 3, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 1, 3, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) def testConv2D2x2Filter(self): - expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0] + expected_output = np.reshape( + [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0], [1, 1, 2, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], - stride=1, + strides=[1, 1], + padding="VALID", + expected=expected_output) + + def testConv2D2x2Filter2x1Dilation(self): + expected_output = np.array([[[[72], [82], [92]], [[112], [122], [132]]]]) + self._VerifyValues( + input_sizes=[1, 4, 4, 1], + filter_sizes=[2, 2, 1, 1], + strides=[1, 1], + dilations=[2, 1], padding="VALID", expected=expected_output) def testConv2D1x2Filter(self): - expected_output = [ + expected_output = np.reshape([ 231.0, 252.0, 273.0, 384.0, 423.0, 462.0, 690.0, 765.0, 840.0, 843.0, 936.0, 1029.0 - ] + ], [1, 2, 2, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[1, 2, 3, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) def testConv2D2x2FilterStride2(self): - expected_output = [2271.0, 2367.0, 2463.0] + expected_output = np.reshape([2271.0, 2367.0, 2463.0], [1, 1, 1, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], - stride=2, + strides=[2, 2], padding="VALID", expected=expected_output) def testConv2D2x2FilterStride2Same(self): - expected_output = [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0] + expected_output = np.reshape( + [2271.0, 2367.0, 2463.0, 1230.0, 1305.0, 1380.0], [1, 1, 2, 3]) self._VerifyValues( input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) + def testConv2DEmptyDilation(self): + self._VerifyValues( + input_sizes=[0, 2, 3, 3], + filter_sizes=[1, 1, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=np.zeros([0, 2, 3, 3])) + + def testConv2D2x2FilterDilation(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[2, 2, 3, 3], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=np.reshape([2667, 2781, 2895], [1, 1, 1, 3])) + + def testConv2D1x2FilterDilation(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 3], + filter_sizes=[1, 2, 3, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=np.array([[[[231, 252, 273], [384, 423, 462]], + [[690, 765, 840], [843, 936, 1029]]]])) + + def testConv2DKernelSizeMatchesInputSizeDilation(self): + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=np.reshape([108, 128], [1, 1, 1, 2])) + class Conv2DBackpropInputTest(XLATestCase): - def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride, - padding, expected): + def _VerifyValues(self, + input_sizes=None, + filter_sizes=None, + out_backprop_sizes=None, + strides=None, + dilations=None, + padding=None, + expected=None): """Tests that gen_nn_ops.conv2d_backprop_input produces the expected output. Args: @@ -125,7 +195,8 @@ class Conv2DBackpropInputTest(XLATestCase): filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, input_depth, output_depth]. out_backprop_sizes: Output gradients tensor dimensions. - stride: Stride. + strides: Strides. + dilations: Dilations. padding: Padding type. expected: Expected output. """ @@ -134,21 +205,25 @@ class Conv2DBackpropInputTest(XLATestCase): x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(filter_sizes) x2 = np.arange( 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes) - strides = [1, stride, stride, 1] + strides = [1] + strides + [1] + if dilations is not None: + dilations = [1] + dilations + [1] with self.test_session() as sess: + t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): - t1 = array_ops.placeholder(dtypes.float32, shape=filter_sizes) - t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) out = gen_nn_ops.conv2d_backprop_input( input_sizes=input_sizes, filter=t1, out_backprop=t2, strides=strides, + dilations=dilations, padding=padding, data_format="NHWC") value = sess.run(out, {t1: x1, t2: x2}) - self.assertArrayNear(expected, np.ravel(value), 1e-3) + self.assertAllEqual(input_sizes, value.shape) + self.assertAllClose(expected, np.ravel(value), 1e-3) def testConv2D1x1Filter(self): expected_output = [ @@ -160,7 +235,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 4, 4, 3], filter_sizes=[1, 1, 3, 2], out_backprop_sizes=[1, 4, 4, 2], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -170,7 +245,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 1, 5, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -180,7 +255,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 1, 6, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -190,7 +265,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 1, 7, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -200,7 +275,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 2, 3, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -213,7 +288,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], out_backprop_sizes=[1, 1, 2, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -226,7 +301,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], out_backprop_sizes=[1, 2, 3, 3], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -236,7 +311,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 3, 3, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 3, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -246,7 +321,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 3, 3, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 3, 3, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -256,7 +331,7 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 3, 5, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 2, 2, 1], - stride=2, + strides=[2, 2], padding="VALID", expected=expected_output) @@ -266,15 +341,76 @@ class Conv2DBackpropInputTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) + def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): + self._VerifyValues( + input_sizes=[1, 3, 6, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 5, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=[1, 4, 7, 10, 13, 10, 0, 0, 0, 0, 0, 0, 3, 10, 17, 24, 31, 20]) + + def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 1, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=[1, 0, 2, 3, 0, 4]) + + def testConv2DEmptyBackpropInputDilation1x2(self): + self._VerifyValues( + input_sizes=[0, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[0, 1, 1, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=np.zeros([0])) + + def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): + # The GPU version of this test is not very stable. So adjusting the + # error threshold to 1e-4. + self._VerifyValues( + input_sizes=[1, 3, 2, 3], + filter_sizes=[2, 2, 3, 3], + out_backprop_sizes=[1, 1, 1, 3], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=[ + 14, 32, 50, 68, 86, 104, 0, 0, 0, 0, 0, 0, 122, 140, 158, 176, 194, + 212 + ]) + + def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + out_backprop_sizes=[1, 1, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=[5, 0, 11, 0, 0, 0, 17, 0, 23]) + class Conv2DBackpropFilterTest(XLATestCase): - def _VerifyValues(self, input_sizes, filter_sizes, out_backprop_sizes, stride, - padding, expected): + def _VerifyValues(self, + input_sizes=None, + filter_sizes=None, + out_backprop_sizes=None, + strides=None, + dilations=None, + padding=None, + expected=None): """Tests that gen_nn_ops.conv2d_backprop_filter produces the right output. Args: @@ -283,7 +419,8 @@ class Conv2DBackpropFilterTest(XLATestCase): filter_sizes: Filter tensor dimensions in [kernel_rows, kernel_cols, input_depth, output_depth]. out_backprop_sizes: Output gradients tensor dimensions. - stride: Stride. + strides: Stride. + dilations: Dilations. padding: Padding type. expected: Expected output. """ @@ -293,22 +430,26 @@ class Conv2DBackpropFilterTest(XLATestCase): x1 = np.arange(1, total_size_1 + 1, dtype=np.float32).reshape(input_sizes) x2 = np.arange( 1, total_size_2 + 1, dtype=np.float32).reshape(out_backprop_sizes) - strides = [1, stride, stride, 1] + strides = [1] + strides + [1] + if dilations is not None: + dilations = [1] + dilations + [1] with self.test_session() as sess: + t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) + t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) with self.test_scope(): - t1 = array_ops.placeholder(dtypes.float32, shape=input_sizes) - t2 = array_ops.placeholder(dtypes.float32, shape=out_backprop_sizes) tensor = gen_nn_ops.conv2d_backprop_filter( input=t1, filter_sizes=filter_sizes, out_backprop=t2, strides=strides, + dilations=dilations, padding=padding, data_format="NHWC") value = sess.run(tensor, {t1: x1, t2: x2}) - self.assertArrayNear(expected, np.ravel(value), 1e-3) + self.assertAllEqual(filter_sizes, value.shape) + self.assertAllClose(expected, np.ravel(value), 1e-3) def testConv2D1x1Filter(self): expected_output = [8056, 8432, 8312, 8704, 8568, 8976] @@ -316,7 +457,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 4, 4, 3], filter_sizes=[1, 1, 3, 2], out_backprop_sizes=[1, 4, 4, 2], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -326,7 +467,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 3, 3, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 3, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -336,7 +477,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -350,7 +491,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 3], filter_sizes=[2, 2, 3, 3], out_backprop_sizes=[1, 1, 2, 3], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -360,7 +501,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 5, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -370,7 +511,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 6, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -380,7 +521,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 7, 1], filter_sizes=[1, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=3, + strides=[3, 3], padding="VALID", expected=expected_output) @@ -390,7 +531,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 4, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=1, + strides=[1, 1], padding="VALID", expected=expected_output) @@ -400,7 +541,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 4, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 1, 4, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -410,7 +551,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 1, 4, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) @@ -420,7 +561,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 2, 3, 1], - stride=1, + strides=[1, 1], padding="SAME", expected=expected_output) @@ -430,7 +571,7 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 3, 5, 1], filter_sizes=[1, 3, 1, 1], out_backprop_sizes=[1, 2, 2, 1], - stride=2, + strides=[2, 2], padding="VALID", expected=expected_output) @@ -440,10 +581,64 @@ class Conv2DBackpropFilterTest(XLATestCase): input_sizes=[1, 2, 3, 1], filter_sizes=[2, 2, 1, 1], out_backprop_sizes=[1, 1, 2, 1], - stride=2, + strides=[2, 2], padding="SAME", expected=expected_output) + def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): + self._VerifyValues( + input_sizes=[1, 3, 6, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 5, 1], + strides=[1, 1], + dilations=[2, 1], + padding="VALID", + expected=[55, 70, 235, 250]) + + def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 1], + out_backprop_sizes=[1, 1, 1, 1], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=[1, 3, 4, 6]) + + def testConv2DEmptyBackpropFilterDilation1x2(self): + self._VerifyValues( + input_sizes=[1, 2, 3, 1], + filter_sizes=[2, 2, 1, 0], + out_backprop_sizes=[1, 1, 1, 0], + strides=[1, 1], + dilations=[1, 2], + padding="VALID", + expected=np.zeros([0])) + + def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): + self._VerifyValues( + input_sizes=[1, 3, 4, 3], + filter_sizes=[2, 2, 3, 3], + out_backprop_sizes=[1, 1, 2, 3], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=[ + 17, 22, 27, 22, 29, 36, 27, 36, 45, 47, 64, 81, 52, 71, 90, 57, 78, + 99, 137, 190, 243, 142, 197, 252, 147, 204, 261, 167, 232, 297, 172, + 239, 306, 177, 246, 315 + ]) + + def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): + self._VerifyValues( + input_sizes=[1, 3, 3, 1], + filter_sizes=[2, 2, 1, 2], + out_backprop_sizes=[1, 1, 1, 2], + strides=[1, 1], + dilations=[2, 2], + padding="VALID", + expected=[1, 2, 3, 6, 7, 14, 9, 18]) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/extract_image_patches_op_test.py b/tensorflow/compiler/tests/extract_image_patches_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0361702e7af778176daed941d64e61198090daf2 --- /dev/null +++ b/tensorflow/compiler/tests/extract_image_patches_op_test.py @@ -0,0 +1,134 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for ExtractImagePatches op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ExtractImagePatches(XLATestCase): + """Functional tests for ExtractImagePatches op.""" + + def _VerifyValues(self, image, ksizes, strides, rates, padding, patches): + """Tests input-output pairs for the ExtractImagePatches op. + + Args: + image: Input tensor with shape: [batch, in_rows, in_cols, depth]. + ksizes: Patch size specified as: [ksize_rows, ksize_cols]. + strides: Output strides, specified as [stride_rows, stride_cols]. + rates: Atrous rates, specified as [rate_rows, rate_cols]. + padding: Padding type. + patches: Expected output. + """ + ksizes = [1] + ksizes + [1] + strides = [1] + strides + [1] + rates = [1] + rates + [1] + + with self.test_session(): + image_placeholder = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + out_tensor = array_ops.extract_image_patches( + image_placeholder, + ksizes=ksizes, + strides=strides, + rates=rates, + padding=padding, + name="im2col") + feed_dict = {image_placeholder: image} + self.assertAllClose(patches, out_tensor.eval(feed_dict=feed_dict)) + + def testKsize1x1Stride1x1Rate1x1(self): + """Verifies that for 1x1 kernel the output equals the input.""" + # [2, 3, 4, 5] + image = np.reshape(range(120), [2, 3, 4, 5]) + # [2, 3, 4, 5] + patches = np.reshape(range(120), [2, 3, 4, 5]) + for padding in ["VALID", "SAME"]: + self._VerifyValues( + image, + ksizes=[1, 1], + strides=[1, 1], + rates=[1, 1], + padding=padding, + patches=patches) + + def testKsize1x1Stride2x3Rate1x1(self): + """Test for 1x1 kernel and strides.""" + # [2, 4, 5, 3] + image = np.reshape(range(120), [2, 4, 5, 3]) + # [2, 2, 2, 3] + patches = image[:, ::2, ::3, :] + for padding in ["VALID", "SAME"]: + self._VerifyValues( + image, + ksizes=[1, 1], + strides=[2, 3], + rates=[1, 1], + padding=padding, + patches=patches) + + def testKsize2x2Stride1x1Rate1x1Valid(self): + """Test for 2x2 kernel with VALID padding.""" + # [1, 2, 2, 1] + image = [[[[1], [2]], [[3], [4]]]] + # [1, 1, 1, 4] + patches = [[[[1, 2, 3, 4]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[1, 1], + padding="VALID", + patches=patches) + + def testKsize2x2Stride1x1Rate1x1Same(self): + """Test for 2x2 kernel with SAME padding.""" + # [1, 2, 2, 1] + image = [[[[1], [2]], [[3], [4]]]] + # [1, 2, 2, 4] + patches = [[[[1, 2, 3, 4], [2, 0, 4, 0]], [[3, 4, 0, 0], [4, 0, 0, 0]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[1, 1], + padding="SAME", + patches=patches) + + def testKsize2x2Stride1x1Rate2x2Valid(self): + """Test for 2x2 kernel with 2x2 dilation.""" + # [1, 2, 2, 1] + image = np.arange(16).reshape(1, 4, 4, 1).astype(np.float32) + # [1, 2, 2, 4] + patches = [[[[0, 2, 8, 10], [1, 3, 9, 11]], + [[4, 6, 12, 14], [5, 7, 13, 15]]]] + self._VerifyValues( + image, + ksizes=[2, 2], + strides=[1, 1], + rates=[2, 2], + padding="VALID", + patches=patches) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/fake_quant_ops_test.py b/tensorflow/compiler/tests/fake_quant_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dfe9400ef0f55ca011d4e23ba5d735899ca2e054 --- /dev/null +++ b/tensorflow/compiler/tests/fake_quant_ops_test.py @@ -0,0 +1,452 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.platform import googletest + + +class FakeQuantWithMinMaxArgsTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxArgs operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + expected = np.array( + [ + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, expected_nudged_input_max, + expected_nudged_input_max, expected_nudged_input_max + ], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + outputs = array_ops.fake_quant_with_min_max_args( + input_placeholder, + min=input_min, + max=input_max, + num_bits=num_bits, + narrow_range=narrow_range) + result = session.run(outputs, {input_placeholder: inputs}) + self.assertAllCloseAccordingToType( + result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxArgsGradientTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxArgsGradient operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + gradients = np.arange(1, len(inputs) + 1, dtype=np.float32) + expected_backprops = np.array( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + gradient_placeholder = array_ops.placeholder( + dtypes.float32, gradients.shape, name="gradients") + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + outputs = gen_array_ops.fake_quant_with_min_max_args_gradient( + gradient_placeholder, + input_placeholder, + min=input_min, + max=input_max, + num_bits=num_bits, + narrow_range=narrow_range) + backprops = session.run(outputs, { + gradient_placeholder: gradients, + input_placeholder: inputs + }) + self.assertAllCloseAccordingToType( + backprops, + expected_backprops, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxVarsTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxVars operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + expected = np.array( + [ + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min, expected_nudged_input_min, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step, + expected_nudged_input_max, expected_nudged_input_max, + expected_nudged_input_max, expected_nudged_input_max + ], + dtype=np.float32) + + with self.test_session() as session: + with self.test_scope(): + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min") + max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max") + outputs = array_ops.fake_quant_with_min_max_vars( + input_placeholder, + min_placeholder, + max_placeholder, + num_bits=num_bits, + narrow_range=narrow_range) + result = session.run( + outputs, { + input_placeholder: inputs, + min_placeholder: input_min, + max_placeholder: input_max + }) + self.assertAllCloseAccordingToType( + result, expected, rtol=1e-3, atol=1e-5, bfloat16_rtol=0.03) + + +class FakeQuantWithMinMaxVarsGradientTest(XLATestCase): + """Test cases for FakeQuantWithMinMaxVarsGradient operation.""" + + # 8 bits, wide range. + def testOp_with8BitsNoScalingNoNudging(self): + self._TestOp(0.0, 255.0, 8, False, 0.0, 255.0, 1.0) + + def testOp_with8BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 128.0, 8, False, 0.0, 127.5, 0.5) + + def testOp_with8BitsScalingAndNudgingUp(self): + self._TestOp(-128.0, -0.5, 8, False, -127.5, 0.0, 0.5) + + def testOp_with8BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 127.4, 8, False, 0.0, 127.5, 0.5) + + # 8 bits, narrow range. + def testOp_with8BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 254.0, 8, True, 0.0, 254.0, 1.0) + + def testOp_with8BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 127.1, 8, True, 0.0, 127.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-127.1, -0.1, 8, True, -127.0, 0.0, 0.5) + + def testOp_with8BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 126.9, 8, True, 0.0, 127.0, 0.5) + + # 7 bits, wide range. + def testOp_with7BitsNoScalingNoNudging(self): + self._TestOp(0.0, 127.0, 7, False, 0.0, 127.0, 1.0) + + def testOp_with7BitsScalingAndNudgingDown(self): + self._TestOp(0.5, 64.0, 7, False, 0.0, 63.5, 0.5) + + def testOp_with7BitsScalingAndNudgingUp(self): + self._TestOp(-64.0, -0.5, 7, False, -63.5, 0.0, 0.5) + + def testOp_with7BitsScalingAndNudgingBetween(self): + self._TestOp(-0.1, 63.4, 7, False, 0.0, 63.5, 0.5) + + # 7 bits, narrow range. + def testOp_with7BitsNarrowRangeNoScalingNoNudging(self): + self._TestOp(0.0, 126.0, 7, True, 0.0, 126.0, 1.0) + + def testOp_with7BitsNarrowRangeScalingAndNudgingDown(self): + self._TestOp(0.1, 63.1, 7, True, 0.0, 63.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingUp(self): + self._TestOp(-63.1, -0.1, 7, True, -63.0, 0.0, 0.5) + + def testOp_with7BitsNarrowRangeScalingAndNudgingBetween(self): + self._TestOp(-0.1, 62.9, 7, True, 0.0, 63.0, 0.5) + + def _TestOp(self, input_min, input_max, num_bits, narrow_range, + expected_nudged_input_min, expected_nudged_input_max, + expected_step): + inputs = np.array( + [ + expected_nudged_input_min - expected_step, + expected_nudged_input_min - 0.01, expected_nudged_input_min, + expected_nudged_input_min + 0.01, + expected_nudged_input_min + expected_step - 0.01, + expected_nudged_input_min + expected_step, + expected_nudged_input_min + expected_step + 0.01, + expected_nudged_input_max - 0.01, expected_nudged_input_max, + expected_nudged_input_max + 0.01, + expected_nudged_input_max + expected_step + ], + dtype=np.float32) + gradients = np.arange(1, len(inputs) + 1, dtype=np.float32) + expected_backprops_wrt_input = np.array( + [0.0, 0.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 0.0, 0.0], + dtype=np.float32) + expected_backprops_wrt_min = 1.0 + 2.0 + expected_backprops_wrt_max = 10.0 + 11.0 + + with self.test_session() as session: + with self.test_scope(): + gradient_placeholder = array_ops.placeholder( + dtypes.float32, gradients.shape, name="gradients") + input_placeholder = array_ops.placeholder( + dtypes.float32, inputs.shape, name="inputs") + min_placeholder = array_ops.placeholder(dtypes.float32, (), name="min") + max_placeholder = array_ops.placeholder(dtypes.float32, (), name="max") + outputs = array_ops.fake_quant_with_min_max_vars_gradient( + gradient_placeholder, + input_placeholder, + min_placeholder, + max_placeholder, + num_bits=num_bits, + narrow_range=narrow_range) + backprops_wrt_input, backprops_wrt_min, backprops_wrt_max = session.run( + outputs, { + gradient_placeholder: gradients, + input_placeholder: inputs, + min_placeholder: input_min, + max_placeholder: input_max + }) + self.assertAllCloseAccordingToType( + backprops_wrt_input, + expected_backprops_wrt_input, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + self.assertAllCloseAccordingToType( + backprops_wrt_min, + expected_backprops_wrt_min, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + self.assertAllCloseAccordingToType( + backprops_wrt_max, + expected_backprops_wrt_max, + rtol=1e-3, + atol=1e-5, + bfloat16_rtol=0.03) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/fft_test.py b/tensorflow/compiler/tests/fft_test.py new file mode 100644 index 0000000000000000000000000000000000000000..afb5fa4bb4fefe5bc2ecded826143ffc83c2b559 --- /dev/null +++ b/tensorflow/compiler/tests/fft_test.py @@ -0,0 +1,204 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for FFT via the XLA JIT.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np +import scipy.signal as sps + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.contrib.signal.python.ops import spectral_ops as signal +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops +from tensorflow.python.platform import googletest + +BATCH_DIMS = (3, 5) +RTOL = 0.02 # Eigen/cuFFT differ widely from np, especially for FFT3D +ATOL = 1e-3 + + +def pick_10(x): + x = list(x) + np.random.seed(123) + np.random.shuffle(x) + return x[:10] + + +def to_32bit(x): + if x.dtype == np.complex128: + return x.astype(np.complex64) + if x.dtype == np.float64: + return x.astype(np.float32) + return x + + +POWS_OF_2 = 2**np.arange(3, 12) +INNER_DIMS_1D = list((x,) for x in POWS_OF_2) +POWS_OF_2 = 2**np.arange(3, 8) # To avoid OOM on GPU. +INNER_DIMS_2D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2)) +INNER_DIMS_3D = pick_10(itertools.product(POWS_OF_2, POWS_OF_2, POWS_OF_2)) + + +class FFTTest(XLATestCase): + + def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected, + tf_method): + for indims in inner_dims: + print("nfft =", indims) + shape = BATCH_DIMS + indims + data = np.arange(np.prod(shape) * 2) / np.prod(indims) + np.random.seed(123) + np.random.shuffle(data) + data = np.reshape(data.astype(np.float32).view(np.complex64), shape) + data = to_32bit(complex_to_input(data)) + expected = to_32bit(input_to_expected(data)) + with self.test_session() as sess: + with self.test_scope(): + ph = array_ops.placeholder( + dtypes.as_dtype(data.dtype), shape=data.shape) + out = tf_method(ph) + value = sess.run(out, {ph: data}) + self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) + + def testContribSignalSTFT(self): + ws = 512 + hs = 128 + dims = (ws * 20,) + shape = BATCH_DIMS + dims + data = np.arange(np.prod(shape)) / np.prod(dims) + np.random.seed(123) + np.random.shuffle(data) + data = np.reshape(data.astype(np.float32), shape) + window = sps.get_window("hann", ws) + expected = sps.stft( + data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2] + expected = np.swapaxes(expected, -1, -2) + expected *= window.sum() # scipy divides by window sum + with self.test_session() as sess: + with self.test_scope(): + ph = array_ops.placeholder( + dtypes.as_dtype(data.dtype), shape=data.shape) + out = signal.stft(ph, ws, hs) + + value = sess.run(out, {ph: data}) + self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL) + + def testFFT(self): + self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft, + spectral_ops.fft) + + def testFFT2D(self): + self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.fft2, + spectral_ops.fft2d) + + def testFFT3D(self): + self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, + lambda x: np.fft.fftn(x, axes=(-3, -2, -1)), + spectral_ops.fft3d) + + def testIFFT(self): + self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.ifft, + spectral_ops.ifft) + + def testIFFT2D(self): + self._VerifyFftMethod(INNER_DIMS_2D, lambda x: x, np.fft.ifft2, + spectral_ops.ifft2d) + + def testIFFT3D(self): + self._VerifyFftMethod(INNER_DIMS_3D, lambda x: x, + lambda x: np.fft.ifftn(x, axes=(-3, -2, -1)), + spectral_ops.ifft3d) + + def testRFFT(self): + self._VerifyFftMethod( + INNER_DIMS_1D, np.real, lambda x: np.fft.rfft(x, n=x.shape[-1]), + lambda x: spectral_ops.rfft(x, fft_length=[x.shape[-1].value])) + + def testRFFT2D(self): + + def _tf_fn(x): + return spectral_ops.rfft2d( + x, fft_length=[x.shape[-2].value, x.shape[-1].value]) + + self._VerifyFftMethod( + INNER_DIMS_2D, np.real, + lambda x: np.fft.rfft2(x, s=[x.shape[-2], x.shape[-1]]), _tf_fn) + + def testRFFT3D(self): + + def _to_expected(x): + return np.fft.rfftn( + x, axes=(-3, -2, -1), s=[x.shape[-3], x.shape[-2], x.shape[-1]]) + + def _tf_fn(x): + return spectral_ops.rfft3d( + x, + fft_length=[x.shape[-3].value, x.shape[-2].value, x.shape[-1].value]) + + self._VerifyFftMethod(INNER_DIMS_3D, np.real, _to_expected, _tf_fn) + + def testIRFFT(self): + + def _tf_fn(x): + return spectral_ops.irfft(x, fft_length=[2 * (x.shape[-1].value - 1)]) + + self._VerifyFftMethod( + INNER_DIMS_1D, lambda x: np.fft.rfft(np.real(x), n=x.shape[-1]), + lambda x: np.fft.irfft(x, n=2 * (x.shape[-1] - 1)), _tf_fn) + + def testIRFFT2D(self): + + def _tf_fn(x): + return spectral_ops.irfft2d( + x, fft_length=[x.shape[-2].value, 2 * (x.shape[-1].value - 1)]) + + self._VerifyFftMethod( + INNER_DIMS_2D, + lambda x: np.fft.rfft2(np.real(x), s=[x.shape[-2], x.shape[-1]]), + lambda x: np.fft.irfft2(x, s=[x.shape[-2], 2 * (x.shape[-1] - 1)]), + _tf_fn) + + def testIRFFT3D(self): + + def _to_input(x): + return np.fft.rfftn( + np.real(x), + axes=(-3, -2, -1), + s=[x.shape[-3], x.shape[-2], x.shape[-1]]) + + def _to_expected(x): + return np.fft.irfftn( + x, + axes=(-3, -2, -1), + s=[x.shape[-3], x.shape[-2], 2 * (x.shape[-1] - 1)]) + + def _tf_fn(x): + return spectral_ops.irfft3d( + x, + fft_length=[ + x.shape[-3].value, x.shape[-2].value, 2 * (x.shape[-1].value - 1) + ]) + + self._VerifyFftMethod(INNER_DIMS_3D, _to_input, _to_expected, _tf_fn) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/compiler/tests/ftrl_test.py b/tensorflow/compiler/tests/ftrl_test.py index 7e3871312c86530b6d3cb0bbacc16c25d3469832..f9db4cf2017c0b4b6dc0cfeeda6dca7bb9d14f19 100644 --- a/tensorflow/compiler/tests/ftrl_test.py +++ b/tensorflow/compiler/tests/ftrl_test.py @@ -161,9 +161,9 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose( + self.assertAllCloseAccordingToType( np.array([-2.55607247, -3.98729396]), var0.eval(), 1e-5, 1e-5) - self.assertAllClose( + self.assertAllCloseAccordingToType( np.array([-0.28232238, -0.56096673]), var1.eval(), 1e-5, 1e-5) def testFtrlWithL1(self): @@ -189,10 +189,10 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-7.66718769, -10.91273689]), var0.eval(), - rtol=1e-4) - self.assertAllClose(np.array([-0.93460727, -1.86147261]), var1.eval(), - rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-7.66718769, -10.91273689]), var0.eval(), rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.93460727, -1.86147261]), var1.eval(), rtol=1e-4) def testFtrlWithL1_L2(self): for dtype in self.float_types: @@ -217,10 +217,10 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-0.24059935, -0.46829352]), var0.eval(), - rtol=1e-5) - self.assertAllClose(np.array([-0.02406147, -0.04830509]), var1.eval(), - rtol=1e-5) + self.assertAllCloseAccordingToType( + np.array([-0.24059935, -0.46829352]), var0.eval(), rtol=1e-5) + self.assertAllCloseAccordingToType( + np.array([-0.02406147, -0.04830509]), var1.eval(), rtol=1e-5) def testFtrlWithL1_L2_L2Shrinkage(self): """Test the new FTRL op with support for l2 shrinkage. @@ -244,18 +244,18 @@ class FtrlOptimizerTest(XLATestCase): ftrl_update = opt.apply_gradients(zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values - self.assertAllClose([1.0, 2.0], var0.eval()) - self.assertAllClose([4.0, 3.0], var1.eval()) + self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval()) + self.assertAllCloseAccordingToType([4.0, 3.0], var1.eval()) # Run 10 steps FTRL for _ in range(10): ftrl_update.run() # Validate updated params - self.assertAllClose(np.array([-0.21931979, -0.40642974]), var0.eval(), - rtol=1e-4) - self.assertAllClose(np.array([-0.0282721, -0.07188385]), var1.eval(), - rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.21931979, -0.40642974]), var0.eval(), rtol=1e-4) + self.assertAllCloseAccordingToType( + np.array([-0.0282721, -0.07188385]), var1.eval(), rtol=1e-4) # When variables are initialized with Zero, FTRL-Proximal has two properties: # 1. Without L1&L2 but with fixed learning rate, FTRL-Proximal is identical @@ -272,8 +272,8 @@ class FtrlOptimizerTest(XLATestCase): with self.test_session(), self.test_scope(): val2, val3 = self.equivAdagradTest_AdagradPart(steps, dtype) - self.assertAllClose(val0, val2, rtol=1e-4) - self.assertAllClose(val1, val3, rtol=1e-4) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-4) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-4) def testEquivGradientDescentwithoutRegularization(self): steps = 5 @@ -284,8 +284,8 @@ class FtrlOptimizerTest(XLATestCase): val2, val3 = self.equivGradientDescentTest_GradientDescentPart( steps, dtype) - self.assertAllClose(val0, val2, rtol=1e-5) - self.assertAllClose(val1, val3, rtol=1e-5) + self.assertAllCloseAccordingToType(val0, val2, rtol=1e-5) + self.assertAllCloseAccordingToType(val1, val3, rtol=1e-5) if __name__ == "__main__": diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 00a9c9a65ba03d099581a3ee0dbe32c33e111231..a80d69fa5f5099b8a8b67df0da9c92b957e9d194 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -155,7 +155,7 @@ class FusedBatchNormTest(XLATestCase): def testLearningWithGradientChecker(self): self._testLearning(True) - def testGradient(self): + def testGradientTraining(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. channel = 3 @@ -175,7 +175,7 @@ class FusedBatchNormTest(XLATestCase): var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") grad_x, grad_scale, grad_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( - grad, x, scale, mean, var, data_format="NHWC") + grad, x, scale, mean, var, data_format="NHWC", is_training=True) grad_x_val, grad_scale_val, grad_offset_val = sess.run( [grad_x, grad_scale, grad_offset], { @@ -193,6 +193,53 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + def testGradientInference(self): + # TODO(b/64270657): Use gradient_checker here in addition to comparing with + # this reference implementation. + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] + grad_val = np.random.random_sample(x_shape).astype(np.float32) + x_val = np.random.random_sample(x_shape).astype(np.float32) + scale_val = np.random.random_sample(scale_shape).astype(np.float32) + mean_val = np.random.random_sample(scale_shape).astype(np.float32) + var_val = np.random.random_sample(scale_shape).astype(np.float32) + + with self.test_session() as sess, self.test_scope(): + grad = array_ops.placeholder(np.float32, shape=x_shape, name="grad") + x = array_ops.placeholder(np.float32, shape=x_shape, name="x") + mean = array_ops.placeholder(np.float32, shape=scale_shape, name="mean") + var = array_ops.placeholder(np.float32, shape=scale_shape, name="var") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + with self.test_scope(): + out = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + grad_x, grad_scale, grad_offset, _, _ = out + + ref_x, ref_scale, ref_offset, _, _ = gen_nn_ops.fused_batch_norm_grad( + grad, x, scale, mean, var, data_format="NHWC", is_training=False) + + grad_x_val, grad_scale_val, grad_offset_val, = sess.run( + [grad_x, grad_scale, grad_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + grad_x_ref, grad_scale_ref, grad_offset_ref, = sess.run( + [ref_x, ref_scale, ref_offset], { + grad: grad_val, + x: x_val, + mean: mean_val, + var: var_val, + scale: scale_val + }) + + self.assertAllClose(grad_x_val, grad_x_ref, atol=1e-2) + self.assertAllClose(grad_scale_val, grad_scale_ref, atol=1e-2) + self.assertAllClose(grad_offset_val, grad_offset_ref, atol=1e-3) + if __name__ == "__main__": test.main() diff --git a/tensorflow/compiler/tests/gather_nd_op_test.py b/tensorflow/compiler/tests/gather_nd_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9378b1db7245c0da3e8298e7dcd972491616b0cd --- /dev/null +++ b/tensorflow/compiler/tests/gather_nd_op_test.py @@ -0,0 +1,147 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.gather_nd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class GatherNdTest(XLATestCase): + + def _runGather(self, params, indices): + with self.test_session(): + paramsp = array_ops.placeholder(params.dtype) + indicesp = array_ops.placeholder(indices.dtype) + with self.test_scope(): + gather_nd_t = array_ops.gather_nd(paramsp, indicesp) + feed_dict = {paramsp: params, indicesp: indices} + return gather_nd_t.eval(feed_dict=feed_dict) + + def testSimpleDtype(self): + for dtype in self.numeric_types: + self.assertAllEqual( + np.array([7, 7, 8], dtype=dtype), + self._runGather( + np.array([8, 1, 2, 3, 7, 5], dtype=dtype), + np.array([[4], [4], [0]], np.int32))) + + def testEmptyIndicesAndParamsOKButJustEmptyParamsFails(self): + with self.test_session(): + params = np.ones((3, 3), dtype=np.float32) + + indices_empty = np.empty((0, 2), dtype=np.int32) + gather_nd_ok_val = self._runGather(params, indices_empty) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + + indices_empty = np.empty((0, 1), dtype=np.int32) + gather_nd_ok_val = self._runGather(params, indices_empty) + self.assertAllClose(np.empty((0, 3), dtype=np.float32), gather_nd_ok_val) + + params_empty = np.empty((0, 3), dtype=np.float32) + indices_empty = np.empty((0, 2), dtype=np.int32) + gather_nd_ok_val = self._runGather(params_empty, indices_empty) + self.assertAllClose(np.empty((0,), dtype=np.float32), gather_nd_ok_val) + + params_empty = np.empty((0, 3), dtype=np.float32) + indices_nonempty = np.zeros((1, 2), dtype=np.int32) + with self.assertRaisesWithPredicateMatch( + errors.InvalidArgumentError, r"Gather dimension 0 is of size zero"): + self._runGather(params_empty, indices_nonempty) + + def testIndexScalar(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array([4, 1], dtype=np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array(7), gather_nd_val) + + def testParamsRankLargerThanIndexIndexScalarSlices(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array( + [ + 4, + ], dtype=np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array([-7, 7]), gather_nd_val) + + def testParamsRankLargerThanIndexSlices(self): + params = np.array( + [[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], dtype=np.float32).T + indices = np.array([[4], [4], [0]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(np.array([[-7, 7], [-7, 7], [-8, 8]]), gather_nd_val) + + def testHigherRankParamsLargerThanIndexSlices(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[4], [4], [0]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(params[[4, 4, 0]], gather_nd_val) + + def testEmptyIndicesLastRankMeansCopyEntireTensor(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[], []], dtype=np.int32) # Size (2, 0) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual( + np.vstack((params[np.newaxis, :], params[np.newaxis, :])), + gather_nd_val) + + def testHigherRankParamsAndIndicesLargerThanIndexSlices(self): + params = np.array( + [[[-8, -1, -2, -3, -7, -5], [8, 1, 2, 3, 7, 5]], + [[-80, -10, -20, -30, -70, -50], [80, 10, 20, 30, 70, 50]]], + dtype=np.float32).T + indices = np.array([[[3], [2], [1]], [[4], [4], [0]]], np.int32) + gather_nd_val = self._runGather(params, indices) + self.assertAllEqual(params[[3, 2, 1, 4, 4, 0]].reshape(2, 3, 2, 2), + gather_nd_val) + + def testHigherRankParams(self): + shape = (10, 20, 5, 1, 17) + params = np.random.rand(*shape).astype(np.float32) + indices = np.vstack( + [np.random.randint(0, s, size=2000, dtype=np.int32) for s in shape]).T + gather_nd_val = self._runGather(params, indices) + + expected = params[tuple(indices.T)] + self.assertAllEqual(expected, gather_nd_val) + + def testHigherRankParamsAndIndices(self): + shape = (10, 20, 5, 1, 17) + params = np.random.rand(*shape).astype(np.float32) + indices = np.vstack( + [np.random.randint(0, s, size=2000, dtype=np.int32) for s in shape]).T + indices_reshaped = indices.reshape([10, 10, 20, 5]) + gather_nd_val = self._runGather(params, indices_reshaped) + expected = params[tuple(indices.T)] + self.assertAllEqual(expected.reshape([10, 10, 20]), gather_nd_val) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 13cbe6f312f5175edaec28fa7a8f28064194b0e9..1a8c4519118f69ce51ca9a5eb95a9d706c7766cc 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -122,6 +122,20 @@ class GatherTest(xla_test.XLATestCase): gather_np = np.take(params, indices, axis=axis) self.assertAllEqual(gather_np, gather_value) + def testIndicesWithDifferentDimensions(self): + with self.test_session(): + for dtype in self.numeric_tf_types: + params = array_ops.placeholder(dtype=dtype) + indices = array_ops.placeholder(dtype=np.int32) + with self.test_scope(): + gather = array_ops.gather(params, indices) + self.assertAllEqual( + 7, gather.eval(feed_dict={params: [4, 7, 2], indices: 1})) + self.assertAllEqual( + [7], gather.eval(feed_dict={params: [4, 7, 2], indices: [1]})) + self.assertAllEqual( + [[7]], gather.eval(feed_dict={params: [4, 7, 2], indices: [[1]]})) + class GatherBenchmark(test.Benchmark): """Microbenchmarks for the gather op.""" diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..538fa8e8e570b83ed681ecc0501285520cabdecb --- /dev/null +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -0,0 +1,552 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for image ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import colorsys +import math + +import numpy as np + +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_image_ops +from tensorflow.python.ops import image_ops +from tensorflow.python.platform import test + + +class RGBToHSVTest(XLATestCase): + + def testBatch(self): + # Build an arbitrary RGB image + np.random.seed(7) + batch_size = 5 + shape = (batch_size, 2, 7, 3) + + for nptype in self.float_types: + inp = np.random.rand(*shape).astype(nptype) + + # Convert to HSV and back, as a batch and individually + with self.test_session() as sess: + batch0 = array_ops.placeholder(nptype, shape=shape) + with self.test_scope(): + batch1 = image_ops.rgb_to_hsv(batch0) + batch2 = image_ops.hsv_to_rgb(batch1) + split0 = array_ops.unstack(batch0) + with self.test_scope(): + split1 = list(map(image_ops.rgb_to_hsv, split0)) + split2 = list(map(image_ops.hsv_to_rgb, split1)) + join1 = array_ops.stack(split1) + join2 = array_ops.stack(split2) + batch1, batch2, join1, join2 = sess.run([batch1, batch2, join1, join2], + { + batch0: inp + }) + + # Verify that processing batch elements together is the same as separate + self.assertAllClose(batch1, join1) + self.assertAllClose(batch2, join2) + self.assertAllCloseAccordingToType(batch2, inp, bfloat16_atol=0.03) + + def testRGBToHSVRoundTrip(self): + data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + for nptype in self.float_types: + rgb_np = np.array(data, dtype=nptype).reshape([2, 2, 3]) / 255. + with self.test_session(): + placeholder = array_ops.placeholder(nptype) + with self.test_scope(): + hsv = image_ops.rgb_to_hsv(placeholder) + rgb = image_ops.hsv_to_rgb(hsv) + rgb_tf = rgb.eval(feed_dict={placeholder: rgb_np}) + self.assertAllCloseAccordingToType(rgb_tf, rgb_np, bfloat16_atol=0.03) + + def testRGBToHSVNumpy(self): + """Tests the RGB to HSV conversion matches a reference implementation.""" + for nptype in self.float_types: + rgb_flat = np.random.random(64 * 3).reshape((64, 3)).astype(nptype) + rgb_np = rgb_flat.reshape(4, 4, 4, 3) + hsv_np = np.array([ + colorsys.rgb_to_hsv( + r.astype(np.float64), g.astype(np.float64), b.astype(np.float64)) + for r, g, b in rgb_flat + ]) + hsv_np = hsv_np.reshape(4, 4, 4, 3) + with self.test_session(): + placeholder = array_ops.placeholder(nptype) + with self.test_scope(): + hsv_op = image_ops.rgb_to_hsv(placeholder) + hsv_tf = hsv_op.eval(feed_dict={placeholder: rgb_np}) + self.assertAllCloseAccordingToType(hsv_tf, hsv_np) + + +class AdjustContrastTest(XLATestCase): + + def _testContrast(self, x_np, y_np, contrast_factor): + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_np.shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = image_ops.adjust_contrast(flt_x, contrast_factor) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllClose(y_tf, y_np, 1e-6) + + def testFloatContrast(self): + x_shape = [1, 2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.float32).reshape(x_shape) / 255. + + y_data = [ + -45.25, -90.75, -92.5, 62.75, 169.25, 333.5, 28.75, -84.75, 349.5, + 134.75, 409.25, -116.5 + ] + y_np = np.array(y_data, dtype=np.float32).reshape(x_shape) / 255. + + self._testContrast(x_np, y_np, contrast_factor=2.0) + + def testBatchContrast(self): + x_shape = [2, 1, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + y_data = [0, 0, 0, 81, 200, 255, 10, 0, 255, 116, 255, 0] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + self._testContrast(x_np, y_np, contrast_factor=2.0) + + def _adjustContrastNp(self, x_np, contrast_factor): + mean = np.mean(x_np, (1, 2), keepdims=True) + y_np = mean + contrast_factor * (x_np - mean) + return y_np + + def _adjustContrastTf(self, x_np, contrast_factor): + with self.test_session(): + x = array_ops.placeholder(np.float32) + with self.test_scope(): + y = image_ops.adjust_contrast(x, contrast_factor) + y_tf = y.eval({x: x_np}) + return y_tf + + def testRandomContrast(self): + x_shapes = [ + [1, 2, 2, 3], + [2, 1, 2, 3], + [1, 2, 2, 3], + [2, 5, 5, 3], + [2, 1, 1, 3], + ] + for x_shape in x_shapes: + x_np = np.random.rand(*x_shape) * 255. + contrast_factor = np.random.rand() * 2.0 + 0.1 + y_np = self._adjustContrastNp(x_np, contrast_factor) + y_tf = self._adjustContrastTf(x_np, contrast_factor) + self.assertAllClose(y_tf, y_np, rtol=1e-5, atol=1e-5) + + +class AdjustHueTest(XLATestCase): + + def testAdjustNegativeHue(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + delta = -0.25 + y_data = [0, 13, 1, 54, 226, 59, 8, 234, 150, 255, 39, 1] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(flt_x, delta) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def testAdjustPositiveHue(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + delta = 0.25 + y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(flt_x, delta) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def testBatchAdjustHue(self): + x_shape = [2, 1, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + delta = 0.25 + y_data = [13, 0, 11, 226, 54, 221, 234, 8, 92, 1, 217, 255] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + flt_x = image_ops.convert_image_dtype(x, dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(flt_x, delta) + y = image_ops.convert_image_dtype(y, x.dtype, saturate=True) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def _adjustHueNp(self, x_np, delta_h): + self.assertEqual(x_np.shape[-1], 3) + x_v = x_np.reshape([-1, 3]) + y_v = np.ndarray(x_v.shape, dtype=x_v.dtype) + channel_count = x_v.shape[0] + for i in xrange(channel_count): + r = x_v[i][0] + g = x_v[i][1] + b = x_v[i][2] + h, s, v = colorsys.rgb_to_hsv(r, g, b) + h += delta_h + h = math.fmod(h + 10.0, 1.0) + r, g, b = colorsys.hsv_to_rgb(h, s, v) + y_v[i][0] = r + y_v[i][1] = g + y_v[i][2] = b + return y_v.reshape(x_np.shape) + + def _adjustHueTf(self, x_np, delta_h): + with self.test_session(): + x = array_ops.placeholder(dtypes.float32) + with self.test_scope(): + y = gen_image_ops.adjust_hue(x, delta_h) + y_tf = y.eval({x: x_np}) + return y_tf + + def testAdjustRandomHue(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + "all_random", + "rg_same", + "rb_same", + "gb_same", + "rgb_same", + ] + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + if test_style == "all_random": + pass + elif test_style == "rg_same": + x_np[..., 1] = x_np[..., 0] + elif test_style == "rb_same": + x_np[..., 2] = x_np[..., 0] + elif test_style == "gb_same": + x_np[..., 2] = x_np[..., 1] + elif test_style == "rgb_same": + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError("Invalid test style: %s" % (test_style)) + y_np = self._adjustHueNp(x_np, delta_h) + y_tf = self._adjustHueTf(x_np, delta_h) + self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-4) + + def testInvalidShapes(self): + fused = False + if not fused: + # The tests are known to pass with the fused adjust_hue. We will enable + # them when the fused implementation is the default. + return + x_np = np.random.rand(2, 3) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + fused = False + with self.assertRaisesRegexp(ValueError, "Shape must be at least rank 3"): + self._adjustHueTf(x_np, delta_h) + x_np = np.random.rand(4, 2, 4) * 255. + delta_h = np.random.rand() * 2.0 - 1.0 + with self.assertRaisesOpError("input must have 3 channels"): + self._adjustHueTf(x_np, delta_h) + + +class AdjustSaturationTest(XLATestCase): + + def _adjust_saturation(self, image, saturation_factor): + image = ops.convert_to_tensor(image, name="image") + orig_dtype = image.dtype + flt_image = image_ops.convert_image_dtype(image, dtypes.float32) + with self.test_scope(): + saturation_adjusted_image = gen_image_ops.adjust_saturation( + flt_image, saturation_factor) + return image_ops.convert_image_dtype(saturation_adjusted_image, orig_dtype) + + def testHalfSaturation(self): + x_shape = [2, 2, 3] + x_rgb_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_rgb_data, dtype=np.uint8).reshape(x_shape) + + saturation_factor = 0.5 + y_rgb_data = [6, 9, 13, 140, 180, 226, 135, 121, 234, 172, 255, 128] + y_np = np.array(y_rgb_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + y = self._adjust_saturation(x, saturation_factor) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def testTwiceSaturation(self): + x_shape = [2, 2, 3] + x_data = [0, 5, 13, 54, 135, 226, 37, 8, 234, 90, 255, 1] + x_np = np.array(x_data, dtype=np.uint8).reshape(x_shape) + + saturation_factor = 2.0 + y_data = [0, 5, 13, 0, 106, 226, 30, 0, 234, 89, 255, 0] + y_np = np.array(y_data, dtype=np.uint8).reshape(x_shape) + + with self.test_session(): + x = array_ops.placeholder(x_np.dtype, shape=x_shape) + y = self._adjust_saturation(x, saturation_factor) + y_tf = y.eval({x: x_np}) + self.assertAllEqual(y_tf, y_np) + + def _adjustSaturationNp(self, x_np, scale): + self.assertEqual(x_np.shape[-1], 3) + x_v = x_np.reshape([-1, 3]) + y_v = np.ndarray(x_v.shape, dtype=x_v.dtype) + channel_count = x_v.shape[0] + for i in xrange(channel_count): + r = x_v[i][0] + g = x_v[i][1] + b = x_v[i][2] + h, s, v = colorsys.rgb_to_hsv(r, g, b) + s *= scale + s = min(1.0, max(0.0, s)) + r, g, b = colorsys.hsv_to_rgb(h, s, v) + y_v[i][0] = r + y_v[i][1] = g + y_v[i][2] = b + return y_v.reshape(x_np.shape) + + def testAdjustRandomSaturation(self): + x_shapes = [ + [2, 2, 3], + [4, 2, 3], + [2, 4, 3], + [2, 5, 3], + [1000, 1, 3], + ] + test_styles = [ + "all_random", + "rg_same", + "rb_same", + "gb_same", + "rgb_same", + ] + with self.test_session(): + for x_shape in x_shapes: + for test_style in test_styles: + x_np = np.random.rand(*x_shape) * 255. + scale = np.random.rand() + if test_style == "all_random": + pass + elif test_style == "rg_same": + x_np[..., 1] = x_np[..., 0] + elif test_style == "rb_same": + x_np[..., 2] = x_np[..., 0] + elif test_style == "gb_same": + x_np[..., 2] = x_np[..., 1] + elif test_style == "rgb_same": + x_np[..., 1] = x_np[..., 0] + x_np[..., 2] = x_np[..., 0] + else: + raise AssertionError("Invalid test style: %s" % (test_style)) + y_baseline = self._adjustSaturationNp(x_np, scale) + x = array_ops.placeholder(dtypes.float32, shape=x_shape) + with self.test_scope(): + y_fused = self._adjust_saturation(x, + scale).eval(feed_dict={ + x: x_np + }) + self.assertAllClose(y_fused, y_baseline, rtol=2e-5, atol=1e-5) + + +class ResizeBilinearTest(XLATestCase): + + def _assertForwardOpMatchesExpected(self, + image_np, + target_shape, + expected=None): + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + image = array_ops.placeholder(image_np.dtype) + resized = gen_image_ops.resize_bilinear( + image, target_shape, align_corners=True) + out = sess.run(resized, {image: image_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllClose(expected[np.newaxis, :, :, np.newaxis], out) + + def _assertBackwardOpMatchesExpected(self, + grads_np, + input_shape=None, + dtype=None, + expected=None): + if input_shape is None: + self.fail("input_shape must be specified") + if expected is None: + self.fail("expected must be specified") + with self.test_session() as sess, self.test_scope(): + dtype = dtype or np.float32 + grads = array_ops.placeholder(np.float32) + resized = gen_image_ops._resize_bilinear_grad( + grads, + np.zeros([1, input_shape[0], input_shape[1], 1], dtype=dtype), + align_corners=True) + out = sess.run(resized, {grads: grads_np[np.newaxis, :, :, np.newaxis]}) + self.assertAllCloseAccordingToType(expected[np.newaxis, :, :, np.newaxis], + out) + + def testAlignCorners1x2To3x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [1, 1.5, 2], [1, 1.5, 2]], dtype=np.float32)) + + def testAlignCorners1x2To3x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32), + input_shape=[1, 2], + dtype=dtype, + expected=np.array([[9, 12]], dtype=np.float32)) + + def testAlignCorners2x2To1x1(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [1, 1], + expected=np.array([[1]], dtype=np.float32)) + + def testAlignCorners2x2To1x1Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7]], dtype=np.float32), + input_shape=[2, 2], + dtype=dtype, + expected=np.array([[7, 0], [0, 0]], dtype=np.float32)) + + def testAlignCorners2x2To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2], [3, 4]], dtype=dtype), [3, 3], + expected=np.array( + [[1, 1.5, 2], [2, 2.5, 3], [3, 3.5, 4]], dtype=np.float32)) + + def testAlignCorners2x2To3x3Grad(self): + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[2, 2], + expected=np.array([[5.25, 8.25], [14.25, 17.25]], dtype=np.float32)) + + def testAlignCorners3x3To2x2(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [2, 2], + expected=np.array([[1, 3], [7, 9]], dtype=np.float32)) + + def testAlignCorners3x3To2x2Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[7, 13], [22, 4]], dtype=np.float32), + input_shape=[3, 3], + dtype=dtype, + expected=np.array( + [[7, 0, 13], [0, 0, 0], [22, 0, 4]], dtype=np.float32)) + + def testAlignCorners4x4To3x3(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]], + dtype=dtype), [3, 3], + expected=np.array( + [[1, 2.5, 4], [7, 8.5, 10], [13, 14.5, 16]], dtype=np.float32)) + + def testAlignCorners4x4To3x3Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.float32), + input_shape=[4, 4], + dtype=dtype, + expected=np.array( + [[1, 1, 1, 3], [2, 1.25, 1.25, 3], [2, 1.25, 1.25, 3], + [7, 4, 4, 9]], + dtype=np.float32)) + + def testAlignCorners3x3To9x9(self): + for dtype in self.float_types: + self._assertForwardOpMatchesExpected( + np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=dtype), [9, 9], + expected=np.array( + [[1.0, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ + 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 + ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ + 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 + ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ + 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 + ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ + 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 + ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + dtype=np.float32)) + + def testAlignCorners3x3To9x9Grad(self): + for dtype in self.float_types: + self._assertBackwardOpMatchesExpected( + np.array( + [[1.00, 1.25, 1.50, 1.75, 2.00, 2.25, 2.50, 2.75, 3.00], [ + 1.75, 2.00, 2.25, 2.50, 2.75, 3.00, 3.25, 3.50, 3.75 + ], [2.50, 2.75, 3.00, 3.25, 3.50, 3.75, 4.00, 4.25, 4.50], [ + 3.25, 3.50, 3.75, 4.00, 4.25, 4.50, 4.75, 5.00, 5.25 + ], [4.00, 4.25, 4.50, 4.75, 5.00, 5.25, 5.50, 5.75, 6.00], [ + 4.75, 5.00, 5.25, 5.50, 5.75, 6.00, 6.25, 6.50, 6.75 + ], [5.50, 5.75, 6.00, 6.25, 6.50, 6.75, 7.00, 7.25, 7.50], [ + 6.25, 6.50, 6.75, 7.00, 7.25, 7.50, 7.75, 8.00, 8.25 + ], [7.00, 7.25, 7.50, 7.75, 8.00, 8.25, 8.50, 8.75, 9.00]], + dtype=np.float32), + input_shape=[3, 3], + dtype=dtype, + expected=np.array( + [[12.5, 27.5, 21.875], [42.5, 80.0, 57.5], [40.625, 72.5, 50]], + dtype=np.float32)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/matrix_band_part_test.py b/tensorflow/compiler/tests/matrix_band_part_test.py new file mode 100644 index 0000000000000000000000000000000000000000..29394f9ea5139b30f88f53de0469b27e37d79195 --- /dev/null +++ b/tensorflow/compiler/tests/matrix_band_part_test.py @@ -0,0 +1,64 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class MatrixBandPartTest(XLATestCase): + + def _testMatrixBandPart(self, dtype, shape): + with self.test_session(): + batch_shape = shape[:-2] + mat = np.ones(shape).astype(dtype) + batch_mat = np.tile(mat, batch_shape + [1, 1]) + for lower in -1, 0, 1, shape[-2] - 1: + for upper in -1, 0, 1, shape[-1] - 1: + band_np = mat + if lower >= 0: + band_np = np.triu(band_np, -lower) + if upper >= 0: + band_np = np.tril(band_np, upper) + if batch_shape: + band_np = np.tile(band_np, batch_shape + [1, 1]) + + placeholder = array_ops.placeholder(dtype) + with self.test_scope(): + band = array_ops.matrix_band_part( + placeholder, + constant_op.constant(lower, dtype=dtypes.int32), + constant_op.constant(upper, dtype=dtypes.int32)) + feed_dict = {placeholder: batch_mat} + self.assertAllEqual(band_np, band.eval(feed_dict=feed_dict)) + + def testMatrixBandPart(self): + for dtype in self.float_types: + for batch_shape in [[], [2,], [1, 3, 2]]: + for rows in 1, 2, 7: + for cols in 1, 2, 7: + self._testMatrixBandPart(dtype, batch_shape + [rows, cols]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..cccb7f5789dce39ef8c3d4b3a7573aaa983b3fbd --- /dev/null +++ b/tensorflow/compiler/tests/matrix_triangular_solve_op_test.py @@ -0,0 +1,130 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.MatrixTriangularSolve.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def MakePlaceholder(x): + return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) + + +class MatrixTriangularSolveOpTest(XLATestCase): + + def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, + placeholder_b, a, clean_a, b, verification, + atol): + feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b} + verification_np = sess.run(verification, feed_dict) + self.assertAllClose(b, verification_np, atol=atol) + + def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): + clean_a = np.tril(a) if lower else np.triu(a) + with self.test_session() as sess: + placeholder_a = MakePlaceholder(a) + placeholder_ca = MakePlaceholder(clean_a) + placeholder_b = MakePlaceholder(b) + with self.test_scope(): + x = linalg_ops.matrix_triangular_solve( + placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) + verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) + self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, + placeholder_b, a, clean_a, b, + verification, atol) + + def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4): + transp = lambda x: np.swapaxes(x, -1, -2) + for lower, adjoint in itertools.product([True, False], repeat=2): + self._VerifyTriangularSolve( + a if lower else transp(a), b, lower, adjoint, atol) + + def testBasic(self): + rng = np.random.RandomState(0) + a = np.tril(rng.randn(5, 5)) + b = rng.randn(5, 7) + for dtype in self.float_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBasicNotActuallyTriangular(self): + rng = np.random.RandomState(0) + a = rng.randn(5, 5) # the `a` matrix is not lower-triangular + b = rng.randn(5, 7) + for dtype in self.float_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBasicComplexDtypes(self): + rng = np.random.RandomState(0) + a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j) + b = rng.randn(5, 7) + rng.randn(5, 7) * 1j + for dtype in self.complex_types: + self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) + + def testBatch(self): + rng = np.random.RandomState(0) + shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)), + ((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))] + tuples = itertools.product(self.float_types, shapes) + for dtype, (a_shape, b_shape) in tuples: + n = a_shape[-1] + a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n) + b = rng.randn(*b_shape) + self._VerifyTriangularSolveCombo( + a.astype(dtype), b.astype(dtype), atol=1e-3) + + def testLarge(self): + n = 1024 + rng = np.random.RandomState(0) + a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n) + b = rng.randn(n, n) + self._VerifyTriangularSolve( + a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) + + def testNonSquareCoefficientMatrix(self): + rng = np.random.RandomState(0) + for dtype in self.float_types: + a = rng.randn(3, 4).astype(dtype) + b = rng.randn(4, 4).astype(dtype) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(a, b) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(a, b) + + def testWrongDimensions(self): + randn = np.random.RandomState(0).randn + for dtype in self.float_types: + lhs = constant_op.constant(randn(3, 3), dtype=dtype) + rhs = constant_op.constant(randn(4, 3), dtype=dtype) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(lhs, rhs) + with self.assertRaises(ValueError): + linalg_ops.matrix_triangular_solve(lhs, rhs) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/momentum_test.py b/tensorflow/compiler/tests/momentum_test.py index c00e3035a0982b2b2e59eb6f53499918515ae71d..af9394e7d7dc9cf7dd009420ff9c845aec8785bd 100644 --- a/tensorflow/compiler/tests/momentum_test.py +++ b/tensorflow/compiler/tests/momentum_test.py @@ -96,28 +96,27 @@ class MomentumOptimizerTest(XLATestCase): def testNesterovMomentum(self): for dtype in self.float_types: with self.test_session(), self.test_scope(): - var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype) - var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype) - var0_np = np.array([1.0, 2.0], dtype=dtype) - var1_np = np.array([3.0, 4.0], dtype=dtype) + var0 = resource_variable_ops.ResourceVariable([0.1, 0.2], dtype=dtype) + var1 = resource_variable_ops.ResourceVariable([0.3, 0.4], dtype=dtype) + var0_np = np.array([0.1, 0.2], dtype=dtype) + var1_np = np.array([0.3, 0.4], dtype=dtype) accum0_np = np.array([0.0, 0.0], dtype=dtype) accum1_np = np.array([0.0, 0.0], dtype=dtype) - cost = 5 * var0 * var0 + 3 * var1 + cost = 0.4 * var0 * var0 + 0.9 * var1 global_step = resource_variable_ops.ResourceVariable( array_ops.zeros([], dtypes.int32), name="global_step") mom_op = momentum_lib.MomentumOptimizer( - learning_rate=2.0, momentum=0.9, use_nesterov=True) + learning_rate=0.1, momentum=0.9, use_nesterov=True) opt_op = mom_op.minimize(cost, global_step, [var0, var1]) variables.global_variables_initializer().run() for _ in range(1, 5): opt_op.run() var0_np, accum0_np = self._update_nesterov_momentum_numpy( - var0_np, accum0_np, var0_np * 10, 2.0, 0.9) - var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np, - accum1_np, - 3, 2.0, 0.9) - self.assertAllClose(var0_np, var0.eval()) - self.assertAllClose(var1_np, var1.eval()) + var0_np, accum0_np, var0_np * 0.8, 0.1, 0.9) + var1_np, accum1_np = self._update_nesterov_momentum_numpy( + var1_np, accum1_np, 0.9, 0.1, 0.9) + self.assertAllCloseAccordingToType(var0_np, var0.eval()) + self.assertAllCloseAccordingToType(var1_np, var1.eval()) def testTensorLearningRateAndMomentum(self): for dtype in self.float_types: diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index 6a8c3bcd55a6e454a19b6249cf4eb48739c8657f..e72dd4eea9f127e1df96ab166103c4c16372adb6 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -93,11 +93,11 @@ class OpTestBuilder { public: explicit OpTestBuilder(const string& op_name); - // Adds an input 'tensor'. + // Adds an input 'tensor' as a Placeholder node. OpTestBuilder& Input(const Tensor& tensor); - // Adds a random input tensor with 'type'. If 'dims' is not provided, - // RandomDims() is used. + // Adds a random input tensor with 'type' as a Placeholder node. + // If 'dims' is not provided, RandomDims() is used. OpTestBuilder& RandomInput(DataType type); OpTestBuilder& RandomInput(DataType type, std::vector dims); @@ -998,6 +998,13 @@ TEST_F(OpTest, Atanh) { }); } +TEST_F(OpTest, Atan) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Atan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Atan2) { Repeatedly([this]() { auto dims = BroadcastableDims(); @@ -1368,6 +1375,121 @@ TEST_F(OpTest, Conj) { }); } +TEST_F(OpTest, FFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, FFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT2D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, FFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("FFT3D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT2D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, IFFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("IFFT3D").RandomInput(DT_COMPLEX64, dims)); + }); +} + +TEST_F(OpTest, RFFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor(AsInt32s({dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, RFFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor( + AsInt32s({dims[dims.size() - 2], dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT2D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, RFFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank, 3); + Tensor fft_shape = test::AsTensor(AsInt32s( + {dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]})); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("RFFT3D").RandomInput(DT_FLOAT, dims).Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, kDefaultMaxRank, 3); + int64 orig_size = dims[dims.size() - 1]; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT2D) { + Repeatedly([this]() { + std::vector dims = RandomDims(2, kDefaultMaxRank, 3); + std::vector orig_size = {dims[dims.size() - 2], + dims[dims.size() - 1]}; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT2D") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + +TEST_F(OpTest, IRFFT3D) { + Repeatedly([this]() { + std::vector dims = RandomDims(3, kDefaultMaxRank, 3); + std::vector orig_size = { + dims[dims.size() - 3], dims[dims.size() - 2], dims[dims.size() - 1]}; + dims[dims.size() - 1] = dims[dims.size() - 1] / 2 + 1; + Tensor fft_shape = test::AsTensor(AsInt32s({orig_size})); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("IRFFT3D") + .RandomInput(DT_COMPLEX64, dims) + .Input(fft_shape)); + }); +} + TEST_F(OpTest, Conv2D) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); @@ -1382,7 +1504,7 @@ TEST_F(OpTest, Conv2D) { std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2D") .RandomInput(type, data_dims) @@ -1407,7 +1529,7 @@ TEST_F(OpTest, Conv2DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropFilter") .RandomInput(type, activations) @@ -1433,7 +1555,7 @@ TEST_F(OpTest, Conv2DBackpropInput) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropInput") .Input(in_shape) @@ -1457,7 +1579,7 @@ TEST_F(OpTest, Conv3D) { std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out}; - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3D") .RandomInput(type, data) @@ -1482,7 +1604,7 @@ TEST_F(OpTest, Conv3DBackpropFilter) { Tensor kernel_shape = test::AsTensor( AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out})); - DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support + DataType type = DT_FLOAT; return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropFilterV2") .RandomInput(type, activations) @@ -2460,6 +2582,36 @@ TEST_F(OpTest, Reshape) { }); } +TEST_F(OpTest, ResizeBilinear) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinear") + .RandomInput(DT_FLOAT, in_dims) + .Input(test::AsTensor( + std::vector(out_dims.begin(), out_dims.end()))) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + +TEST_F(OpTest, ResizeBilinearGrad) { + Repeatedly([this]() { + std::vector in_dims = RandomDims(4, 4); + std::vector out_dims = RandomDims(2, 2); + + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ResizeBilinearGrad") + .RandomInput(DT_FLOAT, in_dims) + .RandomInput(DT_FLOAT, + {in_dims[0], out_dims[0], out_dims[1], in_dims[3]}) + .Attr("T", DT_FLOAT) + .Attr("align_corners", true)); + }); +} + TEST_F(OpTest, Reverse) { Repeatedly([this]() { std::vector dims = RandomDims(1); diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1a5d05094e53cfecd9476d7d87f023e8a02d7458 --- /dev/null +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -0,0 +1,93 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.reverse_sequence_op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ReverseSequenceTest(XLATestCase): + + def _testReverseSequence(self, + x, + batch_axis, + seq_axis, + seq_lengths, + truth, + expected_err_re=None): + with self.test_session(): + p = array_ops.placeholder(dtypes.as_dtype(x.dtype)) + lengths = array_ops.placeholder(dtypes.as_dtype(seq_lengths.dtype)) + with self.test_scope(): + ans = array_ops.reverse_sequence( + p, batch_axis=batch_axis, seq_axis=seq_axis, seq_lengths=lengths) + if expected_err_re is None: + tf_ans = ans.eval(feed_dict={p: x, lengths: seq_lengths}) + self.assertAllClose(tf_ans, truth, atol=1e-10) + else: + with self.assertRaisesOpError(expected_err_re): + ans.eval(feed_dict={p: x, lengths: seq_lengths}) + + def testSimple(self): + x = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=np.int32) + expected = np.array([[1, 2, 3], [6, 5, 4], [8, 7, 9]], dtype=np.int32) + self._testReverseSequence( + x, + batch_axis=0, + seq_axis=1, + seq_lengths=np.array([1, 3, 2], np.int32), + truth=expected) + + def _testBasic(self, dtype, len_dtype): + x = np.asarray( + [[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12], [13, 14, 15, 16]], + [[17, 18, 19, 20], [21, 22, 23, 24]]], + dtype=dtype) + x = x.reshape(3, 2, 4, 1, 1) + x = x.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + # reverse dim 2 up to (0:3, none, 0:4) along dim=0 + seq_lengths = np.asarray([3, 0, 4], dtype=len_dtype) + + truth_orig = np.asarray( + [ + [[3, 2, 1, 4], [7, 6, 5, 8]], # reverse 0:3 + [[9, 10, 11, 12], [13, 14, 15, 16]], # reverse none + [[20, 19, 18, 17], [24, 23, 22, 21]] + ], # reverse 0:4 (all) + dtype=dtype) + truth_orig = truth_orig.reshape(3, 2, 4, 1, 1) + truth = truth_orig.transpose([2, 1, 0, 3, 4]) # permute axes 0 <=> 2 + + seq_axis = 0 # permute seq_axis and batch_axis (originally 2 and 0, resp.) + batch_axis = 2 + self._testReverseSequence(x, batch_axis, seq_axis, seq_lengths, truth) + + def testSeqLength(self): + for dtype in self.all_types: + for seq_dtype in self.int_types: + self._testBasic(dtype, seq_dtype) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3260e63b23226d736a7ddc0f21a94a8c791e0442 --- /dev/null +++ b/tensorflow/compiler/tests/scan_ops_test.py @@ -0,0 +1,229 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for scan ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +def numpy_reverse(x, axis): + length = len(x.shape) + if axis < 0: + axis = length + axis + + ix = [ + slice(None, None, -1) if i == axis else slice(None) for i in range(length) + ] + return x[ix] + + +def handle_options(func, x, axis, exclusive, reverse): + """Adds tf options to numpy scan ops.""" + length = len(x.shape) + if axis < 0: + axis = length + axis + + if reverse: + x = numpy_reverse(x, axis) + + if exclusive: + ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)] + ix_init = [ + slice(0, -1) if i == axis else slice(None) for i in range(length) + ] + if func == np.cumsum: + init = np.zeros_like(x[ix_head]) + elif func == np.cumprod: + init = np.ones_like(x[ix_head]) + else: + raise ValueError("Unknown scan function.") + x = np.concatenate([init, func(x[ix_init], axis)], axis=axis) + else: + x = func(x, axis=axis) + + if reverse: + x = numpy_reverse(x, axis) + return x + + +class CumsumTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumsum, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval( + feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumsum(p, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 10).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumsum(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumsum(input_tensor, [0]).eval() + + +class CumprodTest(XLATestCase): + + valid_dtypes = [np.float32] + + def axis_dtypes(self): + return set(self.int_types).intersection([np.int32, np.int64]) + + def _compare(self, x, axis, exclusive, reverse): + np_out = handle_options(np.cumprod, x, axis, exclusive, reverse) + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + prod = math_ops.cumprod(p, axis, exclusive, reverse) + tf_out = prod.eval(feed_dict={p: x}) + + self.assertAllClose(np_out, tf_out) + + def _compareAll(self, x, axis): + for exclusive in [True, False]: + for reverse in [True, False]: + self._compare(x, axis, exclusive, reverse) + + def testEmpty(self): + for dtype in self.valid_dtypes: + x = np.zeros([0]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def testAxisType(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis_dtype in self.axis_dtypes(): + with self.test_session(), self.test_scope(): + p = array_ops.placeholder(x.dtype) + axis = constant_op.constant(0, axis_dtype) + math_ops.cumprod(x, axis).eval(feed_dict={p: x}) + + def test1D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 6).reshape([5]).astype(dtype) + for axis in (-1, 0): + self._compareAll(x, axis) + + def test2D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 11).reshape([2, 5]).astype(dtype) + for axis in (-2, -1, 0, 1): + self._compareAll(x, axis) + + def test3D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype) + for axis in (-3, -2, -1, 0, 1, 2): + self._compareAll(x, axis) + + def test6D(self): + for dtype in self.valid_dtypes: + x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype) + for axis in range(-6, 6, 3): + self._compareAll(x, axis) + + def testInvalidAxis(self): + x = np.arange(0, 10).reshape([2, 5]).astype(np.float32) + with self.test_session(), self.test_scope(): + input_tensor = ops.convert_to_tensor(x) + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, -3).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "Expected scan axis in the range [-2, 2)" in str(e)): + math_ops.cumprod(input_tensor, 2).eval() + with self.assertRaisesWithPredicateMatch( + errors_impl.InvalidArgumentError, + lambda e: "axis must be a scalar" in str(e)): + math_ops.cumprod(input_tensor, [0]).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/scatter_nd_op_test.py b/tensorflow/compiler/tests/scatter_nd_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..638946e234daf28dc4a34e6c33fc0f78b8e8699b --- /dev/null +++ b/tensorflow/compiler/tests/scatter_nd_op_test.py @@ -0,0 +1,188 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tensorflow.ops.tf.scatter_nd.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +def _AsType(v, vtype): + return v.astype(vtype) if isinstance(v, np.ndarray) else vtype(v) + + +def _FlatInnerDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape( + [functools.reduce(lambda x, y: x * y, shape[:-ndims + 1], 1)] + + shape[-ndims + 1:]) + + +def _FlatOuterDims(tensor, ndims=2): + shape = list(tensor.shape) + return tensor.reshape( + shape[:ndims - 1] + + [functools.reduce(lambda x, y: x * y, shape[ndims - 1:], 1)]) + + +def _NumpyScatterNd(ref, indices, updates, op): + ixdim = indices.shape[-1] + num_updates = indices.size // ixdim + total_nd = len(ref.shape) + slice_size = 1 + for i in range(ixdim, total_nd): + slice_size *= ref.shape[i] + flat_indices = _FlatInnerDims(indices) + flat_updates = updates.reshape((num_updates, slice_size)) + output_flat = _FlatOuterDims(ref, ixdim + 1) + for ix_updates, ix_output in enumerate(flat_indices): + ix_output = tuple(ix_output) + output_flat[ix_output] = op(output_flat[ix_output], + flat_updates[ix_updates]) + return output_flat.reshape(ref.shape) + + +def _NumpyUpdate(indices, updates, shape): + ref = np.zeros(shape, dtype=updates.dtype) + return _NumpyScatterNd(ref, indices, updates, lambda p, u: u) + + +class ScatterNdTest(XLATestCase): + + def _VariableRankTest(self, + np_scatter, + tf_scatter, + vtype, + itype, + repeat_indices=False): + np.random.seed(8) + ref_shapes = [(3, 6), (3, 6), (3, 6, 9), (3, 6, 9), (3, 6, 9), (3, 6, 9)] + indices_shapes = [(2,), (2, 2), (2,), (2, 2), (2, 3), (2, 3, 3)] + for ref_shape, indices_shape in zip(ref_shapes, indices_shapes): + num_updates = indices_shape[0] + ixdim = indices_shape[-1] + + indexable_area_shape = () + for i in range(ixdim): + indexable_area_shape += (ref_shape[i],) + all_indices = [ + list(coord) + for coord, _ in np.ndenumerate(np.empty(indexable_area_shape, vtype)) + ] + np.random.shuffle(all_indices) + indices = np.array(all_indices[:num_updates]) + + if num_updates > 1 and repeat_indices: + indices = indices[:num_updates // 2] + for _ in range(num_updates - num_updates // 2): + indices = np.append( + indices, [indices[np.random.randint(num_updates // 2)]], axis=0) + np.random.shuffle(indices) + indices = _AsType(indices[:num_updates], itype) + + updates_shape = (num_updates,) + for i in range(ixdim, len(ref_shape)): + updates_shape += (ref_shape[i],) + updates = _AsType(np.random.randn(*(updates_shape)), vtype) + + # Scatter via numpy + np_out = np_scatter(indices, updates, ref_shape) + # Scatter via tensorflow + tf_out = tf_scatter(indices, updates, ref_shape) + + self.assertAllClose(np_out, tf_out) + + def _VariableRankTests(self, np_scatter, tf_scatter): + for vtype in self.numeric_types: + for itype in set([np.int32, np.int64]).intersection(set(self.int_types)): + self._VariableRankTest(np_scatter, tf_scatter, vtype, itype) + + def _runScatterNd(self, indices, updates, shape): + with self.test_session(): + updates_placeholder = array_ops.placeholder(updates.dtype) + indices_placeholder = array_ops.placeholder(indices.dtype) + with self.test_scope(): + output = array_ops.scatter_nd(indices_placeholder, updates_placeholder, + shape) + feed_dict = {updates_placeholder: updates, indices_placeholder: indices} + return output.eval(feed_dict=feed_dict) + + def testSimple(self): + indices = np.array([[4], [3], [1], [7]], dtype=np.int32) + updates = np.array([9, 10, 11, 12], dtype=np.float32) + expected = np.array([0, 11, 0, 10, 9, 0, 0, 12], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [8])) + + def testSimple2(self): + indices = np.array([[1, 0], [1, 1]], dtype=np.int32) + updates = np.array([11., 12.], dtype=np.float32) + expected = np.array([[0., 0.], [11., 12.], [0., 0.]], dtype=np.float32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2])) + + def testSimple3(self): + indices = np.array([[1]], dtype=np.int32) + updates = np.array([[11., 12.]], dtype=np.float32) + expected = np.array([[0., 0.], [11., 12.], [0., 0.]]) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [3, 2])) + + def testVariableRankUpdate(self): + self._VariableRankTests(_NumpyUpdate, self._runScatterNd) + + def testExtraIndicesDimensions(self): + indices = np.zeros([1, 1, 2], np.int32) + updates = np.zeros([1, 1], np.int32) + expected = np.zeros([2, 2], dtype=np.int32) + self.assertAllEqual(expected, self._runScatterNd(indices, updates, [2, 2])) + + def testRank3InvalidShape1(self): + indices = np.zeros([3, 2, 2], np.int32) + updates = np.zeros([2, 2, 2], np.int32) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Must have updates.shape"): + self._runScatterNd(indices, updates, [2, 2, 2]) + + def testRank3InvalidShape2(self): + indices = np.zeros([2, 2, 1], np.int32) + updates = np.zeros([2, 2], np.int32) + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Must have updates.shape"): + self._runScatterNd(indices, updates, [2, 2, 2]) + + def testScatterOutOfRange(self): + updates = np.array([-3, -4, -5]).astype(np.float32) + + # Indices all in range, no problem. + indices = np.array([[2], [0], [5]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + + # Indices out of range should not fail. It produces implementation-defined + # output. + indices = np.array([[-1], [0], [5]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + indices = np.array([[2], [0], [6]], dtype=np.int32) + self._runScatterNd(indices, updates, [6]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/segment_reduction_ops_test.py b/tensorflow/compiler/tests/segment_reduction_ops_test.py index 260a04421b62310c109d8f0ea72875a50c234bb0..23bc39cf3f7087424719edfb8b6ee35d87295534 100644 --- a/tensorflow/compiler/tests/segment_reduction_ops_test.py +++ b/tensorflow/compiler/tests/segment_reduction_ops_test.py @@ -60,6 +60,14 @@ class SegmentReductionOpsTest(XLATestCase): np.array([0, 1, 2, 3, 4, 5], dtype=dtype), np.array([3, 0, 2, 1, 3, 3], dtype=np.int32), 4)) + def testUnsortedSegmentSum1DIndices1DDataNegativeIndices(self): + for dtype in self.numeric_types: + self.assertAllClose( + np.array([0, 3, 2, 5], dtype=dtype), + self.UnsortedSegmentSum( + np.array([0, 1, 2, 3, 4, 5], dtype=dtype), + np.array([3, -1, 2, 1, -1, 3], dtype=np.int32), 4)) + def testUnsortedSegmentSum1DIndices2DDataDisjoint(self): for dtype in self.numeric_types: data = np.array( diff --git a/tensorflow/compiler/tests/tensor_array_ops_test.py b/tensorflow/compiler/tests/tensor_array_ops_test.py index ac039e01623b954e291760fb9b50ef8eae3da7c1..a62925a1818da00cb0a9e82e1281db20fb38b208 100644 --- a/tensorflow/compiler/tests/tensor_array_ops_test.py +++ b/tensorflow/compiler/tests/tensor_array_ops_test.py @@ -330,8 +330,7 @@ class TensorArrayTest(xla_test.XLATestCase): # Find two different floating point types, create an array of # the first type, but try to read the other type. if len(self.float_types) > 1: - dtype1 = self.float_types[0] - dtype2 = self.float_types[1] + dtype1, dtype2 = list(self.float_types)[:2] with self.test_session(), self.test_scope(): ta = tensor_array_ops.TensorArray( dtype=dtype1, tensor_array_name="foo", size=3) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index a9a3f4f97f649260e9863fff8ff05d046bd91947..3d3e112f4821ea8e57ea9589a5b4433647ad294b 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -33,6 +33,17 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.platform import googletest +def nhwc_to_format(x, data_format): + """Converts a numpy array from NHWC format to `data_format`.""" + rank = len(x.shape) + if data_format == "NCHW": + return np.transpose(x, [0, rank - 1] + list(range(1, rank - 1))) + elif data_format == "NHWC": + return x + else: + raise ValueError("Unknown format {}".format(data_format)) + + class UnaryOpsTest(XLATestCase): """Test cases for unary operators.""" @@ -56,8 +67,10 @@ class UnaryOpsTest(XLATestCase): output = op(pinp) result = session.run(output, {pinp: inp}) if equality_test is None: - equality_test = self.assertAllClose - equality_test(result, expected, rtol=rtol, atol=atol) + self.assertAllCloseAccordingToType( + result, expected, rtol=rtol, atol=atol, bfloat16_rtol=0.03) + else: + equality_test(result, expected, rtol=rtol, atol=atol) def ListsAreClose(self, result, expected, rtol, atol): """Tests closeness of two lists of floats.""" @@ -76,6 +89,12 @@ class UnaryOpsTest(XLATestCase): array_ops.diag_part, np.arange(36).reshape([2, 3, 2, 3]).astype(dtype), np.array([[0, 7, 14], [21, 28, 35]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.diag, np.array([[1, 2], [3, 4]], dtype=dtype), + np.array( + [[[[1, 0], [0, 0]], [[0, 2], [0, 0]]], [[[0, 0], [3, 0]], + [[0, 0], [0, 4]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.identity, @@ -86,6 +105,21 @@ class UnaryOpsTest(XLATestCase): array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype), np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype), + np.array( + [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]], + dtype=dtype)) + self._assertOpOutputMatchesExpected( + array_ops.matrix_diag, + np.array( + [[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype), + np.array( + [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], + [[4, 0, 0], [0, 5, 0], [0, 0, 6]]], + [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], + [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]], + dtype=dtype)) self._assertOpOutputMatchesExpected( array_ops.matrix_diag_part, np.arange(3 * 2 * 4).reshape([3, 2, 4]).astype(dtype), @@ -120,6 +154,21 @@ class UnaryOpsTest(XLATestCase): def testFloatOps(self): for dtype in self.float_types: + x = np.arange(-0.90, 0.90, 0.25) + self._assertOpOutputMatchesExpected( + math_ops.acos, + x.astype(dtype), + expected=np.arccos(x).astype(dtype)) + self._assertOpOutputMatchesExpected( + math_ops.asin, + x.astype(dtype), + expected=np.arcsin(x).astype(dtype)) + x = np.arange(-3, 3).reshape(1, 3, 2) + self._assertOpOutputMatchesExpected( + math_ops.atan, + x.astype(dtype), + expected=np.arctan(x).astype(dtype)) + self._assertOpOutputMatchesExpected( math_ops.acosh, np.array([1, 2, 3, 4], dtype=dtype), @@ -331,26 +380,23 @@ class UnaryOpsTest(XLATestCase): def testComplexOps(self): for dtype in self.complex_types: - # TODO(b/65408531): Wider support for log (needs atan2). - atan2_supported = self.device == "XLA_GPU" - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.acosh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arccosh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.acosh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arccosh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.asinh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arcsinh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.asinh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arcsinh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) - self._assertOpOutputMatchesExpected( - math_ops.atanh, - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), - expected=np.arctanh( - np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.atanh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arctanh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.cosh, @@ -377,11 +423,10 @@ class UnaryOpsTest(XLATestCase): np.array([[1, 2j, 2 + 3j]], dtype=dtype), expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype)) - if atan2_supported: - self._assertOpOutputMatchesExpected( - math_ops.log, - np.array([[5j, 3 - 2j]], dtype=dtype), - expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.sin, @@ -395,27 +440,26 @@ class UnaryOpsTest(XLATestCase): # TODO(b/34703906): improve log1p implementation and make tolerance # tighter. - if atan2_supported: # TODO(b/34703906): log support - self._assertOpOutputMatchesExpected( - math_ops.log1p, - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), - expected=np.log1p( - np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.log1p, + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), + expected=np.log1p( + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) - val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) - self._assertOpOutputMatchesExpected( - math_ops.rsqrt, val, expected=1 / np.sqrt(val)) + val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype) + self._assertOpOutputMatchesExpected( + math_ops.rsqrt, val, expected=1 / np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) + self._assertOpOutputMatchesExpected( + math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val))) - self._assertOpOutputMatchesExpected( - math_ops.sqrt, val, expected=np.sqrt(val)) + self._assertOpOutputMatchesExpected( + math_ops.sqrt, val, expected=np.sqrt(val)) - self._assertOpOutputMatchesExpected( - math_ops.tanh, - np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), - expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.tanh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.tan, @@ -448,12 +492,10 @@ class UnaryOpsTest(XLATestCase): np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype), expected=np.array([[1, 1], [1, 1]], dtype=dtype)) - if atan2_supported: # TODO(b/34703906): atan2 support - self._assertOpOutputMatchesExpected( - math_ops.angle, - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), - expected=np.angle( - np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) + self._assertOpOutputMatchesExpected( + math_ops.angle, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.angle(np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) self._assertOpOutputMatchesExpected( math_ops.conj, @@ -541,7 +583,8 @@ class UnaryOpsTest(XLATestCase): def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types + types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) | + self.complex_tf_types) for shape in shapes: for src_type in types: for dst_type in types: @@ -641,55 +684,88 @@ class UnaryOpsTest(XLATestCase): equality_test=self.ListsAreClose) def testDepthToSpace(self): + def make_op(data_format): + def op(x): + return array_ops.depth_to_space(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4]]]], dtype=dtype), - expected=np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.depth_to_space(x, block_size=2), - np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype), - expected=np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format)) def testSpaceToDepth(self): + def make_op(data_format): + def op(x): + return array_ops.space_to_depth(x, block_size=2, + data_format=data_format) + return op + for dtype in self.numeric_types: - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2]], - [[3], [4]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4]]]], dtype=dtype)) + for data_format in ["NCHW", "NHWC"]: + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + data_format), + expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1, 2, 3], [4, 5, 6]], - [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], - dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype), + data_format)) - self._assertOpOutputMatchesExpected( - lambda x: array_ops.space_to_depth(x, block_size=2), - np.array([[[[1], [2], [5], [6]], - [[3], [4], [7], [8]], - [[9], [10], [13], [14]], - [[11], [12], [15], [16]]]], dtype=dtype), - expected=np.array([[[[1, 2, 3, 4], - [5, 6, 7, 8]], - [[9, 10, 11, 12], - [13, 14, 15, 16]]]], dtype=dtype)) + self._assertOpOutputMatchesExpected( + make_op(data_format), + nhwc_to_format(np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + data_format), + expected=nhwc_to_format( + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + data_format)) def _assertSoftplusMatchesExpected(self, features, dtype): features = np.array(features, dtype=dtype) diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index c50342dee45eba6ae54f01653ecc81ef096b547b..b08d6ab21e0746558cb3d4818d4c822c45d2e9ee 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -107,11 +107,26 @@ class VariableOpsTest(XLATestCase): [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], ).astype(dtype), sess.run(x)) + def testShape(self): + for dtype in self.numeric_types: + init = np.ones([2, 3]).astype(dtype) + with self.test_session() as session, self.test_scope(): + v = resource_variable_ops.ResourceVariable(init) + session.run(variables.variables_initializer([v])) + h = v.handle + s32, s64 = session.run([ + resource_variable_ops.variable_shape(h), + resource_variable_ops.variable_shape(h, out_type=dtypes.int64) + ]) + self.assertEqual(s32.dtype, np.int32) + self.assertEqual(s64.dtype, np.int64) + self.assertAllEqual(s32, [2, 3]) + self.assertAllEqual(s64, [2, 3]) + def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" for dtype in self.numeric_types: with self.test_session() as session: - print(ops.get_default_graph()) with self.test_scope(): with variable_scope.variable_scope("ascope", use_resource=True): x = variable_scope.get_variable( diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 0be127997e5211f810ca791187486760881fe172..7e1f5c76ed65946363cc3c113ab1a9862f87b289 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -53,41 +53,100 @@ class XLATestCase(test.TestCase): super(XLATestCase, self).__init__(method_name) self.device = FLAGS.test_device self.has_custom_call = (self.device == 'XLA_CPU') - self.all_tf_types = [ + self._all_tf_types = set([ dtypes.as_dtype(types_pb2.DataType.Value(name)) for name in FLAGS.types.split(',') - ] - self.int_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_integer - ] - self.float_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_floating - ] - self.complex_tf_types = [ - dtype for dtype in self.all_tf_types if dtype.is_complex - ] - self.numeric_tf_types = ( - self.int_tf_types + self.float_tf_types + self.complex_tf_types) - - self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] - self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types] - self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types] - self.complex_types = [ + ]) + self.int_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_integer + ]) + self._float_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_floating + ]) + self.complex_tf_types = set([ + dtype for dtype in self._all_tf_types if dtype.is_complex + ]) + self._numeric_tf_types = set( + self.int_tf_types | self._float_tf_types | self.complex_tf_types) + + self._all_types = set( + [dtype.as_numpy_dtype for dtype in self._all_tf_types]) + self.int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self._float_types = set( + [dtype.as_numpy_dtype for dtype in self._float_tf_types]) + self.complex_types = set([ dtype.as_numpy_dtype for dtype in self.complex_tf_types - ] - self.numeric_types = self.int_types + self.float_types + self.complex_types + ]) + self._numeric_types = set( + self.int_types | self._float_types | self.complex_types) # Parse the manifest file, if any, into a regex identifying tests to # disable self.disabled_regex = None + self._method_types_filter = dict() + # TODO(xpan): Make it text proto if it doesn't scale. + # Each line of the manifest file specifies an entry. The entry can be + # 1) TestNameRegex // E.g. CumprodTest.* Or + # 2) TestName TypeName // E.g. AdamOptimizerTest.testSharing DT_BFLOAT16 + # The 1) disables the entire test. While 2) only filter some numeric types + # so that they are not used in those tests. + if FLAGS.disabled_manifest is not None: comments_re = re.compile('#.*$') manifest_file = open(FLAGS.disabled_manifest, 'r') - lines = manifest_file.read().splitlines() - lines = [comments_re.sub('', l).strip() for l in lines] - self.disabled_regex = re.compile('|'.join(lines)) + disabled_tests = [] + disabled_method_types = [] + for l in manifest_file.read().splitlines(): + entry = comments_re.sub('', l).strip().split(' ') + if len(entry) == 1: + disabled_tests.append(entry[0]) + elif len(entry) == 2: + disabled_method_types.append( + (entry[0], entry[1].strip().split(','))) + else: + raise ValueError('Bad entry in manifest file.') + + self.disabled_regex = re.compile('|'.join(disabled_tests)) + for method, types in disabled_method_types: + self._method_types_filter[method] = set([ + dtypes.as_dtype(types_pb2.DataType.Value(name)).as_numpy_dtype + for name in types]) manifest_file.close() + @property + def all_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + tf_types = set([dtypes.as_dtype(t) + for t in self._method_types_filter.get(name, set())]) + return self._all_tf_types - tf_types + + @property + def float_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._float_types - self._method_types_filter.get(name, set()) + + @property + def float_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._float_tf_types - self._method_types_filter.get(name, set()) + + @property + def numeric_tf_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + tf_types = set([dtypes.as_dtype(t) + for t in self._method_types_filter.get(name, set())]) + return self._numeric_tf_types - tf_types + + @property + def numeric_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._numeric_types - self._method_types_filter.get(name, set()) + + @property + def all_types(self): + name = '{}.{}'.format(type(self).__name__, self._testMethodName) + return self._all_types - self._method_types_filter.get(name, set()) + def setUp(self): super(XLATestCase, self).setUp() name = '{}.{}'.format(type(self).__name__, self._testMethodName) diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 5a81438b1c48e7f0ef66dae072092974db24c621..3c7dfef03dfb5d86dd63fd4aa84ad56081833035 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -1,6 +1,6 @@ licenses(["notice"]) # Apache 2.0 -load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") package_group( name = "internal", @@ -25,6 +25,30 @@ package( load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +cc_library( + name = "tf2xla_supported_ops_lib", + srcs = ["tf2xla_supported_ops.cc"], + hdrs = ["tf2xla_supported_ops.h"], + visibility = ["//visibility:public"], + deps = [ + ":xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_cpu_only_ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:ops", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_binary( + name = "tf2xla_supported_ops", + srcs = ["tf2xla_supported_ops_main.cc"], + visibility = ["//visibility:public"], + deps = [":tf2xla_supported_ops_lib"], +) + xla_proto_library( name = "tf2xla_proto", srcs = ["tf2xla.proto"], @@ -67,7 +91,6 @@ cc_library( # Keep dependencies to a minimum here; this library is used in every AOT # binary produced by tfcompile. "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/core:framework_lite", ], @@ -97,18 +120,21 @@ cc_library( cc_library( name = "xla_compiler", srcs = [ + "const_analysis.cc", + "graph_compiler.cc", "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", - "graph_compiler.cc", + "xla_resource.cc", "xla_cpu_backend.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]), hdrs = [ + "const_analysis.h", "graph_compiler.h", "xla_compilation_device.h", "xla_compiler.h", @@ -116,11 +142,11 @@ cc_library( "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", + "xla_resource.h", ], visibility = [":friends"], deps = [ ":common", - ":const_analysis", ":dump_graph", ":functionalize_control_flow", ":sharding_util", @@ -180,6 +206,7 @@ cc_library( deps = [ "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:sharding_builder", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -215,6 +242,7 @@ cc_library( "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", + "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], @@ -328,28 +356,16 @@ tf_cc_test( ], ) -cc_library( - name = "const_analysis", - srcs = ["const_analysis.cc"], - hdrs = ["const_analysis.h"], - deps = [ - "//tensorflow/core:core_cpu", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", - ], -) - tf_cc_test( name = "const_analysis_test", size = "small", srcs = ["const_analysis_test.cc"], deps = [ - ":const_analysis", + ":xla_compiler", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:ops", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:ops", "//tensorflow/core:test", @@ -357,13 +373,6 @@ tf_cc_test( ], ) -cc_library( - name = "xla_local_runtime_context", - hdrs = ["xla_local_runtime_context.h"], - visibility = ["//visibility:public"], - deps = ["//tensorflow/core:framework_lite"], -) - cc_library( name = "dump_graph", srcs = [ @@ -400,6 +409,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib", ], ) diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index d57273d84442c17565a6ace1c29170a0f3ba583b..82923722c54d235716b9138d95a75a441df924ca 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/graph/algorithm.h" @@ -27,93 +28,18 @@ namespace tensorflow { // compile-time constants. Status BackwardsConstAnalysis(const Graph& g, std::vector* compile_time_const_args) { - // TODO(phawkins): annotate these on the kernel registrations, rather than - // using a hard-coded list. - // (operator, argument) pairs that must be compile-time constants. - const std::unordered_multimap compile_time_const_inputs = { - {"All", "reduction_indices"}, - {"Any", "reduction_indices"}, - {"ArgMin", "dimension"}, - {"ArgMax", "dimension"}, - {"AvgPoolGrad", "orig_input_shape"}, - {"AvgPool3DGrad", "orig_input_shape"}, - {"BatchToSpace", "crops"}, - {"BatchToSpaceND", "block_shape"}, - {"BatchToSpaceND", "crops"}, - {"BroadcastArgs", "s0"}, - {"BroadcastArgs", "s1"}, - {"BroadcastGradientArgs", "s0"}, - {"BroadcastGradientArgs", "s1"}, - {"Concat", "concat_dim"}, - {"ConcatV2", "axis"}, - {"ConcatOffset", "concat_dim"}, - {"ConcatOffset", "shape"}, - {"Conv2DBackpropFilter", "filter_sizes"}, - {"Conv2DBackpropInput", "input_sizes"}, - {"Conv3DBackpropFilterV2", "filter_sizes"}, - {"Conv3DBackpropInputV2", "input_sizes"}, - {"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"}, - {"DepthwiseConv2dNativeBackpropInput", "input_sizes"}, - {"DynamicStitch", "indices"}, - {"ExpandDims", "dim"}, - {"Fill", "dims"}, - {"GatherV2", "axis"}, - {"InvertPermutation", "x"}, - {"LinSpace", "start"}, - {"LinSpace", "stop"}, - {"LinSpace", "num"}, - {"Max", "reduction_indices"}, - {"Mean", "reduction_indices"}, - {"Min", "reduction_indices"}, - {"OneHot", "depth"}, - {"Pad", "paddings"}, - {"PadV2", "paddings"}, - {"MirrorPad", "paddings"}, - {"Multinomial", "num_samples"}, - {"Prod", "reduction_indices"}, - {"RandomStandardNormal", "shape"}, - {"RandomUniform", "shape"}, - {"RandomUniformInt", "shape"}, - {"Range", "start"}, - {"Range", "limit"}, - {"Range", "delta"}, - {"Reshape", "shape"}, - {"ResourceStridedSliceAssign", "begin"}, - {"ResourceStridedSliceAssign", "end"}, - {"ResourceStridedSliceAssign", "strides"}, - {"Reverse", "dims"}, - {"ReverseV2", "axis"}, - {"Slice", "begin"}, - {"Slice", "size"}, - {"SpaceToBatch", "paddings"}, - {"SpaceToBatchND", "block_shape"}, - {"SpaceToBatchND", "paddings"}, - {"Split", "split_dim"}, - {"SplitV", "split_dim"}, - {"SplitV", "size_splits"}, - {"StackV2", "max_size"}, - {"StridedSlice", "begin"}, - {"StridedSlice", "end"}, - {"StridedSlice", "strides"}, - {"StridedSliceGrad", "shape"}, - {"StridedSliceGrad", "begin"}, - {"StridedSliceGrad", "end"}, - {"StridedSliceGrad", "strides"}, - {"Sum", "reduction_indices"}, - {"TensorArrayV3", "size"}, - {"TensorArraySplitV3", "lengths"}, - {"Tile", "multiples"}, - {"Transpose", "perm"}}; - // Operators that don't look at the data of their inputs, just the shapes. const std::unordered_set metadata_ops = { - "Rank", "Shape", "ShapeN", "Size", + "Rank", + "Shape", + "ShapeN", + "Size", }; Status status; std::unordered_set must_be_const; - auto visit = [&status, &metadata_ops, &compile_time_const_inputs, - &must_be_const, compile_time_const_args](Node* node) { + auto visit = [&status, &metadata_ops, &must_be_const, + compile_time_const_args](Node* node) { if (!status.ok()) return; // If this is a metadata-only op, don't propagate the const requirement. @@ -136,16 +62,17 @@ Status BackwardsConstAnalysis(const Graph& g, } // Mark any compile-time constant operator arguments as const. - auto range = compile_time_const_inputs.equal_range(node->type_string()); - if (range.first == range.second) return; + const std::unordered_set* const_inputs = + XlaOpRegistry::CompileTimeConstantInputs(node->type_string()); + if (!const_inputs || const_inputs->empty()) return; NameRangeMap input_name_ranges; status = NameRangesForNode(*node, node->op_def(), &input_name_ranges, nullptr); if (!status.ok()) return; - for (auto it = range.first; it != range.second; ++it) { - auto name_range = input_name_ranges.find(it->second); + for (const string& input : *const_inputs) { + auto name_range = input_name_ranges.find(input); if (name_range == input_name_ranges.end()) continue; for (Edge const* edge : node->in_edges()) { diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc index ddd912b87315f7943915153b5bf73531107af54d..03603ee9baefd1d20d220faf63c9c1c427ebdf31 100644 --- a/tensorflow/compiler/tf2xla/dump_graph.cc +++ b/tensorflow/compiler/tf2xla/dump_graph.cc @@ -63,7 +63,12 @@ string MakeUniquePath(string name) { string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def)); + Status status = WriteTextProto(Env::Default(), path, graph_def); + if (!status.ok()) { + VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status; + path.clear(); + path = "(unavailable)"; + } return path; } @@ -79,7 +84,13 @@ string DumpGraphToFile(const string& name, Graph const& graph, string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) { string path = MakeUniquePath(name); - TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef)); + Status status = WriteTextProto(Env::Default(), path, fdef); + if (!status.ok()) { + VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : " + << status; + path.clear(); + path = "(unavailable)"; + } return path; } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 5726d8294a7c7fe81d7f6b803af89ca305aa2deb..f8169795ddfb7fd4e93d3f136c51623385868951 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/control_flow.h" #include "tensorflow/core/lib/gtl/optional.h" @@ -36,6 +37,8 @@ namespace tensorflow { namespace { +using xla::StatusOr; + const char* const kArgOp = "_Arg"; const char* const kRetValOp = "_Retval"; @@ -75,6 +78,20 @@ struct Frame { std::unordered_set nodes; }; +// Comparison function used for sorting nodes consistently. +// a) resource variables are last, and +// b) sort lexicographically by name (for deterministic output). +struct NodeCmp { + bool operator()(const Node* lhs, const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); + } +}; + // Returns a textual representation of the names of the nodes in the input. template string NodesToString(const T& nodes) { @@ -140,7 +157,7 @@ Status CopySubgraph(const Graph& graph, const Frame* frame, return Status::OK(); } -xla::StatusOr AddNode(const NodeDef& node_def, Graph* graph) { +StatusOr AddNode(const NodeDef& node_def, Graph* graph) { Status status; Node* inserted_node = graph->AddNode(node_def, &status); if (!status.ok()) { @@ -149,7 +166,7 @@ xla::StatusOr AddNode(const NodeDef& node_def, Graph* graph) { return inserted_node; } -xla::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { +StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); @@ -158,7 +175,7 @@ xla::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { return AddNode(arg_def, graph); } -xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { +StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { NodeDef ret_def; ret_def.set_op(kRetValOp); ret_def.set_name(strings::StrCat(kRetValOp, index)); @@ -268,7 +285,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, Status FunctionalizeLoop(Graph* graph, Frame* frame, FunctionLibraryDefinition* library) { VLOG(2) << "Frame " << frame->name << " before: " - << dump_graph::DumpGraphToFile("functionalize_before", *graph); + << dump_graph::DumpGraphToFile("functionalize_before", *graph, + library); // Split loop-varying Enter nodes with multiple successors. If the same // Tensor is fed as input to multiple loop arguments, we may end up with a @@ -309,16 +327,9 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } frame->args = std::move(args); - // Order the arguments so that: - // a) resource variables are last, and - // b) sort lexicographically by name (for deterministic output). - std::sort(frame->args.begin(), frame->args.end(), - [](const Arg& a, const Arg& b) { - bool a_is_resource = (a.enter->input_type(0) == DT_RESOURCE); - bool b_is_resource = (b.enter->input_type(0) == DT_RESOURCE); - return std::tie(a_is_resource, a.enter->name()) < - std::tie(b_is_resource, b.enter->name()); - }); + std::sort( + frame->args.begin(), frame->args.end(), + [](const Arg& a, const Arg& b) { return NodeCmp()(a.enter, b.enter); }); if (frame->loop_cond == nullptr) { return errors::InvalidArgument("Loop ", frame->name, @@ -417,16 +428,36 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, // identity nodes are values used by the loop body or condition. // The Identity node may have the wrong device so copy the device from // one of its outputs instead. + std::deque possible_exit; for (const Edge* edge : arg.switch_node->out_edges()) { - if (edge->src_output() == 0 && IsExit(edge->dst())) { + if (edge->src_output() == 0) { + possible_exit.push_back(edge); + } + if (IsIdentity(edge->dst())) { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } + } + // TODO(b/67425339): Allow general graph between switch and exit. + while (!possible_exit.empty()) { + const Edge* edge = possible_exit.front(); + possible_exit.pop_front(); + if (IsExit(edge->dst())) { if (arg.exit != nullptr) { return errors::InvalidArgument("Duplicate Exit successors to ", arg.switch_node->name()); } arg.exit = edge->dst(); - } else if (StringPiece(edge->dst()->type_string()) == "Identity") { - TF_RETURN_IF_ERROR( - SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); + } else { + if (!IsIdentity(edge->dst())) { + return errors::Unimplemented("General graph between switch (", + arg.switch_node->name(), + ") and exit node of frame ", + frame->name, " not supported yet."); + } + for (const Edge* out : edge->dst()->out_edges()) { + possible_exit.push_back(out); + } } } } @@ -440,7 +471,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, TF_RETURN_IF_ERROR(BuildLoopBody(*graph, frame, &arg_types, &body_graph)); VLOG(2) << "Frame " << frame->name << " condition: " - << dump_graph::DumpGraphToFile("loop_condition", *cond_graph) + << dump_graph::DumpGraphToFile("loop_condition", *cond_graph, library) << " body: " << dump_graph::DumpGraphToFile("loop_body", *body_graph); static std::atomic sequence_num(0LL); @@ -521,266 +552,141 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " - << dump_graph::DumpGraphToFile("functionalize_after", *graph); + << dump_graph::DumpGraphToFile("functionalize_after", *graph, + library); return Status::OK(); } class FunctionalizeCond { public: - // Identifies the connected parts of the tf.Cond. - struct ClusterHandle { - explicit ClusterHandle(int representative = -1) - : representative(representative) {} + // All nodes are assumed to be either in no branch, then branch, else branch, + // or both branches (such as merge nodes). + enum Branch { + kElseBranch = 0, + kThenBranch = 1, + kBoth = 2, + kNeither = 3, + kNumBranchTypes = 4 + }; - bool operator==(const ClusterHandle& other) const { - return representative == other.representative; - } + // Returns a textual representation of the Branch b. + static string Branch_Name(FunctionalizeCond::Branch b); - bool operator!=(const ClusterHandle& other) const { - return !(*this == other); - } + // Functionalize all the switch-merge nodes of a loop-free graph into XlaIf + // nodes. That is, attempt to transform every remaining switch and merge nodes + // in the graph into XlaIf nodes. + // Precondition: All while loops have been removed from graph. + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - bool operator<(const ClusterHandle& other) const { - return representative < other.representative; + private: + // CondArgNode represents a input to the conditional and its corresponding + // switch nodes. + struct CondArgNode { + explicit CondArgNode(Node* input) : input(input) {} + string ToString() const { + return strings::StrCat("input=", input->name(), + " switches=", NodesToString(switches)); } - bool operator>(const ClusterHandle& other) const { - return representative > other.representative; - } + Node* input; + std::vector switches; + }; + using CondArgNodes = std::vector; + struct ForwardFlowNode { + explicit ForwardFlowNode(Branch branch = Branch::kNeither) + : branch(branch), count(0) {} string ToString() const { - return strings::StrCat("Cluster_", representative); + return strings::StrCat("branch=", Branch_Name(branch), " count=", count); } - - // Vector of UnionFind indexable by ClusterHandle and Node*. - struct Vector { - explicit Vector(size_t size) : clusters(size) {} - - UnionFind& at(const ClusterHandle& cluster) { - return clusters.at(cluster.representative); - } - - UnionFind& at(const Node* node) { - return clusters.at(node->id()); - } - - UnionFind& operator[](const Node* node) { - return clusters.at(node->id()); - } - - size_t size() const { return clusters.size(); } - - void resize(size_t count) { return clusters.resize(count); } - - private: - std::vector> clusters; - }; - - private: - int representative; + Branch branch; + int count; }; - // Represents a node in the clustered graph consisting of switch_nodes, - // merge_nodes as well as the edges into and out of this node to other - // Clusters. Each Cluster corresponds to a ClusterHandle and has a - // corresponding representative. - struct Cluster { - std::unordered_set switch_nodes; - std::unordered_set merge_nodes; - std::unordered_set in_nodes; - std::unordered_set out_nodes; - - // A member of the ClusterHandle corresponding to this Cluster. - ClusterHandle representative; - bool visited = false; - }; + // Group of switch nodes that will be part of the same XlaIf. + struct SwitchCluster { + explicit SwitchCluster(Node* predicate) : predicate(predicate) {} + string ToString() const { + return strings::StrCat(name, " predicate=", predicate->name(), + " switches=", NodesToString(switches)); + } - // Represent the clustered graph as map from cluster representative to - // Cluster. - using ClusteredGraph = std::map; - - // The arguments and condition of a XlaIf. The arguments are ordered by node - // id in the original graph. - struct CondArgs { - struct CondCmp { - bool operator()(const Node* lhs, const Node* rhs) const { - bool lhs_is_resource = - lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; - bool rhs_is_resource = - rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; - return std::tie(lhs_is_resource, lhs->name()) < - std::tie(rhs_is_resource, rhs->name()); - } - }; - Node* conditional = nullptr; - std::set args; + string name; + Node* predicate; + std::vector switches; }; - static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); - - private: - FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) - : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} - - // Returns a vector of Switch nodes from the clustered graph where the nodes - // are sorted by the number of switch nodes minus number of merge nodes - // from a root of the clustered graph to the given Merge node, with ties - // broken by the representative of the Cluster. This corresponds to sorting by - // nesting depth, from deepest nested to outermost. - std::vector> SortedSwitchNodes(); - - // Returns whether the graph has no conditionals. - bool NoConditionals() const { return merge_nodes_.empty(); } - - // Construct the clustered graph by creating nodes for each cluster and the - // connections between the clusters. Switch and Merge nodes partition - // clusters, so iterate over those. Note: a Cluster may have neither a - // Merge or Switch but will have an in/out edge from a Cluster that has. - void CreateClusters(); - - // Creates the clustered graph by identifying all the edges between different - // clusters and collecting all switch and merge nodes that correspond to a - // cluster. - void CreateClusteredGraph(); - - // If `from` and `to` correspond to different clusters, then merge the nodes - // in the clustered graph corresponding to `from` and `to`. - // - // If `remove_from_graph` is specified then the `from` node is also removed - // from the clustered graph post contracting the edge. - void ContractEdge(Cluster* from, Cluster* to, bool remove_from_graph = false); - - // Converts a Merge node to a XlaIf. This encapsulates the process of - // extracting the bodies needed for the then and else branch, creates a XlaIf - // node, removing the nodes of the branches from the graph and replacing the - // merge node with a XlaIf. - Status ConvertCorrespondingMergeToXlaIf(Cluster* switch_cluster); - - // Removes a Switch cluster feeding directly into a Merge cluster by removing - // the Switch and Merge nodes and collapsing into a single cluster. - Status RemoveTrivialSwitch(Cluster* switch_cluster); - - // Returns the merge cluster corresponding to the switch node. This function - // only returns the merge cluster in the case where we have a switch node that - // is the single entry point for all paths to a common merge cluster, this - // merge cluster may be created by combining multiple merge clusters, that - // share the switch cluster as common ancestor, together. - // - // Switch - // / \ - // Branch Branch - // \ / - // merge_cluster - // - // Note: either of the branches may be empty. The case where both branches are - // empty is handled by RemoveTrivialSwitch. - gtl::optional CreateCorrespondingMergeCluster( - Cluster* switch_cluster); - - // Determines the arguments needed as input to the Merge cluster originating - // from the Switch cluster. - xla::StatusOr DetermineCondArgs(const Cluster& merge_cluster, - const Cluster& switch_cluster); - - // Builds a XlaIfOp to replace the Merge node with. - xla::StatusOr BuildAndAddXlaIfOp(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs); + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library, + bool dump_graphs) + : library_(library), graph_(graph), dump_graphs_(dump_graphs) {} + + // Perform the actual cond functionalization. Iterate over groups of switch + // nodes (linked by common predicate), from innermost to outermost, and + // extract into XlaIf nodes. + Status FunctionalizeInternal(); + + // Determines the branch_map (mapping from node to branch of cond) and + // frontier (the nodes where the cond ends). + StatusOr, + std::unordered_set>> + DetermineBranchMapAndFrontier(const SwitchCluster& switch_cluster); + + // Returns XlaIf node created from subgraph of merge and switch nodes. This + // encapsulates the process of extracting the bodies needed for the then and + // else branch, creates a XlaIf node, removing the nodes of the branches from + // the graph and replacing the merge node with a XlaIf. + StatusOr ConvertToXlaIf(const CondArgNodes& cond_arg_nodes, + const SwitchCluster& switch_cluster, + const std::vector& switches); + + // Builds a XlaIfOp to replace the Switch-Graph-Merge cluster with. + StatusOr BuildAndAddXlaIfOp(const CondArgNodes& cond_arg_nodes, + const SwitchCluster& switch_cluster, + const std::vector& merge_nodes); // Extracts a function body corresponding to the given input edge of the merge // node. - Status ExtractBody(const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs, int input_edge, + Status ExtractBody(const CondArgNodes& cond_arg_nodes, + const std::vector& switches, + const std::vector& merge_nodes, int input_edge, Graph* body); // Adds all the input edges to `if_node` corresponding to the arguments. - Status AddInputEdges(const CondArgs& cond_args, Node* if_node); + Status AddInputEdges(const CondArgNodes& cond_arg_nodes, Node* predicate, + Node* if_node); // Adds all output edges from the `if_node`. Status AddOutputEdges(const std::vector& outputs, Node* if_node); - // Removes all nodes from the graph that are part of cluster. - void RemoveClusterNodes(Cluster* cluster); - - // Removes all argument nodes that are unused. - template - void RemoveUnusedArgs(const T& args); - - // Removes all Merge nodes in merge_cluster. - void RemoveMergeNodes(Cluster* merge_cluster); + // Returns the switch clusters of graph_ in postorder. Dead switch nodes are + // skipped and removed from the graph. + StatusOr> DeterminePredicateSwitchOrder(); + + // Update the state for destination based on the state of source and the node + // being updated. + Status Join(const ForwardFlowNode& src_state, const Node* dst, + ForwardFlowNode* dst_state); + + // Ensure that all nodes in the branch_map are dominated by the switch + // nodes. Returns nodes that are not dominated by the switches but are a + // control dependency of a node in the cond, and remove such control + // dependencies. + StatusOr> EnsureDominanceAndReturnNonDominatedControlNodes( + const std::unordered_map& branch_map, + const std::vector& switches); + + // Validates that the frontier of nodes for the conditional + // section are as expected. + Status ValidateFrontier( + const std::unordered_map& branch_map, + const std::unordered_set& frontier); - // Returns the representative member of the corresponding cluster. - ClusterHandle Representative(const Node* node) { - return clusters_.at(node).Get(); - } - - ClusteredGraph clustered_graph_; - ClusterHandle::Vector clusters_; - std::unordered_set merge_nodes_; - std::unordered_set switch_nodes_; FunctionLibraryDefinition* library_; Graph* graph_; + bool dump_graphs_; }; -std::ostream& operator<<(std::ostream& os, - const FunctionalizeCond::ClusterHandle& c) { - os << c.ToString(); - return os; -} - -// Returns a dot representation of the clustered graph showing the connections -// between the nodes and the nodes in each cluster. -string DebugString(const Graph& graph, - FunctionalizeCond::ClusterHandle::Vector* clusters) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; - std::map subgraphs; - auto name = [](const Node* n) { - return strings::StrCat(n->type_string(), "_", n->id()); - }; - for (Node* n : graph.nodes()) { - strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), " [label=\"", - name(n), "\"];\n"); - } - for (auto kv : subgraphs) { - strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", - "style=filled; color=lightgrey;", "label = \"", - kv.first.ToString(), "\";\n", kv.second, "}\n"); - } - for (Node* n : graph.nodes()) { - for (Node* in : n->in_nodes()) { - strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); - } - } - return strings::StrCat(ret, "} // end"); -} - -string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { - string ret = "digraph {\ncompound=true;labeljust=\"r\";\n"; - auto name = [](const FunctionalizeCond::Cluster& cluster) { - return cluster.representative.ToString(); - }; - for (auto kv : clustered_graph) { - if (!kv.second.switch_nodes.empty() || !kv.second.merge_nodes.empty()) { - strings::StrAppend( - &ret, kv.first.ToString(), " [label=\"", name(kv.second), - kv.second.switch_nodes.empty() - ? "" - : strings::StrCat(" switches=", kv.second.switch_nodes.size()), - kv.second.merge_nodes.empty() - ? "" - : strings::StrCat(" merges=", kv.second.merge_nodes.size()), - "\"];\n"); - } - } - for (auto kv : clustered_graph) { - for (auto in : kv.second.in_nodes) { - strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); - } - } - return strings::StrCat(ret, "} // end"); -} - bool IsDeadSwitch(const Node* node) { for (const Edge* e : node->out_edges()) { const Node* dst = e->dst(); @@ -796,337 +702,454 @@ bool IsDeadSwitch(const Node* node) { return true; } -void FunctionalizeCond::CreateClusters() { - ClusterHandle source_cluster = ClusterHandle(Graph::kSourceId); - auto& source = clusters_.at(source_cluster); - std::deque>> workqueue; - workqueue.push_back({source_cluster, {}}); - for (Node* node : graph_->nodes()) { - if (IsSwitch(node)) { - switch_nodes_.insert(node); - } else if (IsMerge(node)) { - merge_nodes_.insert(node); - } - ClusterHandle& cluster = clusters_.at(node).Get(); - cluster = ClusterHandle(node->id()); - // Group all source clusters together. - if (node->IsSource() || node->in_edges().empty()) { - clusters_.at(node).Merge(&source); - source.Merge(&clusters_.at(node)); - workqueue.front().second.push_back(node); +string FunctionalizeCond::Branch_Name(FunctionalizeCond::Branch b) { + const string branch_name[FunctionalizeCond::kNumBranchTypes + 1] = { + "else", "then", "both", "neither", "count"}; + return branch_name[b]; +} + +Status FunctionalizeCond::ValidateFrontier( + const std::unordered_map& + branch_map, + const std::unordered_set& frontier) { + std::unordered_set pending[kNumBranchTypes]; + for (Node* n : frontier) { + pending[branch_map.at(n).branch].insert(n); + } + TF_RET_CHECK(pending[kNeither].empty()) << NodesToString(pending[kNeither]); + for (const Node* n : pending[kBoth]) { + TF_RET_CHECK(IsMerge(n)) << n->DebugString(); + // Merge nodes may be in then or else branch too + } + int index = (pending[kThenBranch].size() <= pending[kElseBranch].size()) + ? kThenBranch + : kElseBranch; + int other = 1 - index; + for (const Node* n : pending[index]) { + if (pending[other].find(n) != pending[other].end()) { + return errors::Internal( + "Node (", n->DebugString().c_str(), + ") in both Else and Then branch should be in Both."); } } + // An empty frontier indicates a dead switch. Above we attempt to remove dead + // switch nodes, but not all are removed so don't treat it as an error yet. + // TODO(jpienaar): Find out why dead switch nodes remain. + // if (pending[kBoth].empty() && pending[kThenBranch].empty() && + // pending[kElseBranch].empty()) { + // return errors::Internal("Unexpected empty frontier for switch nodes"); + // } + return Status::OK(); +} - // If there are no Merge nodes, then terminate. - if (merge_nodes_.empty()) { - return; +Status FunctionalizeCond::Join(const ForwardFlowNode& src_state, + const Node* dst, ForwardFlowNode* dst_state) { + TF_RET_CHECK(dst_state->branch != Branch::kBoth && + dst_state->branch != Branch::kNumBranchTypes) + << "Unexpected/Invalid branch type: Merging " + << Branch_Name(src_state.branch) << " with " + << Branch_Name(dst_state->branch); + if (dst_state->branch == Branch::kNeither) { + dst_state->branch = src_state.branch; + } else if (src_state.branch != dst_state->branch && + src_state.branch != Branch::kNeither) { + if (IsMerge(dst)) { + dst_state->branch = Branch::kBoth; + } else { + return errors::Internal("Illegal merge: ", src_state.ToString(), " with ", + dst_state->ToString(), " for ", + dst->DebugString()); + } } + ++dst_state->count; + return Status::OK(); +} - // Remove all dead Switch nodes. - RemoveUnusedArgs(switch_nodes_); - - // All parent_'s are still nullptr so clusters_ may still be resized. Resize - // conservatively assuming all merge nodes become XlaIf nodes. - clusters_.resize(clusters_.size() + merge_nodes_.size()); - - std::unordered_set marked; - while (!workqueue.empty()) { - auto cluster_queue = workqueue.front(); - VLOG(4) << "Cluster: " << cluster_queue.first << " Queue: {" - << str_util::Join(cluster_queue.second, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, node->id()); - }) - << "}"; - - UnionFind& repr = clusters_.at(cluster_queue.first); - workqueue.pop_front(); - std::deque switch_nodes; - std::deque merge_nodes; - std::unordered_set cluster_member; - while (!cluster_queue.second.empty()) { - // Iterate node workqueue and flow forward merging all nodes reachable - // that are neither a Switch or a Merge and whose inputs are all part of - // the same cluster. - Node* cur = cluster_queue.second.front(); - cluster_queue.second.pop_front(); - if (marked.find(cur) != marked.end()) { - continue; - } - if (IsMerge(cur)) { - merge_nodes.push_back(cur); - marked.insert(cur); - continue; - } - if (IsSwitch(cur)) { - switch_nodes.push_back(cur); - marked.insert(cur); - continue; - } - clusters_.at(cur).Merge(&repr); - cluster_member.insert(cur); - for (Node* out : cur->out_nodes()) { - bool all_ancestors_in_cluster = true; - for (Node* in : out->in_nodes()) { - if (IsMerge(out)) { - merge_nodes.push_back(out); - } - if (IsSwitch(out)) { - switch_nodes.push_back(out); - } - if (cluster_member.find(in) == cluster_member.end()) { - all_ancestors_in_cluster = false; - break; - } - } - if (all_ancestors_in_cluster && out->IsOp()) { - cluster_queue.second.push_back(out); - marked.insert(cur); - } +StatusOr> +FunctionalizeCond::DeterminePredicateSwitchOrder() { + struct Cluster { + bool operator==(const Cluster& other) const { + return representative == other.representative; + } + int representative = -1; + }; + + // Perform a DFS over the graph and + // * Determine the reverse topological order of the nodes (there should be no + // cycles at this point so the post-order numbering corresponds to the + // reverse topological sorting); + // * Identify dead switches; + // * Initialize the cluster's representative; + std::vector> clusters(graph_->num_node_ids()); + std::vector dead_switches; + std::vector switch_order; + std::vector rev_topo_sorted_nodes; + DFS(*graph_, nullptr, [&](Node* n) { + clusters[n->id()].Get().representative = n->id(); + if (IsSwitch(n)) { + if (IsDeadSwitch(n)) { + dead_switches.push_back(n); + } else { + rev_topo_sorted_nodes.push_back(n); + switch_order.push_back(n); } + } else if (n->IsOp()) { + // Exclude src and sink nodes from further consideration. + rev_topo_sorted_nodes.push_back(n); } + }); + + std::vector switch_clusters; + // Return early if there are no switches in the graph. + if (switch_order.empty()) { + return switch_clusters; + } + + // Remove all dead switch nodes. + for (Node* n : dead_switches) { + VLOG(2) << "Removing dead switch: " << n->DebugString(); + graph_->RemoveNode(n); + } + + // Identify switch nodes that are part of the same control flow context by + // considering the operands of operations: an operation is part of the same + // control context as its operands unless the operation is a switch. Control + // dependencies are considered part of the same control flow context if the + // switch depth is the same (see comment below). + + // entry_cluster records the input cluster to a switch node. This is used when + // merging with a merge node where the dst's cluster is merged with the entry + // cluster of the merge node's cluster (which corresponds to a switch cluster + // and so has an entry cluster). + std::unordered_map*> entry_cluster; + + // Returns the output cluster of a node. Where the output cluster is cluster + // where the output of the node is used. For non-merge nodes this is simply + // the cluster they are part of, while for merge nodes it is the entry cluster + // of the cluster they are part of (this will correspond to the entry node of + // a switch node that dominates the merge). + auto find_output_cluster = [&](Node* n) { + UnionFind* cluster = &clusters[n->id()]; + if (!IsMerge(n)) return cluster; + auto it = entry_cluster.find(clusters[n->id()].Get().representative); + // If the cluster is not found in the entry_cluster map then an + // instruction not dominated by a switch node has been merged into the + // cluster of the merge. This indicates a failure of the clustering. + CHECK(it != entry_cluster.end()) + << "Unable to find entry for n=" << n->id() << " (" + << cluster->Get().representative << ")"; + return it->second; + }; + + // TODO(jpienaar): This could be combined with DetermineBranchMapAndFrontier. + std::vector switch_depth(graph_->num_node_ids()); + for (auto it = rev_topo_sorted_nodes.rbegin(); + it != rev_topo_sorted_nodes.rend(); ++it) { + Node* n = *it; - VLOG(4) << "Switches: {" - << str_util::Join(switch_nodes, ",", - [](string* output, const Node* node) { - strings::StrAppend(output, node->id()); - }) - << "}"; - - // Merge Switch nodes with common predicate. - std::unordered_map> predicate_to_switch; - for (Node* node : switch_nodes) { - Node* tmp; - TF_CHECK_OK(node->input_node(1, &tmp)); - predicate_to_switch[tmp].push_back(node); + // Compute switch depth. + int new_switch_depth = 0; + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + new_switch_depth = std::max( + new_switch_depth, switch_depth[src->id()] - (IsMerge(src) ? 1 : 0)); } - for (auto kv : predicate_to_switch) { - Node* first = kv.second.front(); - for (Node* switch_node : kv.second) { - clusters_.at(first).Merge(&clusters_.at(switch_node)); + switch_depth[n->id()] = new_switch_depth + (IsSwitch(n) ? 1 : 0); + + // Only merge the input operands of a switch. The switch's clustering itself + // is determined by the interaction of the switch's outputs. + if (IsSwitch(n)) { + Node* input; + TF_CHECK_OK(n->input_node(0, &input)); + entry_cluster[n->id()] = &clusters[input->id()]; + UnionFind* cluster = find_output_cluster(input); + int cluster_depth = switch_depth[cluster->Get().representative]; + // Merge the inputs of the switch node with one another. This results in + // predicates and control input residing in the same cluster. + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + UnionFind* src_cluster = find_output_cluster(src); + int src_cluster_depth = switch_depth[src_cluster->Get().representative]; + if (cluster_depth != src_cluster_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Switch ('", + n->name(), "') has operands ('", input->name(), "' and '", + src->name(), "') that have different switch depths (", + cluster_depth, " != ", src_cluster_depth, ")"); + } + cluster->Merge(src_cluster); } + continue; } - // Enqueue each edge of the switch node separately. That is, group all the - // nodes that are due to the true/false edge of the switch together and - // consider all nodes that only have a control dependency on the switch node - // separately. We want to group together all nodes that are part of the same - // branch, as these will be extracted into the `then` and `else` functions - // of the functional if. The ops due to control edges are different as they - // could be involved with either branch and merging them here could result - // in invalid graphs. - for (auto kv : predicate_to_switch) { - ClusterHandle none = ClusterHandle(-1); - ClusterHandle first[2] = {none, none}; - std::deque* queue[2]; - for (auto switch_node : kv.second) { - for (const auto e : switch_node->out_edges()) { - if (IsSwitch(e->dst()) || IsMerge(e->dst())) { - continue; - } - // Control edges are enqueued on their own. - if (e->IsControlEdge()) { - workqueue.push_back({Representative(e->dst()), {e->dst()}}); - continue; - } - // Combine all outputs of the same output port of a switch cluster - // into the same workqueue entry. - if (first[e->src_output()] == none) { - ClusterHandle repr = Representative(e->dst()); - first[e->src_output()] = repr; - workqueue.push_back({repr, {}}); - queue[e->src_output()] = &workqueue.back().second; - } - clusters_.at(first[e->src_output()]).Merge(&clusters_.at(e->dst())); - queue[e->src_output()]->push_back(e->dst()); + for (const Edge* e : n->in_edges()) { + Node* src = e->src(); + if (!src->IsOp()) continue; + UnionFind* cluster = find_output_cluster(src); + // Merge a node with its data operands and with its control operands if + // the src and dst are in the same ControlContext. The ControlContext is + // not explicitly available here, and instead the switch depth is used as + // a proxy here. Due to the invariant that control edges can only be from + // a containing scope to an inner scope or from the inner scope to its + // containing scope (for exit nodes), the switch depth will only match if + // the src and dst are in the same ControlContext. Control edges between + // ControlContexts are handled during the extraction. + int src_id = cluster->Get().representative; + int src_depth = switch_depth[src_id]; + if (!e->IsControlEdge() || new_switch_depth == src_depth) { + if (src_depth != new_switch_depth) { + return errors::InvalidArgument( + "Unable to functionalize control flow in graph: Operand ('", + src->name(), "') and operator ('", n->name(), + "') have different switch depths (", src_depth, + " != ", new_switch_depth, ")"); } + cluster->Merge(&clusters[n->id()]); } } } -} -void FunctionalizeCond::ContractEdge(Cluster* from, Cluster* to, - bool remove_from_graph) { - VLOG(3) << "ContractEdge from = " << from->representative - << " to = " << to->representative; - if (from->representative == to->representative) { - return; - } - to->merge_nodes.insert(from->merge_nodes.begin(), from->merge_nodes.end()); - from->merge_nodes.clear(); - to->switch_nodes.insert(from->switch_nodes.begin(), from->switch_nodes.end()); - from->switch_nodes.clear(); - - for (Cluster* from_out : from->out_nodes) { - from_out->in_nodes.erase(from); - if (from_out->representative != to->representative) { - from_out->in_nodes.insert(to); - to->out_nodes.insert(from_out); + if (dump_graphs_) { + // Mark the switch cluster each node is part of. + for (Node* n : graph_->nodes()) { + n->ClearAttr("_XlaFunctionalizeSwitchGroup"); + n->AddAttr("_XlaFunctionalizeSwitchGroup", + clusters[n->id()].Get().representative); } + LOG(INFO) << "FunctionalizeControlFlow (with_clusters): " + << dump_graph::DumpGraphToFile("functionalize_clustered", *graph_, + library_); } - from->out_nodes.clear(); - for (Cluster* from_in : from->in_nodes) { - from_in->out_nodes.erase(from); - if (from_in->representative != to->representative) { - from_in->out_nodes.insert(to); - to->in_nodes.insert(from_in); + // Verify all the nodes of a cluster are at the same depth. + std::unordered_map> cluster_to_depth_node; + for (Node* n : graph_->nodes()) { + int depth = switch_depth[n->id()]; + int cluster_rep = clusters[n->id()].Get().representative; + auto it = cluster_to_depth_node.find(cluster_rep); + if (it == cluster_to_depth_node.end()) { + cluster_to_depth_node[cluster_rep] = std::make_pair(depth, n); + } else { + if (it->second.first != depth) { + return errors::Internal( + "Illegal clustering created, mismatch in depths:", "\n\t", + n->DebugString(), "(", clusters[n->id()].Get().representative, + ") at depth=", depth, " vs\n\t", it->second.second->DebugString(), + "(", clusters[n->id()].Get().representative, ") at depth ", + it->second.first); + } } } - from->in_nodes.clear(); - to->in_nodes.erase(from); - to->out_nodes.erase(from); - clusters_.at(to->representative).Merge(&clusters_.at(from->representative)); - from->visited = true; + struct Hash { + size_t operator()(const std::pair& item) const { + return Hash64Combine(hash()(item.first), + std::hash()(item.second.representative)); + } + }; - if (remove_from_graph) { - clustered_graph_.erase(from->representative); + // Merge Switch nodes with common predicate. + std::unordered_map, int, Hash> predicate_index; + // The nodes in switch_order are in reverse topological order, but the + // clustered switches need not be (i.e., when considered as a cluster one + // element of a cluster may be later in the topological order than another + // node whose cluster is later in the topological order of clustered + // switches). + for (auto it = switch_order.rbegin(); it != switch_order.rend(); ++it) { + Node* pred; + TF_CHECK_OK((*it)->input_node(1, &pred)); + auto repr = std::make_pair(pred, clusters[(*it)->id()].Get()); + if (predicate_index.find(repr) == predicate_index.end()) { + predicate_index[repr] = switch_clusters.size(); + switch_clusters.emplace_back(pred); + // Generate a name by concatenating with the cluster representative as + // there could be multiple switch clusters with the same predicate. + switch_clusters[predicate_index[repr]].name = + strings::StrCat(pred->name(), "_", repr.second.representative, "_If"); + } + switch_clusters[predicate_index[repr]].switches.push_back(*it); } + + return switch_clusters; } -void FunctionalizeCond::CreateClusteredGraph() { - auto update_cluster_for_node = [this](Node* node) -> Cluster& { - ClusterHandle repr = Representative(node); - Cluster& cluster_node = clustered_graph_[repr]; - cluster_node.representative = repr; - for (const Node* in : node->in_nodes()) { - ClusterHandle other_repr = Representative(in); - // Skip source, sink and internal edges. - if (other_repr == repr) { - continue; +StatusOr> +FunctionalizeCond::EnsureDominanceAndReturnNonDominatedControlNodes( + const std::unordered_map& branch_map, + const std::vector& switches) { + std::vector old_control_nodes; + for (const auto& kv : branch_map) { + if (kv.second.count != kv.first->in_edges().size()) { + std::vector delete_edges; + for (const Edge* in : kv.first->in_edges()) { + auto it = branch_map.find(in->src()); + if (it == branch_map.end()) { + if (in->IsControlEdge()) { + old_control_nodes.push_back(in->src()); + delete_edges.push_back(in); + } else { + if (IsSwitch(in->src())) { + if (std::find(switches.begin(), switches.end(), in->src()) == + switches.end()) { + return errors::Internal( + "Unexpected switch node found during flow forward: ", + in->src()->DebugString()); + } + continue; + } + return errors::InvalidArgument( + "Value ", kv.first->name(), "'s input, ", in->src()->name(), + ", is not dominated by switch nodes ", NodesToString(switches)); + } + } } - Cluster& cluster_node_in = clustered_graph_[other_repr]; - cluster_node.in_nodes.insert(&cluster_node_in); - cluster_node_in.out_nodes.insert(&cluster_node); - cluster_node_in.representative = other_repr; - } - for (const Node* out : node->out_nodes()) { - ClusterHandle other_repr = Representative(out); - // Skip source, sink and internal edges. - if (other_repr == repr) { - continue; + // Remove control edges from nodes that are not dominated by the switch + // nodes. New control dependencies will be added between these nodes and + // the XlaIf node inserted. + for (const Edge* e : delete_edges) { + graph_->RemoveEdge(e); } - Cluster& cluster_node_out = clustered_graph_[other_repr]; - cluster_node.out_nodes.insert(&cluster_node_out); - cluster_node_out.in_nodes.insert(&cluster_node); - cluster_node_out.representative = other_repr; } - return cluster_node; - }; - update_cluster_for_node(graph_->source_node()); - for (Node* node : switch_nodes_) { - update_cluster_for_node(node).switch_nodes.insert(node); - } - for (Node* node : merge_nodes_) { - update_cluster_for_node(node).merge_nodes.insert(node); } - - VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); - VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); + return old_control_nodes; } -gtl::optional -FunctionalizeCond::CreateCorrespondingMergeCluster(Cluster* switch_cluster) { - VLOG(3) << "CreateCorrespondingMergeCluster for " - << switch_cluster->representative; - std::unordered_set merges; - std::unordered_set dominated; - dominated.insert(switch_cluster); - std::deque queue; - auto enqueue_or_update_merge = [this, &queue, &merges](Cluster* c) { - if (c->merge_nodes.empty()) { - queue.push_back(c); - } else { - merges.insert(c); - } - }; - // Enqueue all the outputs of the switch cluster in the workqueue. - for (auto* out : switch_cluster->out_nodes) { - enqueue_or_update_merge(out); - } - std::unordered_set visited; - while (!queue.empty()) { - Cluster* cur = queue.front(); - queue.pop_front(); - if (visited.find(cur) != visited.end()) { +StatusOr< + std::pair, + std::unordered_set>> +FunctionalizeCond::DetermineBranchMapAndFrontier( + const SwitchCluster& switch_cluster) { + std::unordered_map branch_map; + std::unordered_set frontier; + std::vector stack = switch_cluster.switches; + std::vector visited(graph_->num_node_ids(), false); + while (!stack.empty()) { + Node* n = stack.back(); + stack.pop_back(); + + if (visited[n->id()]) { continue; } - visited.insert(cur); - // Ensure all inputs to the current node are in the dominated set. - for (Cluster* in : cur->in_nodes) { - if (dominated.find(in) == dominated.end()) { - return gtl::nullopt; + visited[n->id()] = true; + + // Propagate branch state along each edge of a switch node. + bool sink_only = true; + for (const Edge* e : n->out_edges()) { + Node* out = e->dst(); + if (!out->IsOp()) { + continue; + } + sink_only = false; + // Propagate branch information. + ForwardFlowNode& ffn = branch_map[out]; + if (IsSwitch(n)) { + int index = e->IsControlEdge() ? Branch::kNeither : e->src_output(); + TF_RETURN_IF_ERROR(Join(ForwardFlowNode(Branch(index)), out, &ffn)); + } else { + TF_RETURN_IF_ERROR(Join(branch_map[n], out, &ffn)); + } + if (IsMerge(out)) { + if (out->in_edges().size() == ffn.count) { + frontier.insert(out); + } + } else if (!visited[out->id()]) { + stack.push_back(out); } } - for (Cluster* out : cur->out_nodes) { - // No switch nodes beyond the entry one is expected. - if (!out->switch_nodes.empty()) { - return gtl::nullopt; + if (sink_only) { + if (!IsIdentity(n)) { + VLOG(1) << "Feeding into sink: " << n->DebugString(); } - enqueue_or_update_merge(out); } } - auto it = merges.begin(); - Cluster* merge_cluster = *it; - for (++it; it != merges.end(); ++it) { - ContractEdge(*it, merge_cluster); - } - - // TODO(jpienaar): Clean up graph, merging nodes. - return merge_cluster; + if (dump_graphs_) { + for (const auto& kv : branch_map) { + // Append attribute to the graph if running with logging to make the + // changes clearer in the visualization. + kv.first->AddAttr("_XlaFunctionalizeBranch", + Branch_Name(kv.second.branch)); + } + } + return std::make_pair(std::move(branch_map), std::move(frontier)); } -xla::StatusOr FunctionalizeCond::DetermineCondArgs( - const Cluster& merge_cluster, const Cluster& switch_cluster) { - VLOG(2) << "DetermineCondArgs for " << merge_cluster.representative - << " with switch cluster " << switch_cluster.representative; - CondArgs ret; - auto feeds_into_branch_cluster = [&](Node* switch_cluster) { - for (Node* out : switch_cluster->out_nodes()) { - ClusterHandle repr = Representative(out); - if (repr == merge_cluster.representative) { - return true; - } - for (Cluster* in : merge_cluster.in_nodes) { - if (repr == in->representative) { - return true; - } +Status FunctionalizeCond::FunctionalizeInternal() { + TF_ASSIGN_OR_RETURN(std::vector predicate_switch_order, + DeterminePredicateSwitchOrder()); + + // Iterate from innermost set of clustered switches to outermost, replacing + // matching switch->merge subgraphs with single XlaIf nodes. + for (auto it = predicate_switch_order.rbegin(); + it != predicate_switch_order.rend(); ++it) { + auto& ps = *it; + VLOG(3) << "Flow down from: " << NodesToString(ps.switches) << " (" + << ps.predicate->name() << ")"; + + std::unordered_map branch_map; + std::unordered_set frontier; + TF_ASSIGN_OR_RETURN(std::tie(branch_map, frontier), + DetermineBranchMapAndFrontier(ps)); + + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (before XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_bc", *graph_, + library_); + TF_RETURN_IF_ERROR(ValidateFrontier(branch_map, frontier)); + + // Sort the merge and switch nodes using NodeCmp. The switch-nodes are + // further grouped (post sorting) by input to the switch node as in the + // functionalized form each input will be passed in only once. This grouping + // should retain the sorted order. + CondArgNodes cond_arg_nodes; + std::unordered_map input_index; + std::sort(ps.switches.begin(), ps.switches.end(), NodeCmp()); + for (Node* switch_node : ps.switches) { + Node* in; + TF_RETURN_IF_ERROR(switch_node->input_node(0, &in)); + if (input_index.find(in) == input_index.end()) { + input_index[in] = cond_arg_nodes.size(); + cond_arg_nodes.emplace_back(in); } + cond_arg_nodes.at(input_index.at(in)).switches.push_back(switch_node); } - return false; - }; - for (Node* switch_cluster_node : switch_cluster.switch_nodes) { - if (!feeds_into_branch_cluster(switch_cluster_node)) { - continue; + std::vector merge_nodes(frontier.begin(), frontier.end()); + std::sort(merge_nodes.begin(), merge_nodes.end(), NodeCmp()); + + TF_ASSIGN_OR_RETURN(std::vector old_control_nodes, + EnsureDominanceAndReturnNonDominatedControlNodes( + branch_map, ps.switches)); + + TF_ASSIGN_OR_RETURN(Node * if_node, + ConvertToXlaIf(cond_arg_nodes, ps, merge_nodes)); + for (Node* old : old_control_nodes) { + graph_->AddControlEdge(old, if_node); } - Node* tmp; - TF_RETURN_IF_ERROR(switch_cluster_node->input_node(1, &tmp)); - if (ret.conditional == nullptr) { - ret.conditional = tmp; - } else if (ret.conditional != tmp) { - return errors::Unimplemented( - "Switch statements with different conditionals cannot be " - "converted into functional conditional."); + for (auto& del_kv : branch_map) { + graph_->RemoveNode(del_kv.first); + } + for (auto& kv : cond_arg_nodes) { + for (Node* node : kv.switches) { + graph_->RemoveNode(node); + } } - ret.args.insert(switch_cluster_node); + if (dump_graphs_) + LOG(INFO) << "FunctionalizeControlFlow (after XlaIf conversion): " + << dump_graph::DumpGraphToFile("functionalize_ac", *graph_, + library_); } - return ret; + return Status::OK(); } -xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( - const CondArgs& cond_args, const Cluster& merge_cluster, - const std::vector& outputs) { - VLOG(2) << "Build if op for " << NodesToString(merge_cluster.merge_nodes) - << " with input " << NodesToString(cond_args.args); +StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(2) << "Build if op for " << switch_cluster.name; NodeDef if_def; // Create a new If node using the name of the merge node. - NodeDefBuilder builder( - strings::StrCat((*merge_cluster.merge_nodes.begin())->name(), "_If"), - "XlaIf"); + NodeDefBuilder builder(switch_cluster.name, "XlaIf"); string branch[] = {"else_branch", "then_branch"}; for (int i = 0; i < 2; ++i) { static std::atomic sequence_num(0LL); @@ -1136,8 +1159,8 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( body_name.set_name( strings::StrCat("_functionalize_if_", branch[i], "_", id)); auto body = xla::MakeUnique(graph_->op_registry()); - TF_RETURN_IF_ERROR( - ExtractBody(cond_args, merge_cluster, outputs, i, body.get())); + TF_RETURN_IF_ERROR(ExtractBody(cond_arg_nodes, switch_cluster.switches, + merge_nodes, i, body.get())); VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); FunctionDef body_fdef; TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); @@ -1148,33 +1171,40 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( // Build input type. std::vector inputs; DataTypeVector in_arg_types; - for (const Node* arg : cond_args.args) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - builder.ControlInput(in_edge->src()->name()); - } else { - DataType dtype = arg->input_type(0); - inputs.emplace_back(NodeDefBuilder::NodeOut( - in_edge->src()->name(), in_edge->src_output(), dtype)); - in_arg_types.push_back(dtype); + for (auto& kv : cond_arg_nodes) { + bool inserted = false; + for (const Node* arg : kv.switches) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + if (!inserted) { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + inserted = true; + } + } } } builder.Attr("Tin", in_arg_types); // Build output type. DataTypeVector out_type; - for (const Node* merge : merge_cluster.merge_nodes) { + for (const Node* merge : merge_nodes) { DataType dtype = merge->output_type(0); out_type.push_back(dtype); } builder.Attr("Tout", out_type); builder.Attr("Tcond", DT_BOOL); - builder.Device(cond_args.conditional->assigned_device_name()); + builder.Device(switch_cluster.predicate->assigned_device_name()); // Conditional should be the first input ... - builder.Input(NodeDefBuilder::NodeOut(cond_args.conditional->name(), 0, - cond_args.conditional->output_type(0))); + builder.Input( + NodeDefBuilder::NodeOut(switch_cluster.predicate->name(), 0, + switch_cluster.predicate->output_type(0))); // ... followed by the other inputs. builder.Input(inputs); @@ -1183,64 +1213,31 @@ xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( return if_node; } -void FunctionalizeCond::RemoveClusterNodes(Cluster* cluster) { - VLOG(3) << "RemoveClusterNodes for " << cluster->representative; - ClusterHandle repr = cluster->representative; - std::deque to_delete; - for (Node* node : graph_->nodes()) { - if (Representative(node) == repr) { - to_delete.push_back(node); - } - } - for (Node* n : to_delete) { - graph_->RemoveNode(n); - } -} - -template -void FunctionalizeCond::RemoveUnusedArgs(const T& args) { - VLOG(2) << "RemoveUnusedArgs among: " << NodesToString(args); - - std::deque to_delete; - for (Node* arg : args) { - if (IsDeadSwitch(arg)) { - to_delete.push_back(arg); - for (Node* n : arg->out_nodes()) { - to_delete.push_back(n); - } - } - } - for (Node* n : to_delete) { - switch_nodes_.erase(n); - auto it = clustered_graph_.find(Representative(n)); - if (it != clustered_graph_.end()) { - it->second.switch_nodes.erase(n); - } - graph_->RemoveNode(n); - } -} - -Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, - const Cluster& merge_cluster, - const std::vector& outputs, +Status FunctionalizeCond::ExtractBody(const CondArgNodes& cond_arg_nodes, + const std::vector& switches, + const std::vector& merge_nodes, int input_edge, Graph* body) { - VLOG(2) << "ExtractBody for " << merge_cluster.representative - << " along edge " << input_edge; + VLOG(2) << "ExtractBody for " << NodesToString(merge_nodes) << " along edge " + << input_edge; std::vector squash_src_outputs(graph_->num_node_ids(), false); std::vector node_map(graph_->num_node_ids(), nullptr); int arg_count = 0; - for (const auto* arg : cond_args.args) { - DataType dtype = arg->input_type(0); - TF_ASSIGN_OR_RETURN(Node * arg_node, - BuildArgNode(body, dtype, arg_count++)); - node_map.at(arg->id()) = arg_node; - squash_src_outputs.at(arg->id()) = true; + for (auto& kv : cond_arg_nodes) { + Node* arg_node = nullptr; + for (const auto* arg : kv.switches) { + DataType dtype = arg->input_type(0); + if (arg_node == nullptr) { + TF_ASSIGN_OR_RETURN(arg_node, BuildArgNode(body, dtype, arg_count++)); + } + node_map.at(arg->id()) = arg_node; + squash_src_outputs.at(arg->id()) = true; + } } std::vector stack; - stack.reserve(outputs.size()); - for (int j = 0; j < outputs.size(); ++j) { - Node* node = outputs[j]; + stack.reserve(merge_nodes.size()); + for (int j = 0; j < merge_nodes.size(); ++j) { + Node* node = merge_nodes[j]; TF_ASSIGN_OR_RETURN(node_map.at(node->id()), BuildRetvalNode(body, node->output_type(0), /*index=*/j)); @@ -1251,7 +1248,7 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, node_map.at(in->id()) = body->CopyNode(in); } - if (cond_args.args.find(in) == cond_args.args.end()) { + if (std::find(switches.begin(), switches.end(), in) == switches.end()) { body->AddEdge(node_map.at(in->id()), in_edge->src_output(), node_map.at(node->id()), 0); } else { @@ -1266,18 +1263,25 @@ Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, body); } -Status FunctionalizeCond::AddInputEdges(const CondArgs& cond_args, - Node* if_node) { +Status FunctionalizeCond::AddInputEdges(const CondArgNodes& cond_arg_nodes, + Node* predicate, Node* if_node) { VLOG(3) << "AddInputEdges for " << if_node->name(); - int i = 0; - graph_->AddEdge(cond_args.conditional, 0, if_node, i++); - for (const Node* arg : cond_args.args) { - const Edge* in_edge; - TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); - if (in_edge->IsControlEdge()) { - graph_->AddControlEdge(in_edge->src(), if_node); - } else { - graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, i++); + int index = 0; + graph_->AddEdge(predicate, 0, if_node, index++); + for (auto& kv : cond_arg_nodes) { + bool inserted = false; + for (const Node* arg : kv.switches) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph_->AddControlEdge(in_edge->src(), if_node); + } else { + if (!inserted) { + graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, + index++); + inserted = true; + } + } } } return Status::OK(); @@ -1308,196 +1312,39 @@ Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, return Status::OK(); } -void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { - VLOG(3) << "RemoveMergeNodes for " << merge_cluster->representative; - // Remove all merge nodes now dead post extraction of If. - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - Node* node = *it; - graph_->RemoveNode(node); - merge_cluster->merge_nodes.erase(*it++); - } -} - -Status FunctionalizeCond::RemoveTrivialSwitch(Cluster* switch_cluster) { - Cluster* merge_cluster = *switch_cluster->out_nodes.begin(); - if (merge_cluster->merge_nodes.empty()) { - return errors::FailedPrecondition( - "Not a trivial switch: no Merge node feeding into Switch node"); - } - - for (auto it = merge_cluster->merge_nodes.begin(); - it != merge_cluster->merge_nodes.end();) { - // We have the following structure: - // Op -> Switch -> Merge -> Consumer - // and we want to transform it to: - // Op -> Consumer - Node* merge_node = *it; - Node* switch_node; - const Edge* in = nullptr; - TF_RETURN_IF_ERROR(merge_node->input_node(0, &switch_node)); - TF_RETURN_IF_ERROR(switch_node->input_edge(0, &in)); - for (auto out : merge_node->out_edges()) { - int src_output = out->dst_input() == Graph::kControlSlot - ? Graph::kControlSlot - : in->src_output(); - graph_->AddEdge(in->src(), src_output, out->dst(), out->dst_input()); - } - graph_->RemoveNode(*it++); - } - RemoveUnusedArgs(switch_cluster->switch_nodes); - - return Status::OK(); -} - -Status FunctionalizeCond::ConvertCorrespondingMergeToXlaIf( - Cluster* switch_cluster) { - VLOG(1) << "ConvertMergeToXlaIf for " << switch_cluster->representative; - gtl::optional maybe_merge = - CreateCorrespondingMergeCluster(switch_cluster); - if (!maybe_merge.has_value()) { - return errors::FailedPrecondition( - "Switch cluster was not part of a simple conditional in the clustered " - "graph. Graph nodes in switch cluster ", - NodesToString(switch_cluster->switch_nodes)); - } - Cluster* merge_cluster = *maybe_merge; - if (merge_cluster->merge_nodes.empty()) { - return errors::Internal( - "Merge node in clustered graph contains no merge nodes: ", - merge_cluster->representative.ToString()); - } - TF_ASSIGN_OR_RETURN(auto cond_args, - DetermineCondArgs(*merge_cluster, *switch_cluster)); - - // Sort the outputs by ID to produce more stable output. - std::vector outputs(merge_cluster->merge_nodes.begin(), - merge_cluster->merge_nodes.end()); - std::sort(outputs.begin(), outputs.end(), CondArgs::CondCmp()); +StatusOr FunctionalizeCond::ConvertToXlaIf( + const CondArgNodes& cond_arg_nodes, const SwitchCluster& switch_cluster, + const std::vector& merge_nodes) { + VLOG(1) << "ConvertToXlaIf for " << switch_cluster.ToString() << " -> " + << NodesToString(merge_nodes); // Extract bodies and builds a If operator. - TF_ASSIGN_OR_RETURN(Node * if_node, - BuildAndAddXlaIfOp(cond_args, *merge_cluster, outputs)); - TF_RETURN_IF_ERROR(AddInputEdges(cond_args, if_node)); - TF_RETURN_IF_ERROR(AddOutputEdges(outputs, if_node)); - - // Remove the old nodes from the graph_ and contract the edges of the - // clustered graph. - for (auto in : merge_cluster->in_nodes) { - if (in != switch_cluster) { - RemoveClusterNodes(in); - } - } - RemoveMergeNodes(merge_cluster); - RemoveUnusedArgs(cond_args.args); - auto in_nodes = merge_cluster->in_nodes; - for (auto it = in_nodes.begin(); it != in_nodes.end();) { - ContractEdge(*it++, switch_cluster); - } - ContractEdge(merge_cluster, switch_cluster); - clusters_[if_node].Get() = ClusterHandle(switch_cluster->representative); - - return Status::OK(); -} - -std::vector> -FunctionalizeCond::SortedSwitchNodes() { - VLOG(2) << "ProcessClusteredGraph"; - std::stack> stack; - // Initialize with the source node. - stack.push({0, &clustered_graph_[Representative(graph_->source_node())]}); - - // Perform a depth-first traversal of the clustered graph computing the - // switch-merge depth. - std::vector> queue; - std::unordered_set visited; - while (!stack.empty()) { - Cluster* n = stack.top().second; - size_t depth = stack.top().first; - stack.pop(); - - auto inserted = visited.insert(n); - if (!inserted.second) { - continue; - } - - size_t new_depth = depth; - if (!n->merge_nodes.empty()) { - --new_depth; - } - if (!n->switch_nodes.empty()) { - queue.emplace_back(depth, n); - ++new_depth; - } - for (Cluster* e : n->out_nodes) { - stack.emplace(new_depth, e); - } - } - - // Sort in reverse order of switch-merge depth with ties broken by the - // ClusterHandle. - std::sort(queue.begin(), queue.end(), - [](const std::pair& lhs, - const std::pair& rhs) { - return std::tie(lhs.first, lhs.second->representative) > - std::tie(rhs.first, rhs.second->representative); - }); + TF_ASSIGN_OR_RETURN( + Node * if_node, + BuildAndAddXlaIfOp(cond_arg_nodes, switch_cluster, merge_nodes)); + TF_RETURN_IF_ERROR( + AddInputEdges(cond_arg_nodes, switch_cluster.predicate, if_node)); + TF_RETURN_IF_ERROR(AddOutputEdges(merge_nodes, if_node)); - return queue; + return if_node; } Status FunctionalizeCond::Functionalize(Graph* graph, FunctionLibraryDefinition* library) { VLOG(1) << "FunctionalizeCond::Functionalize"; - FunctionalizeCond fc(graph, library); - fc.CreateClusters(); - if (fc.NoConditionals()) { - return Status::OK(); - } - fc.CreateClusteredGraph(); - - auto queue = fc.SortedSwitchNodes(); - for (auto it = queue.begin(); it != queue.end();) { - Cluster* switch_cluster = (*it).second; - ++it; - if (switch_cluster->out_nodes.size() == 1) { - TF_RETURN_IF_ERROR(fc.RemoveTrivialSwitch(switch_cluster)); - } else { - TF_RETURN_IF_ERROR(fc.ConvertCorrespondingMergeToXlaIf(switch_cluster)); - } - - // Contract newly Switch free switch_cluster with outgoing nodes without - // Switch or Merge nodes. - for (auto& nodes : {switch_cluster->out_nodes, switch_cluster->in_nodes}) { - std::vector copy_nodes(nodes.begin(), nodes.end()); - for (auto* node : copy_nodes) { - if (node->merge_nodes.empty() && node->switch_nodes.empty()) { - fc.ContractEdge(node, switch_cluster); - } - } - } - - VLOG(3) << "Graph with clusters: " - << DebugString(*fc.graph_, &fc.clusters_); - VLOG(3) << "ClusteredGraph: " << DebugString(fc.clustered_graph_); - } - - if (!fc.switch_nodes_.empty()) { - return errors::Internal( - "Failed to functionalize control flow with Switch nodes remaining: ", - NodesToString(fc.switch_nodes_)); - } - return Status::OK(); + FunctionalizeCond fc(graph, library, /*dump_graphs=*/VLOG_IS_ON(2)); + return fc.FunctionalizeInternal(); } } // namespace -// Transformation that converts Tensorflow's graph control flow constructs into +// Transformation that converts TensorFlow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { VLOG(2) << "FunctionalizeControlFlow (initial): " - << dump_graph::DumpGraphToFile("functionalize_initial", *graph); + << dump_graph::DumpGraphToFile("functionalize_initial", *graph, + library); // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this // invariant. @@ -1509,7 +1356,8 @@ Status FunctionalizeControlFlow(Graph* graph, for (Node* node : graph->op_nodes()) { const ControlFlowInfo& cf = cf_info[node->id()]; - VLOG(2) << "node: " << node->name() << " frame_name: " << cf.frame_name + VLOG(2) << "node: " << node->name() << " (" << node->id() + << ") frame_name: " << cf.frame_name << " frame: " << (cf.frame ? cf.frame->name() : "---") << " parent_frame: " << (cf.parent_frame ? cf.parent_frame->name() : "---"); @@ -1577,7 +1425,8 @@ Status FunctionalizeControlFlow(Graph* graph, TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); VLOG(2) << "FunctionalizeControlFlow (final): " - << dump_graph::DumpGraphToFile("functionalize_final", *graph); + << dump_graph::DumpGraphToFile("functionalize_final", *graph, + library); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 01d2b282751f387cfa9c8887cdeb48090c96bff4..bc7276c3afd5060d6faeceb4d479416299ecc5da 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -38,10 +38,11 @@ namespace { // Returns the names of the "then" and "else" functions for the XlaIf node in a // graph. -Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn, - NameAttrList* else_fn) { +Status FindIfThenAndElse(const GraphDef& graph, string* op_name, + NameAttrList* then_fn, NameAttrList* else_fn) { for (const NodeDef& node : graph.node()) { if (node.op() == "XlaIf") { + *op_name = node.name(); const NameAttrList* result; TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); *then_fn = *result; @@ -96,9 +97,10 @@ TEST(FunctionalizeControlFlow, Conditional) { GraphDef graph_def; graph.ToGraphDef(&graph_def); + string op_name; NameAttrList then_fn; NameAttrList else_fn; - TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn)); + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &op_name, &then_fn, &else_fn)); InstantiationResultForTest else_result; TF_EXPECT_OK( InstantiateFunctionForTest(else_fn.name(), library, &else_result)); @@ -109,7 +111,7 @@ TEST(FunctionalizeControlFlow, Conditional) { auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); - auto if_op = ops::XlaIf(scope.WithOpName("cond/Merge_If"), less, + auto if_op = ops::XlaIf(scope.WithOpName(op_name), less, std::initializer_list{less, y, x}, then_fn, else_fn, {DT_INT32}); GraphDef expected; diff --git a/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..91351421bcacd26c41b5c9f98ea833730e4aef30 --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/cpu_supported_ops.md @@ -0,0 +1,266 @@ +**Supported operators for device: XLA_CPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`AdjustContrastv2` | +`AdjustHue` | +`AdjustSaturation` | +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={float}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}` +`FFT` | +`FFT2D` | +`FFT3D` | +`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`HSVToRGB` | `T={double,float}` +`IFFT` | +`IFFT2D` | +`IFFT3D` | +`IRFFT` | +`IRFFT2D` | +`IRFFT3D` | +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixTriangularSolve` | `T={complex64,double,float}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolV2` | `T={double,float,int32,int64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`RFFT` | +`RFFT2D` | +`RFFT3D` | +`RGBToHSV` | `T={double,float}` +`RandomStandardNormal` | `dtype={float}` +`RandomUniform` | `T={int32,int64}`
`dtype={double,float}` +`RandomUniformInt` | `T={int32,int64}`
`Tout={int32,int64}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResizeBilinear` | `T={double,float,int32,int64}` +`ResizeBilinearGrad` | `T={double,float}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseSequence` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`TruncatedNormal` | `T={int32,int64}`
`dtype={double,float}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_CPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md new file mode 100644 index 0000000000000000000000000000000000000000..b9bdb829d773825005a8921f48d28b6892d8f0cd --- /dev/null +++ b/tensorflow/compiler/tf2xla/g3doc/gpu_supported_ops.md @@ -0,0 +1,262 @@ +**Supported operators for device: XLA_GPU_JIT** + +Operator | Type Constraint +------------------------------------- | --------------- +`Abs` | `T={double,float,int32,int64}` +`Acosh` | `T={complex64,double,float}` +`Add` | `T={complex64,double,float,int32,int64}` +`AddN` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`AdjustContrastv2` | +`AdjustHue` | +`AdjustSaturation` | +`All` | `Tidx={int32,int64}` +`Angle` | `Tout={double,float}`
`T={complex64}` +`Any` | `Tidx={int32,int64}` +`ApproximateEqual` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMax` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`ArgMin` | `Tidx={int32,int64}`
`output_type={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Asinh` | `T={complex64,double,float}` +`AssignAddVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignSubVariableOp` | `dtype={complex64,double,float,int32,int64,uint32,uint64}` +`AssignVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Atan2` | `T={double,float}` +`Atanh` | `T={complex64,double,float}` +`AvgPool` | `T={double,float}` +`AvgPool3D` | `T={double,float}` +`AvgPool3DGrad` | `T={double,float}` +`AvgPoolGrad` | `T={double,float}` +`BatchMatMul` | `T={complex64,double,float,int32}` +`BatchToSpace` | `Tidx={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BatchToSpaceND` | `Tcrops={int32,int64}`
`Tblock_shape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`BiasAdd` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddGrad` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BiasAddV1` | `T={complex64,double,float,int32,int64,uint32,uint64}` +`BitwiseAnd` | `T={int32,int64,uint32,uint64}` +`BitwiseOr` | `T={int32,int64,uint32,uint64}` +`BroadcastArgs` | `T={int32,int64}` +`BroadcastGradientArgs` | `T={int32,int64}` +`Cast` | `DstT={bool,complex64,double,float,int32,int64,uint32,uint64}`
`SrcT={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Ceil` | `T={double,float}` +`Cholesky` | `T={double,float}` +`Complex` | `Tout={complex64}`
`T={double,float}` +`ComplexAbs` | `Tout={double,float}`
`T={complex64}` +`Concat` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ConcatOffset` | +`ConcatV2` | `Tidx={int32}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Conj` | `T={complex64}` +`Const` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ControlTrigger` | +`Conv2D` | `T={float}` +`Conv2DBackpropFilter` | `T={float}` +`Conv2DBackpropInput` | `T={float}` +`Conv3D` | `T={double,float}` +`Conv3DBackpropFilterV2` | `T={double,float}` +`Conv3DBackpropInputV2` | `T={double,float}` +`Cos` | `T={complex64,double,float}` +`Cosh` | `T={complex64,double,float}` +`Cross` | `T={double,float,int32,int64,uint32,uint64}` +`Cumprod` | `Tidx={int32,int64}`
`T={float}` +`Cumsum` | `Tidx={int32,int64}`
`T={float}` +`DepthToSpace` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`DepthwiseConv2dNative` | `T={double,float}` +`DepthwiseConv2dNativeBackpropFilter` | `T={double,float}` +`DepthwiseConv2dNativeBackpropInput` | `T={double,float}` +`Diag` | `T={complex64,double,float,int32,int64}` +`DiagPart` | `T={complex64,double,float,int32,int64}` +`Div` | `T={complex64,double,float,int32,int64}` +`DynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Elu` | `T={double,float}` +`EluGrad` | `T={double,float}` +`Equal` | `T={bool,complex64,double,float,int32,int64}` +`Exp` | `T={complex64,double,float}` +`ExpandDims` | `Tdim={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Expm1` | `T={complex64,double,float}` +`ExtractImagePatches` | `T={double,float,int32,int64,uint32,uint64}` +`FFT` | +`FFT2D` | +`FFT3D` | +`Fill` | `index_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Floor` | `T={double,float}` +`FloorDiv` | `T={complex64,double,float,int32,int64}` +`FloorMod` | `T={double,float,int32,int64}` +`FusedBatchNorm` | `T={float}` +`FusedBatchNormGrad` | `T={float}` +`FusedBatchNormGradV2` | `U={float}`
`T={float}` +`FusedBatchNormV2` | `U={float}`
`T={float}` +`Gather` | `Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`GatherV2` | `Taxis={int32,int64}`
`Tindices={int32,int64}`
`Tparams={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Greater` | `T={double,float,int32,int64,uint32,uint64}` +`GreaterEqual` | `T={double,float,int32,int64,uint32,uint64}` +`HSVToRGB` | `T={double,float}` +`IFFT` | +`IFFT2D` | +`IFFT3D` | +`IRFFT` | +`IRFFT2D` | +`IRFFT3D` | +`Identity` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`IdentityN` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Imag` | `Tout={double,float}`
`T={complex64}` +`Inv` | `T={complex64,double,float,int32,int64}` +`Invert` | `T={int32,int64,uint32,uint64}` +`InvertPermutation` | `T={int32}` +`IsFinite` | `T={double,float}` +`IsInf` | `T={double,float}` +`IsNan` | `T={double,float}` +`L2Loss` | `T={double,float}` +`LRN` | `T={float}` +`LRNGrad` | `T={float}` +`LeftShift` | `T={int32,int64,uint32,uint64}` +`Less` | `T={double,float,int32,int64,uint32,uint64}` +`LessEqual` | `T={double,float,int32,int64,uint32,uint64}` +`LinSpace` | `Tidx={int32,int64}`
`T={double,float}` +`Log` | `T={complex64,double,float}` +`Log1p` | `T={complex64,double,float}` +`LogSoftmax` | `T={double,float}` +`LogicalAnd` | +`LogicalNot` | +`LogicalOr` | +`MatMul` | `T={complex64,double,float}` +`MatrixDiag` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixDiagPart` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`MatrixTriangularSolve` | `T={complex64,double,float}` +`Max` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`MaxPool` | `T={double,float,int32,int64}` +`MaxPool3D` | `T={float}` +`MaxPool3DGrad` | `TInput={float}`
`T={float}` +`MaxPoolGrad` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolGradV2` | `T={double,float,int32,int64,uint32,uint64}` +`MaxPoolV2` | `T={double,float,int32,int64}` +`Maximum` | `T={double,float,int32,int64}` +`Mean` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Min` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`Minimum` | `T={double,float,int32,int64}` +`MirrorPad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Mod` | `T={double,float,int32,int64}` +`Mul` | `T={complex64,double,float,int32,int64}` +`Multinomial` | `output_dtype={int32,int64}`
`T={double,float,int32,int64,uint32,uint64}` +`Neg` | `T={complex64,double,float,int32,int64}` +`NoOp` | +`NotEqual` | `T={bool,complex64,double,float,int32,int64}` +`OneHot` | `TI={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`OnesLike` | `T={bool,complex64,double,float,int32,int64}` +`Pack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pad` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`PadV2` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ParallelDynamicStitch` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Pow` | `T={complex64,double,float,int32,int64}` +`PreventGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Prod` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`QuantizeAndDequantizeV2` | `T={double,float}` +`RFFT` | +`RFFT2D` | +`RFFT3D` | +`RGBToHSV` | `T={double,float}` +`Range` | `Tidx={double,float,int32,int64}` +`Rank` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReadVariableOp` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Real` | `Tout={double,float}`
`T={complex64}` +`RealDiv` | `T={complex64,double,float,int32,int64}` +`Reciprocal` | `T={complex64,double,float,int32,int64}` +`ReciprocalGrad` | `T={complex64,double,float}` +`Relu` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6` | `T={double,float,int32,int64,uint32,uint64}` +`Relu6Grad` | `T={double,float,int32,int64,uint32,uint64}` +`ReluGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Reshape` | `Tshape={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ResizeBilinear` | `T={double,float,int32,int64}` +`ResizeBilinearGrad` | `T={double,float}` +`ResourceApplyAdagrad` | `T={double,float}` +`ResourceApplyAdam` | `T={double,float}` +`ResourceApplyFtrl` | `T={double,float}` +`ResourceApplyFtrlV2` | `T={double,float}` +`ResourceApplyGradientDescent` | `T={double,float}` +`ResourceApplyMomentum` | `T={double,float}` +`ResourceApplyRMSProp` | `T={double,float}` +`ResourceGather` | `Tindices={int32,int64}`
`dtype={complex64,double,float,int32,int64,uint32,uint64}` +`ResourceStridedSliceAssign` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Reverse` | `T={bool,complex64,double,float,int32,int64}` +`ReverseSequence` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ReverseV2` | `T={bool,complex64,double,float,int32,int64}`
`Tidx={int32,int64}` +`RightShift` | `T={int32,int64,uint32,uint64}` +`Rint` | `T={double,float}` +`Round` | `T={complex64,double,float,int32,int64}` +`Rsqrt` | `T={complex64,double,float}` +`RsqrtGrad` | `T={complex64,double,float}` +`Select` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Selu` | `T={double,float}` +`SeluGrad` | `T={double,float}` +`Shape` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`ShapeN` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sigmoid` | `T={complex64,double,float}` +`SigmoidGrad` | `T={complex64,double,float}` +`Sign` | `T={complex64,double,float,int32,int64}` +`Sin` | `T={complex64,double,float}` +`Sinh` | `T={complex64,double,float}` +`Size` | `out_type={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Slice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Softmax` | `T={double,float}` +`SoftmaxCrossEntropyWithLogits` | `T={double,float}` +`Softplus` | `T={double,float,int32,int64,uint32,uint64}` +`SoftplusGrad` | `T={double,float,int32,int64,uint32,uint64}` +`Softsign` | `T={double,float,int32,int64,uint32,uint64}` +`SoftsignGrad` | `T={double,float,int32,int64,uint32,uint64}` +`SpaceToBatch` | `Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToBatchND` | `Tblock_shape={int32,int64}`
`Tpaddings={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SpaceToDepth` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SparseMatMul` | `Tb={float}`
`Ta={float}` +`SparseSoftmaxCrossEntropyWithLogits` | `Tlabels={int32,int64}`
`T={double,float}` +`Split` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`SplitV` | `Tlen={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sqrt` | `T={complex64,double,float}` +`SqrtGrad` | `T={complex64,double,float}` +`Square` | `T={complex64,double,float,int32,int64}` +`SquaredDifference` | `T={complex64,double,float,int32,int64}` +`Squeeze` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackCloseV2` | +`StackPopV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackPushV2` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StackV2` | `elem_type={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StatelessRandomNormal` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StatelessRandomUniform` | `Tseed={int32}`
`T={int32,int64}`
`dtype={float}` +`StopGradient` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSlice` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`StridedSliceGrad` | `Index={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Sub` | `T={complex64,double,float,int32,int64}` +`Sum` | `Tidx={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`SymbolicGradient` | `Tout={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tan` | `T={complex64,double,float,int32,int64}` +`Tanh` | `T={complex64,double,float}` +`TanhGrad` | `T={complex64,double,float}` +`TensorArrayCloseV3` | +`TensorArrayConcatV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGatherV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayGradV3` | +`TensorArrayReadV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayScatterV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArraySizeV3` | +`TensorArraySplitV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayV3` | `dtype={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TensorArrayWriteV3` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Tile` | `Tmultiples={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`Transpose` | `Tperm={int32,int64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`TruncateDiv` | `T={complex64,double,float,int32,int64}` +`TruncateMod` | `T={double,float,int32,int64}` +`Unpack` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`UnsortedSegmentSum` | `Tnumsegments={int32,int64}`
`Tindices={int32,int64}`
`T={complex64,double,float,int32,int64,uint32,uint64}` +`VarIsInitializedOp` | +`VariableShape` | `out_type={int32,int64}` +`XlaWhile` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`ZerosLike` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Arg` | `T={bool,complex64,double,float,int32,int64,resource,uint32,uint64}` +`_ArrayToList` | `out_types={bool,complex64,double,float,int32,int64,uint32,uint64}`
`T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_ListToArray` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}`
`Tin={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_Retval` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLARecv` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` +`_XLASend` | `T={bool,complex64,double,float,int32,int64,uint32,uint64}` + +To regenerate this table, run: + +```shell +bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops --device=XLA_GPU_JIT +``` diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 8062f0c03ca60e88bd5c021092dceb105232219f..058a1f2621c64a735bd9d9c9d0ae007f93aa4dea 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" @@ -59,9 +60,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, for (int i = 0; i < args->size(); ++i) { XlaCompiler::Argument& arg = (*args)[i]; arg.type = ctx->input_type(i); - - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); if (arg.type == DT_RESOURCE) { return errors::InvalidArgument( @@ -135,7 +134,7 @@ Status GraphCompiler::Compile() { TF_RET_CHECK(src->id() < output_registry.size()); const NodeOutputs& src_outputs = output_registry[src->id()]; - tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()]; + tensor_inputs_.at(e->dst_input()) = src_outputs.at(e->src_output()); } OpKernelContext op_context(¶ms, n->num_outputs()); @@ -144,7 +143,9 @@ Status GraphCompiler::Compile() { } else { device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); Status s = op_context.status(); - TF_RETURN_IF_ERROR(s); + if (!s.ok()) { + return AttachDef(s, n->def()); + } } // Set up outputs. Also check if outputs from the previous computation is diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h index ba00160b6d78c1e55cc2e053cd5285344e0179fb..127562eb23d775f17179cc9ee968ec2255cf3a14 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.h +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -70,7 +70,7 @@ class GraphCompiler { private: // Partially sets params. This partially set params can be reused - // across multple nodes visit. + // across multiple nodes visit. void PartiallySetupParams(OpKernelContext::Params* params); // Tests if a node is a functional node. A functional node represents a diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 6302fece1ffb27b6c7170fcfb90f5985f5b50659..d2fa933cf9c085f92b2f442827a94d72938e4bb2 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -4,6 +4,7 @@ package( default_visibility = ["//tensorflow/compiler/tf2xla:internal"], ) +load("//tensorflow:tensorflow.bzl", "tf_copts") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") tf_kernel_library( @@ -30,15 +31,23 @@ tf_kernel_library( "diag_op.cc", "dynamic_stitch_op.cc", "elu_op.cc", + "extract_image_patches_op.cc", + "fake_quantize_ops.cc", + "fft_ops.cc", "fill_op.cc", "function_ops.cc", "gather_op.cc", "gather_op_helpers.h", "identity_op.cc", + "image_ops.cc", + "image_resize_ops.cc", "index_ops.cc", "l2loss_op.cc", "lrn_ops.cc", "matmul_op.cc", + "matrix_band_part_op.cc", + "matrix_set_diag_op.cc", + "matrix_triangular_solve_op.cc", "mirror_pad_op.cc", "no_op.cc", "one_hot_op.cc", @@ -54,11 +63,15 @@ tf_kernel_library( "reshape_op.cc", "retval_op.cc", "reverse_op.cc", + "reverse_sequence_op.cc", + "scan_ops.cc", + "scatter_nd_op.cc", "segment_reduction_ops.cc", "select_op.cc", "sendrecv_ops.cc", "sequence_ops.cc", "shape_op.cc", + "shape_util.cc", "slice_op.cc", "softmax_op.cc", "spacetobatch_op.cc", @@ -76,8 +89,8 @@ tf_kernel_library( "variable_ops.cc", ], hdrs = [ - "gather_op.h", "index_ops.h", + "shape_util.h", ], deps = [ ":while_op", @@ -85,18 +98,26 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/lib:batch_dot", "//tensorflow/compiler/tf2xla/lib:cholesky", + "//tensorflow/compiler/tf2xla/lib:scatter", + "//tensorflow/compiler/tf2xla/lib:triangular_solve", + "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/tf2xla/lib:while_loop", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", + "//tensorflow/compiler/xla:array4d", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core:framework", + "//tensorflow/core:image_ops_op_lib", "//tensorflow/core:lib", "//tensorflow/core:linalg_ops_op_lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:spectral_ops_op_lib", "//tensorflow/core:stateless_random_ops_op_lib", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", @@ -157,6 +178,7 @@ tf_kernel_library( cc_library( name = "index_ops_kernel_argmax_float_1d", srcs = ["index_ops_kernel_argmax_float_1d.cc"], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", @@ -169,6 +191,7 @@ cc_library( cc_library( name = "index_ops_kernel_argmax_float_2d", srcs = ["index_ops_kernel_argmax_float_2d.cc"], + copts = tf_copts(), visibility = ["//visibility:public"], deps = [ "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index a015b8e0e8949f8aaa03a78b0f88b7ea8d6aaa1c..b0ba25b9983c3a9af26728ce4b1c263c844327db 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -28,8 +28,9 @@ class BatchMatMulOp : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - auto result = - BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), adj_x_, adj_y_); + auto result = BatchDot(ctx->builder(), ctx->Input(0), ctx->Input(1), + /*transpose_x=*/adj_x_, /*transpose_y=*/adj_y_, + /*conjugate_x=*/adj_x_, /*conjugate_y=*/adj_y_); OP_REQUIRES_OK(ctx, result.status()); ctx->SetOutput(0, result.ValueOrDie()); } diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc index 248e9d111e556dcdd75581aa6562a66fc8b57063..a249b1869f547f8e5aa725f9f5cf391b10429928 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc @@ -14,7 +14,7 @@ limitations under the License. ==============================================================================*/ // XLA implementation of BatchNorm operations. -#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -26,43 +26,63 @@ namespace { class FusedBatchNormOp : public XlaOpKernel { public: explicit FusedBatchNormOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(/*num_dims=*/4, tensor_format); - } + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(0), &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(ctx->input_type(1), &scale_type)); + + xla::ComputationBuilder* builder = ctx->builder(); + + xla::ComputationDataHandle input = ctx->Input(0); + TensorShape input_shape = ctx->InputShape(0); + + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. As a workaround, cast everything to the statistics type (which + // may be more precise than the input type). + input = builder->ConvertElementType(input, scale_type); + if (is_training_) { - xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining( - ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_, - feature_index_); + xla::ComputationDataHandle output = builder->BatchNormTraining( + input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index); // In training mode, outputs the normalized value as well as the // calculated mean and variance. - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); - } + ctx->SetOutput(0, builder->ConvertElementType( + builder->GetTupleElement(output, 0), input_type)); + ctx->SetOutput(1, builder->GetTupleElement(output, 1)); + ctx->SetOutput(2, builder->GetTupleElement(output, 2)); + // Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved // space 1 & 2". They are used to pass the per-batch mean and // variance to the gradient. Here we maintain the same behavior by setting // them to the mean and variance calculated by BatchNormTraining. - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + ctx->SetOutput(3, builder->GetTupleElement(output, 1)); + ctx->SetOutput(4, builder->GetTupleElement(output, 2)); } else { - xla::ComputationDataHandle output = ctx->builder()->BatchNormInference( - ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3), - ctx->Input(4), epsilon_, feature_index_); - ctx->SetOutput(0, output); + xla::ComputationDataHandle output = builder->BatchNormInference( + input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4), + epsilon_, feature_index); + ctx->SetOutput(0, builder->ConvertElementType(output, input_type)); // Directly send input to output as mean and variance in inference mode. ctx->SetOutput(1, ctx->Input(3)); ctx->SetOutput(2, ctx->Input(4)); @@ -73,55 +93,113 @@ class FusedBatchNormOp : public XlaOpKernel { private: float epsilon_; - int64 feature_index_; + TensorFormat data_format_; bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp); +REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp); class FusedBatchNormGradOp : public XlaOpKernel { public: explicit FusedBatchNormGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { - string data_format; OP_REQUIRES_OK(ctx, ctx->GetAttr("epsilon", &epsilon_)); - OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format)); - bool is_training; - OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training)); - CHECK(is_training) << "FusedBatchNormGradOp with is_training=False cannot " - "be used with XLA for now!"; - TensorFormat tensor_format; - if (ctx->GetAttr("data_format", &data_format).ok()) { - OP_REQUIRES(ctx, FormatFromString(data_format, &tensor_format), - errors::InvalidArgument("Invalid data format")); - OP_REQUIRES( - ctx, (tensor_format == FORMAT_NHWC || tensor_format == FORMAT_NCHW), - errors::InvalidArgument("Not supported format")); - feature_index_ = GetTensorFeatureDimIndex(4, tensor_format); - } + OP_REQUIRES_OK(ctx, ctx->GetAttr("is_training", &is_training_)); + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES( + ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format: ", data_format_str)); + OP_REQUIRES(ctx, + (data_format_ == FORMAT_NHWC || data_format_ == FORMAT_NCHW), + errors::InvalidArgument( + "Unsupported data format ", ToString(data_format_), + "; supported formats are NHWC and NCHW")); } void Compile(XlaOpKernelContext* ctx) override { - auto grad_output = ctx->Input(0); - auto activation = ctx->Input(1); + xla::ComputationBuilder* b = ctx->builder(); + + auto grad_backprop = ctx->Input(0); + auto activations = ctx->Input(1); auto scale = ctx->Input(2); auto mean = ctx->Input(3); auto var = ctx->Input(4); - xla::ComputationDataHandle output = ctx->builder()->BatchNormGrad( - activation, scale, mean, var, grad_output, epsilon_, feature_index_); - for (int i = 0; i < 3; i++) { - ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i)); + TensorShape input_shape = ctx->InputShape(0); + int feature_index = + GetTensorFeatureDimIndex(input_shape.dims(), data_format_); + + DataType input_dtype = ctx->input_type(0); + DataType scale_dtype = ctx->input_type(2); + xla::PrimitiveType input_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(input_dtype, &input_type)); + xla::PrimitiveType scale_type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(scale_dtype, &scale_type)); + + // TODO(b/69928690): support mixed precision in the XLA batch normalization + // operators. For now, cast everything to the statistics type (which + // may be more precise than the input type). + grad_backprop = b->ConvertElementType(grad_backprop, scale_type); + activations = b->ConvertElementType(activations, scale_type); + + xla::ComputationDataHandle x_backprop; + xla::ComputationDataHandle scale_backprop; + xla::ComputationDataHandle offset_backprop; + if (is_training_) { + xla::ComputationDataHandle output = + b->BatchNormGrad(activations, scale, mean, var, grad_backprop, + epsilon_, feature_index); + + x_backprop = b->GetTupleElement(output, 0); + scale_backprop = b->GetTupleElement(output, 1); + offset_backprop = b->GetTupleElement(output, 2); + } else { + // Reduce over all dimensions except the feature dim. + std::vector reduction_dims(input_shape.dims() - 1); + std::iota(reduction_dims.begin(), reduction_dims.begin() + feature_index, + 0); + std::iota(reduction_dims.begin() + feature_index, reduction_dims.end(), + feature_index + 1); + // offset_backprop = sum(y_backprop) + // scale_backprop = y_backprop * ((x - pop_mean) * rsqrt(pop_var + + // epsilon)) + // x_backprop = y_backprop * (scale * rsqrt(pop_var + epsilon)) + offset_backprop = + b->Reduce(grad_backprop, XlaHelpers::Zero(b, scale_dtype), + *ctx->GetOrCreateAdd(scale_dtype), reduction_dims); + + // scratch1 = rsqrt(pop_var + epsilon) + auto neg_half = XlaHelpers::FloatLiteral(b, scale_dtype, -0.5); + auto scratch1 = + b->Pow(b->Add(var, b->ConstantR0(epsilon_)), neg_half); + + // scratch2 = sum(y_backprop * (x - mean)) + auto scratch2 = b->Reduce( + b->Mul(grad_backprop, b->Sub(activations, mean, {feature_index})), + XlaHelpers::Zero(b, scale_dtype), *ctx->GetOrCreateAdd(scale_dtype), + reduction_dims); + + x_backprop = + b->Mul(grad_backprop, b->Mul(scratch1, scale), {feature_index}); + scale_backprop = b->Mul(scratch1, scratch2); } - ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1)); - ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2)); + + ctx->SetOutput(0, b->ConvertElementType(x_backprop, input_type)); + ctx->SetOutput(1, scale_backprop); + ctx->SetOutput(2, offset_backprop); + ctx->SetConstantOutput(3, Tensor(scale_dtype, {})); + ctx->SetConstantOutput(4, Tensor(scale_dtype, {})); } private: + TensorFormat data_format_; float epsilon_; - int64 feature_index_; + bool is_training_; }; REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp); +REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index 21d3e64872e19109852297838043975cea6d7921..344a2ab2b6835c518c41de6f7a30fb2a34d130d2 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -159,7 +159,8 @@ class BatchToSpaceNDOp : public XlaOpKernel { block_shape, crops); } }; -REGISTER_XLA_OP(Name("BatchToSpaceND"), BatchToSpaceNDOp); +REGISTER_XLA_OP(Name("BatchToSpaceND").CompileTimeConstInput("crops"), + BatchToSpaceNDOp); class BatchToSpaceOp : public XlaOpKernel { public: @@ -181,7 +182,10 @@ class BatchToSpaceOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("BatchToSpace"), BatchToSpaceOp); +REGISTER_XLA_OP(Name("BatchToSpace") + .CompileTimeConstInput("crops") + .CompileTimeConstInput("block_shape"), + BatchToSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index bb031b8c471e08ba90c554e309b850a26c3edae0..ee2c920453c3bbaef2c145df743fddf999167c39 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -65,7 +65,10 @@ class BCastArgsOp : public XlaOpKernel { private: TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastArgs"), BCastArgsOp); +REGISTER_XLA_OP(Name("BroadcastArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastArgsOp); // Given shapes of two tensors, computes the reduction indices for the // gradient computation. @@ -121,7 +124,10 @@ class BCastGradArgsOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(BCastGradArgsOp); }; -REGISTER_XLA_OP(Name("BroadcastGradientArgs"), BCastGradArgsOp); +REGISTER_XLA_OP(Name("BroadcastGradientArgs") + .CompileTimeConstInput("s0") + .CompileTimeConstInput("s1"), + BCastGradArgsOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 1de91924326464338352b1ac9edf77141f25ad35..2436a6074a11ad66387b232dd1c5aa135875bfc3 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" namespace tensorflow { namespace { @@ -75,7 +76,7 @@ static xla::ComputationDataHandle FloorDivImpl(xla::ComputationBuilder* b, auto abs_y = b->Abs(y); auto t = b->Neg(b->Sub(b->Add(abs_x, abs_y), one)); auto result = b->Select(different_sign, b->Div(t, abs_y), b->Div(x, y)); - if (dtype == DT_FLOAT || dtype == DT_DOUBLE) { + if (DataTypeIsFloating(dtype)) { result = b->Floor(result); } return result; diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc index 592f3ecc3ce2abf33ddffe8b0e59c4e12e73e956..545aa364f937b2dc972dbe7b8c18b5897aa8e5c3 100644 --- a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -92,7 +92,8 @@ class CategoricalOp : public XlaOpKernel { }; // TODO(b/68769717): Rename this sampler to Categorical. -REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp); +REGISTER_XLA_OP(Name("Multinomial").CompileTimeConstInput("num_samples"), + CategoricalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc index 87d858f763560be454c162e0cf40307c68217663..fe6651793dc763d13f4a4b0ac294ec3ecf64af8f 100644 --- a/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cholesky_op.cc @@ -33,7 +33,7 @@ class CholeskyOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Cholesky"), CholeskyOp); +REGISTER_XLA_OP(Name("Cholesky").TypeConstraint("T", kFloatTypes), CholeskyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/concat_op.cc b/tensorflow/compiler/tf2xla/kernels/concat_op.cc index 73a4740e29af7fa57e71ef42a342f46b0e24231d..1a246e8df9b2cd83147b50d960744332f8582a51 100644 --- a/tensorflow/compiler/tf2xla/kernels/concat_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/concat_op.cc @@ -84,8 +84,8 @@ class ConcatBaseOp : public XlaOpKernel { in_shape.dims() == input_dims || (input_is_scalar && in_is_scalar), errors::InvalidArgument( "ConcatOp : Ranks of all input tensors should match: shape[0] = ", - input_shape.DebugString(), " vs. shape[", i, "] = ", - in_shape.DebugString())); + input_shape.DebugString(), " vs. shape[", i, + "] = ", in_shape.DebugString())); if (in_shape.dims() == 0) { // Inputs that come in as scalars must be reshaped to 1-vectors. input_data.push_back(ctx->builder()->Reshape(handle, {1})); @@ -117,8 +117,11 @@ class ConcatV2Op : public ConcatBaseOp { : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} }; -REGISTER_XLA_OP(Name("Concat"), ConcatOp); -REGISTER_XLA_OP(Name("ConcatV2").TypeConstraint("Tidx", DT_INT32), ConcatV2Op); +REGISTER_XLA_OP(Name("Concat").CompileTimeConstInput("concat_dim"), ConcatOp); +REGISTER_XLA_OP(Name("ConcatV2") + .TypeConstraint("Tidx", DT_INT32) + .CompileTimeConstInput("axis"), + ConcatV2Op); class ConcatOffsetOp : public XlaOpKernel { public: @@ -189,10 +192,10 @@ class ConcatOffsetOp : public XlaOpKernel { } else { const int32 inp0_element = inp0_literal.Get({j}); const int32 inp_element = inp_literal.Get({j}); - OP_REQUIRES( - ctx, (inp0_element == inp_element), - errors::InvalidArgument("input[", i, ",", j, "] mismatch: ", - inp0_element, " vs. ", inp_element)); + OP_REQUIRES(ctx, (inp0_element == inp_element), + errors::InvalidArgument("input[", i, ",", j, + "] mismatch: ", inp0_element, + " vs. ", inp_element)); out_vec(j) = 0; } } @@ -202,7 +205,10 @@ class ConcatOffsetOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ConcatOffset"), ConcatOffsetOp); +REGISTER_XLA_OP(Name("ConcatOffset") + .CompileTimeConstInput("concat_dim") + .CompileTimeConstInput("shape"), + ConcatOffsetOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index c5017704e2a45b0bd740f7a8fdcf3a0be1d445a4..81cea6d376d02c956a5257c5475fe5c10b83deb9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution( return expanded_shape; } +// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution. +xla::ComputationDataHandle CreateExpandedZero( + const TensorShape& filter_shape, DataType dtype, + xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + return builder->Broadcast(XlaHelpers::Zero(builder, dtype), + expanded_filter_shape.dim_sizes()); +} + +// Create a mask for depthwise convolution that will make a normal convolution +// produce the same results as a depthwise convolution. For a [2, 2, 3, 2] +// depthwise filter this returns a [2, 2, 3, 6] tesnsor +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// 1 1 0 0 0 0 1 1 0 0 0 0 +// 0 0 1 1 0 0 0 0 1 1 0 0 +// 0 0 0 0 1 1 0 0 0 0 1 1 +// +// The first step is to create a one tensor, A, that is [3] +// 0 1 2 +// +// and another tensor, B, that is [3 * 2] +// 0 1 2 3 4 5 +// +// and divide B it by 2 to get +// 0 0 1 1 2 2 +// +// then we broadcast the B to [2, 2, 3, 3 * 2] +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// 0 0 1 1 2 2 0 0 1 1 2 2 +// +// Finally compare A and broadcasted B in dimension 2 amd return the result at +// the beginning of the comment. +xla::ComputationDataHandle CreateExpandedFilterMask( + const TensorShape& filter_shape, xla::ComputationBuilder* builder) { + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); + + // Create a M sized linspace and an M*N sized linspace that will be + // broadcasted into perpendicular dimensions and compared. + xla::ComputationDataHandle input_feature_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature, + &input_feature_iota)); + xla::ComputationDataHandle expanded_feature_iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + input_feature * depthwise_multiplier, + &expanded_feature_iota)); + + // Divide the M*N sized linspace by the depthwise_multiplier to create + // [0 0 1 1 2 2] in the example in the function comment. + expanded_feature_iota = + builder->Div(expanded_feature_iota, + XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32, + depthwise_multiplier)); + + // Broadcast the N*M linspace to [H, W, ..., M, M*N]. + auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes(); + expanded_feature_broadcast_dims.pop_back(); + auto broadcasted_expanded_feature_iota = builder->Broadcast( + expanded_feature_iota, expanded_feature_broadcast_dims); + + // Compare the broadcasted linspace to the input feature linspace in the + // input feature dimension to create a diagonal predicate. + return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota, + {expanded_filter_shape.dims() - 2}); +} + // Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding // zeros for the cross-depth filters. Used to build a depthwise convolution. xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution( const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter, xla::ComputationBuilder* builder) { - // Filter has shape [H, W, ..., M, N] - // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then - // reshape to [H, W, ..., M, M*N]. - int num_spatial_dims = filter_shape.dims() - 2; - const int64 in_depth = filter_shape.dim_size(num_spatial_dims); - xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims()); - padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth); - auto dilated_filter = - builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding); - + int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1); + int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2); TensorShape expanded_filter_shape = ExpandedFilterShapeForDepthwiseConvolution(filter_shape); - return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes()); + + // Create a [H, W, ..., 1, N*M] reshape of the filter. + TensorShape implicit_broadcast_filter_shape = expanded_filter_shape; + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 2, 1); + implicit_broadcast_filter_shape.set_dim( + implicit_broadcast_filter_shape.dims() - 1, + depthwise_multiplier * input_feature); + auto implicit_broadcast_filter = + builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes()); + + // Broadcast the filter to [H, W, ..., M, M*N]. + auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder); + auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero); + + // If the filter mask is set, choose the broadcasted filter, othwerwise, + // choose zero. + return builder->Select(CreateExpandedFilterMask(filter_shape, builder), + expanded_filter, expanded_zero); } // Inverse of ExpandFilterForDepthwiseConvolution. xla::ComputationDataHandle ContractFilterForDepthwiseBackprop( - const TensorShape& filter_shape, DataType dtype, + XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype, const xla::ComputationDataHandle& filter_backprop, xla::ComputationBuilder* builder) { - int num_spatial_dims = filter_shape.dims() - 2; - - // Reshape to [H, W, ..., M*M, N] - TensorShape shape = filter_shape; - int64 in_depth = filter_shape.dim_size(num_spatial_dims); - shape.set_dim(num_spatial_dims, in_depth * in_depth); - auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes()); - - std::vector zeros(filter_shape.dims()); - std::vector strides(filter_shape.dims(), 1LL); - strides[num_spatial_dims] = in_depth + 1; - return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides); - - // Alternate implementation for backends without strided Slice() support. - // TODO(phawkins): Remove when all backends support strided slice. - // // Pad [..., M * (M + 1), N] - // xla::PaddingConfig config = - // xla::MakeNoPaddingConfig(filter_shape.dims()); - // config.mutable_dimensions(num_spatial_dims) - // ->set_edge_padding_high(in_depth); - // auto zero = XlaHelpers::Zero(builder, dtype); - // auto padded = builder->Pad(reshaped, zero, config); - // - // // Reshape to [..., M, M + 1, N] - // shape = filter_shape; - // shape.set_dim(num_spatial_dims, in_depth); - // shape.set_dim(num_spatial_dims + 1, in_depth + 1); - // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1); - // shape.AddDim(out_depth); - // reshaped = builder->Reshape(padded, shape.dim_sizes()); - // - // // Slice to [..., M, 1, N] - // std::vector zeros(shape.dims()); - // std::vector strides(shape.dims(), 1LL); - // shape.set_dim(num_spatial_dims + 1, 1); - // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(), - // strides); - // - // // Reshape to [..., M, N] - // return builder->Reshape(sliced, filter_shape.dim_sizes()); + TensorShape expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); + auto masked_expanded_filter = builder->Select( + CreateExpandedFilterMask(filter_shape, builder), filter_backprop, + CreateExpandedZero(filter_shape, dtype, builder)); + return builder->Reshape( + builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype), + *ctx->GetOrCreateAdd(dtype), + {expanded_filter_shape.dims() - 2}), + filter_shape.dim_sizes()); } class ConvOp : public XlaOpKernel { @@ -121,6 +179,7 @@ class ConvOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); @@ -144,6 +203,22 @@ class ConvOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + const TensorShape input_shape = ctx->InputShape(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, ..., in_depth, out_depth] @@ -172,38 +247,53 @@ class ConvOp : public XlaOpKernel { xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle filter = ctx->Input(1); + TensorShape expanded_filter_shape = filter_shape; if (depthwise_) { filter = ExpandFilterForDepthwiseConvolution( filter_shape, ctx->input_type(0), filter, b); + expanded_filter_shape = + ExpandedFilterShapeForDepthwiseConvolution(filter_shape); } xla::ConvolutionDimensionNumbers dims; - std::vector window_strides; + std::vector window_strides(num_spatial_dims_); + std::vector lhs_dilation(num_spatial_dims_, 1); + std::vector rhs_dilation(num_spatial_dims_); + std::vector> padding(num_spatial_dims_); + dims.set_input_batch_dimension(batch_dim); dims.set_output_batch_dimension(batch_dim); dims.set_input_feature_dimension(feature_dim); dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(num_spatial_dims_); + dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); + for (int i = 0; i < num_spatial_dims_; ++i) { - int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dims.add_input_spatial_dimensions(dim); dims.add_kernel_spatial_dimensions(i); dims.add_output_spatial_dimensions(dim); - window_strides.push_back(strides_.at(dim)); + window_strides[i] = strides_.at(dim); + rhs_dilation[i] = dilations_.at(dim); + + int64 unused_output_size; + OP_REQUIRES_OK( + ctx, GetWindowedOutputSizeVerboseV2( + input_shape.dim_size(dim), expanded_filter_shape.dim_size(i), + rhs_dilation[i], window_strides[i], padding_, + &unused_output_size, &padding[i].first, &padding[i].second)); } - dims.set_kernel_input_feature_dimension(num_spatial_dims_); - dims.set_kernel_output_feature_dimension(num_spatial_dims_ + 1); - xla::Padding xla_padding = - (padding_ == VALID) ? xla::Padding::kValid : xla::Padding::kSame; - - xla::ComputationDataHandle conv = b->ConvWithGeneralDimensions( - ctx->Input(0), filter, window_strides, xla_padding, dims); + xla::ComputationDataHandle conv = + b->ConvGeneralDilated(ctx->Input(0), filter, window_strides, padding, + lhs_dilation, rhs_dilation, dims); ctx->SetOutput(0, conv); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -241,6 +331,7 @@ class ConvBackpropInputOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -263,6 +354,22 @@ class ConvBackpropInputOp : public XlaOpKernel { errors::Unimplemented("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + TensorShape input_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape)); @@ -274,10 +381,11 @@ class ConvBackpropInputOp : public XlaOpKernel { : filter_shape; // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( - type_string(), num_spatial_dims_, input_shape, - expanded_filter_shape, out_backprop_shape, strides_, - padding_, data_format_, &dims)); + OP_REQUIRES_OK(ctx, + ConvBackpropComputeDimensionsV2( + type_string(), num_spatial_dims_, input_shape, + expanded_filter_shape, out_backprop_shape, dilations_, + strides_, padding_, data_format_, &dims)); xla::ComputationBuilder* b = ctx->builder(); auto filter = ctx->Input(1); @@ -301,6 +409,7 @@ class ConvBackpropInputOp : public XlaOpKernel { std::vector kernel_spatial_dims(num_spatial_dims_); std::vector> padding(num_spatial_dims_); std::vector lhs_dilation(num_spatial_dims_); + std::vector rhs_dilation(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); for (int i = 0; i < num_spatial_dims_; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); @@ -312,6 +421,7 @@ class ConvBackpropInputOp : public XlaOpKernel { padding[i] = {dims.spatial_dims[i].pad_before, dims.spatial_dims[i].pad_after}; lhs_dilation[i] = dims.spatial_dims[i].stride; + rhs_dilation[i] = dilations_[dim]; } // If this is a depthwise convolution, expand the filter. @@ -328,7 +438,7 @@ class ConvBackpropInputOp : public XlaOpKernel { // = gradients (with padding and dilation) mirrored_weights xla::ComputationDataHandle in_backprop = b->ConvGeneralDilated( out_backprop, mirrored_weights, /*window_strides=*/ones, padding, - lhs_dilation, /*rhs_dilation=*/ones, dnums); + lhs_dilation, rhs_dilation, dnums); ctx->SetOutput(0, in_backprop); } @@ -336,6 +446,7 @@ class ConvBackpropInputOp : public XlaOpKernel { protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -349,21 +460,26 @@ class Conv2DBackpropInputOp : public ConvBackpropInputOp { explicit Conv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv2DBackpropInput"), Conv2DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropInput").CompileTimeConstInput("input_sizes"), + Conv2DBackpropInputOp); class Conv3DBackpropInputOp : public ConvBackpropInputOp { public: explicit Conv3DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) {} }; -REGISTER_XLA_OP(Name("Conv3DBackpropInputV2"), Conv3DBackpropInputOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropInputV2").CompileTimeConstInput("input_sizes"), + Conv3DBackpropInputOp); class DepthwiseConv2DBackpropInputOp : public ConvBackpropInputOp { public: explicit DepthwiseConv2DBackpropInputOp(OpKernelConstruction* ctx) : ConvBackpropInputOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropInput") + .CompileTimeConstInput("input_sizes"), DepthwiseConv2DBackpropInputOp); class ConvBackpropFilterOp : public XlaOpKernel { @@ -373,6 +489,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims), depthwise_(depthwise) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); string data_format; @@ -392,6 +509,22 @@ class ConvBackpropFilterOp : public XlaOpKernel { errors::InvalidArgument("Current implementation does not yet support " "strides in the batch and depth dimensions.")); + OP_REQUIRES(ctx, dilations_.size() == num_dims(), + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES( + ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + for (int i = 0; i < num_spatial_dims_; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + const TensorShape activations_shape = ctx->InputShape(0); TensorShape filter_shape; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape)); @@ -403,10 +536,11 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Reuse dimension computation logic from conv_grad_ops.cc. ConvBackpropDimensions dims; - OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions( - type_string(), num_spatial_dims_, activations_shape, - expanded_filter_shape, out_backprop_shape, strides_, - padding_, data_format_, &dims)); + OP_REQUIRES_OK(ctx, + ConvBackpropComputeDimensionsV2( + type_string(), num_spatial_dims_, activations_shape, + expanded_filter_shape, out_backprop_shape, dilations_, + strides_, padding_, data_format_, &dims)); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle activations = ctx->Input(0); @@ -426,9 +560,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Swap n_dim and c_dim in the activations. dnums.set_input_batch_dimension(c_dim); - dnums.set_output_batch_dimension(c_dim); dnums.set_input_feature_dimension(n_dim); - dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] @@ -438,21 +570,29 @@ class ConvBackpropFilterOp : public XlaOpKernel { std::vector> padding(num_spatial_dims_); std::vector rhs_dilation(num_spatial_dims_); + std::vector window_strides(num_spatial_dims_); std::vector ones(num_spatial_dims_, 1); + // Tensorflow filter shape is [ H, W, ..., inC, outC ]. + for (int i = 0; i < num_spatial_dims_; ++i) { + dnums.add_output_spatial_dimensions(i); + } + dnums.set_output_batch_dimension(num_spatial_dims_); + dnums.set_output_feature_dimension(num_spatial_dims_ + 1); + for (int i = 0; i < num_spatial_dims_; ++i) { int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dnums.add_input_spatial_dimensions(dim); dnums.add_kernel_spatial_dimensions(dim); - dnums.add_output_spatial_dimensions(dim); // We will also need to pad the input with zeros such that after the // convolution, we get the right size for the filter. // The padded_in_rows should be such that when we convolve this with the // expanded_out_rows as a filter, we should get filter_rows back. // - const int padded_in_size = dims.spatial_dims[i].expanded_output_size + - dims.spatial_dims[i].filter_size - 1; + const int64 padded_in_size = + dims.spatial_dims[i].expanded_output_size + + (dims.spatial_dims[i].filter_size - 1) * dilations_[dim]; // However it can be smaller than input_rows: in this // case it means some of the inputs are not used. @@ -468,8 +608,7 @@ class ConvBackpropFilterOp : public XlaOpKernel { // and input "C" is not used at all. // // We apply negative padding in this case. - const int total_pad_in_size = - padded_in_size - dims.spatial_dims[i].input_size; + const int64 pad_total = padded_in_size - dims.spatial_dims[i].input_size; // + For the VALID padding, we don't pad anything on the top/left side // and pad the bottom/right side with the remaining space. @@ -479,13 +618,12 @@ class ConvBackpropFilterOp : public XlaOpKernel { // In addition, if the padded input size is smaller than the input size, // we need to ignore some training elements of the input. We do this by // applying negative padding on the right/bottom. - const int before_pad_in_size = - (total_pad_in_size > 0 && padding_ == Padding::SAME) - ? total_pad_in_size / 2 - : 0; + const int64 pad_before = + padding_ == Padding::SAME ? std::max(pad_total / 2, 0) : 0; - padding[i] = {before_pad_in_size, total_pad_in_size - before_pad_in_size}; + padding[i] = {pad_before, pad_total - pad_before}; rhs_dilation[i] = dims.spatial_dims[i].stride; + window_strides[i] = dilations_[dim]; } // Besides padding the input, we will also expand output_rows to @@ -497,35 +635,20 @@ class ConvBackpropFilterOp : public XlaOpKernel { // This is done by specifying the window dilation factors in the // convolution HLO below. auto filter_backprop = - b->ConvGeneralDilated(activations, gradients, - /*window_strides=*/ones, padding, + b->ConvGeneralDilated(activations, gradients, window_strides, padding, /*lhs_dilation=*/ones, rhs_dilation, dnums); - // The layout of filter_backprop will match the layout of - // padded_activations - // and so will have layout: [out_feature, h, w, ..., in_feature] - // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the - // output. - std::vector transpose_dims; - transpose_dims.reserve(num_dims()); - for (int i = 0; i < num_spatial_dims_; ++i) { - transpose_dims.push_back(dnums.output_spatial_dimensions(i)); - } - transpose_dims.push_back(c_dim); - transpose_dims.push_back(n_dim); - xla::ComputationDataHandle filter_backprop_reshaped = - b->Transpose(filter_backprop, transpose_dims); - if (depthwise_) { - filter_backprop_reshaped = ContractFilterForDepthwiseBackprop( - filter_shape, ctx->input_type(0), filter_backprop_reshaped, b); + filter_backprop = ContractFilterForDepthwiseBackprop( + ctx, filter_shape, ctx->input_type(0), filter_backprop, b); } - ctx->SetOutput(0, filter_backprop_reshaped); + ctx->SetOutput(0, filter_backprop); } protected: const int num_spatial_dims_; const bool depthwise_; + std::vector dilations_; std::vector strides_; Padding padding_; TensorFormat data_format_ = FORMAT_NHWC; @@ -540,7 +663,9 @@ class Conv2DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv2DBackpropFilter"), Conv2DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv2DBackpropFilter").CompileTimeConstInput("filter_sizes"), + Conv2DBackpropFilterOp); class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { public: @@ -548,14 +673,17 @@ class Conv3DBackpropFilterOp : public ConvBackpropFilterOp { : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/3, /*depthwise=*/false) { } }; -REGISTER_XLA_OP(Name("Conv3DBackpropFilterV2"), Conv3DBackpropFilterOp); +REGISTER_XLA_OP( + Name("Conv3DBackpropFilterV2").CompileTimeConstInput("filter_sizes"), + Conv3DBackpropFilterOp); class DepthwiseConv2DBackpropFilterOp : public ConvBackpropFilterOp { public: explicit DepthwiseConv2DBackpropFilterOp(OpKernelConstruction* ctx) : ConvBackpropFilterOp(ctx, /*num_spatial_dims=*/2, /*depthwise=*/true) {} }; -REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter"), +REGISTER_XLA_OP(Name("DepthwiseConv2dNativeBackpropFilter") + .CompileTimeConstInput("filter_sizes"), DepthwiseConv2DBackpropFilterOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc index a4ea65ea89e348cb77412efb0c5c0fcb1a9f33f3..96d7809f7995634b6bc31ab801b93526d9da7e6f 100644 --- a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class DepthToSpaceOp : public XlaOpKernel { public: explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,18 +42,79 @@ class DepthToSpaceOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got: ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i]); + } + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i + 1); + transpose_order.push_back(i + 1 + num_spatial_dims); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] * block_size_); + } + output_shape.push_back(input_shape[feature_dim] / block_elems); + } else { + // NCHW format. + reshaped_shape.push_back(input_shape[0]); + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(block_size_); + block_elems *= block_size_; + } + reshaped_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i]); + } + + transpose_order.push_back(0); + transpose_order.push_back(1 + num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(2 + num_spatial_dims + i); + transpose_order.push_back(1 + i); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] / block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] * block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, @@ -51,14 +123,14 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // block_size_, // depth / (block_size_ * block_size_)] - OP_REQUIRES(ctx, input_shape[3] % (block_size_ * block_size_) == 0, + OP_REQUIRES(ctx, + input_shape[feature_dim] % (block_size_ * block_size_) == 0, errors::InvalidArgument( "Input depth dimension (", input_shape[3], ") is not divisible by square of the block size (", block_size_, ")")); - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1], input_shape[2], block_size_, - block_size_, input_shape[3] / (block_size_ * block_size_)}); + + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -70,7 +142,7 @@ class DepthToSpaceOp : public XlaOpKernel { // block_size_, // depth / (block_size_ * block_size_)] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -80,15 +152,14 @@ class DepthToSpaceOp : public XlaOpKernel { // input_shape[2] * block_size_, // depth / (block_size_ * block_size_)] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] * block_size_, - input_shape[2] * block_size_, - input_shape[3] / (block_size_ * block_size_)}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("DepthToSpace"), DepthToSpaceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/diag_op.cc b/tensorflow/compiler/tf2xla/kernels/diag_op.cc index ec5017f6ab96bd3fc273a746b77fbb7e74fd9f35..765ea922a532a085a552192348ab360c4c30ff0a 100644 --- a/tensorflow/compiler/tf2xla/kernels/diag_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/diag_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +24,62 @@ limitations under the License. namespace tensorflow { namespace { +// Create a diagonal / batch diagonal matrix with 'input' on the diagonal. +xla::StatusOr CreateDiagonal( + const xla::ComputationDataHandle& input, int64 last_dim_size, + tensorflow::gtl::ArraySlice other_dims, XlaOpKernelContext* ctx, + xla::ComputationBuilder* builder) { + // Create two matrices that have the following forms, and compare them: + // + // [[0, 0, 0, 0] [[0, 1, 2, 3] + // [1, 1, 1, 1] [0, 1, 2, 3] + // [2, 2, 2, 2] [0, 1, 2, 3] + // [3, 3, 3, 3]] [0, 1, 2, 3]] + // + // This produces a predicate matrix of the right size, with "true" on the + // diagonal. + xla::ComputationDataHandle iota; + TF_RETURN_IF_ERROR( + XlaHelpers::Iota(builder, DataType::DT_INT32, last_dim_size, &iota)); + xla::ComputationDataHandle iota_broadcast = + builder->Broadcast(iota, {last_dim_size}); + xla::ComputationDataHandle mask = builder->Eq(iota_broadcast, iota, {0}); + + // If this is a batched diagonal, broadcast the mask across the other + // dimensions. + if (!other_dims.empty()) { + mask = builder->Broadcast(mask, other_dims); + } + + // Broadcast the input, and then use the mask computed above to select the + // diagonal: + // e.g, in 2D: + // [[t, f, f] [[1, 1, 1] [[0, 0, 0] [[1, 0, 0] + // select( [f, t, f] , [4, 4, 4] , [0, 0, 0] ) = [0, 4, 0] + // [f, f, t]] [9, 9, 9]] [0, 0, 0]] [0, 0, 9]] + // + // Broadcasting the input is less-than-trivial, since we need to broadcast + // into a "middle" dimension. We can do this with a reshape + implicit + // broadcast. + // TODO(b/30112114): Replace with in-dim broadcast when those are supported. + std::vector broadcast_dims(other_dims.begin(), other_dims.end()); + broadcast_dims.push_back(1LL); + broadcast_dims.push_back(last_dim_size); + xla::ComputationDataHandle input_broadcast = + builder->Reshape(input, broadcast_dims); + + broadcast_dims[broadcast_dims.size() - 2] = last_dim_size; + xla::PrimitiveType element_type; + TF_RETURN_IF_ERROR( + DataTypeToPrimitiveType(ctx->input_type(0), &element_type)); + auto broadcast_shape = + xla::ShapeUtil::MakeShape(element_type, broadcast_dims); + xla::ComputationDataHandle zeros = Zeros(builder, broadcast_shape); + + input_broadcast = builder->Add(input_broadcast, zeros); + return builder->Select(mask, input_broadcast, zeros); +} + class DiagOp : public XlaOpKernel { public: explicit DiagOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} @@ -29,6 +87,8 @@ class DiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("Diag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -36,7 +96,7 @@ class DiagOp : public XlaOpKernel { errors::InvalidArgument("Expected 1 <= dims, got shape ", input_shape.DebugString())); - xla::ComputationDataHandle diag = ctx->Input(0); + xla::ComputationDataHandle input = ctx->Input(0); // Picture: // tf.diag([1, 2, 3, 4]) ==> [[1, 0, 0, 0] @@ -46,13 +106,13 @@ class DiagOp : public XlaOpKernel { // Flattens the input to 1D. int64 size = input_shape.num_elements(); - diag = builder->Reshape(diag, {size}); + input = builder->Reshape(input, {size}); - // Adds inter-element padding of 'size'. - xla::PaddingConfig config; - auto* dim = config.add_dimensions(); - dim->set_interior_padding(size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); + // Create an R2 with the R1 diagonal. + auto diag_or_status = + CreateDiagonal(input, size, /*other_dims=*/{}, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + xla::ComputationDataHandle diag = diag_or_status.ValueOrDie(); // Reshapes to the final shape. std::vector new_dims(dims.size() * 2); @@ -141,6 +201,8 @@ class MatrixDiagOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); + OP_REQUIRES(ctx, ctx->num_inputs() >= 1, + errors::InvalidArgument("MatrixDiag op must have at an input")); const TensorShape input_shape = ctx->InputShape(0); auto dims = input_shape.dim_sizes(); @@ -152,17 +214,13 @@ class MatrixDiagOp : public XlaOpKernel { int last_dim = dims.size() - 1; int64 last_dim_size = input_shape.dim_size(last_dim); + tensorflow::gtl::ArraySlice other_dims(dims); + other_dims.pop_back(); - // Adds inter-element padding of 'last_dim_size' to the last dimension. - xla::PaddingConfig config = xla::MakeNoPaddingConfig(dims.size()); - auto* dim = config.mutable_dimensions(last_dim); - dim->set_interior_padding(last_dim_size); - diag = builder->Pad(diag, XlaHelpers::Zero(builder, input_type(0)), config); - - // Reshapes to the final shape. - dims.push_back(last_dim_size); - diag = builder->Reshape(diag, dims); - + auto diag_or_status = + CreateDiagonal(diag, last_dim_size, other_dims, ctx, builder); + OP_REQUIRES_OK(ctx, diag_or_status.status()); + diag = diag_or_status.ValueOrDie(); ctx->SetOutput(0, diag); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index 7349dcb987cd88c423570889c0502d1a0bd12c52..f2cd21ffb9ce88747c04f3c71e66dadeb1faf0f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -72,22 +72,24 @@ class DynamicStitchOp : public XlaOpKernel { XLAShapeToTensorShape(indices_input[input_num].shape(), &indices_shape)); const TensorShape& data_shape = data_shapes[input_num]; - OP_REQUIRES(ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), - errors::InvalidArgument( - "data[", input_num, "].shape = ", - data_shape.DebugString(), " does not start with indices[", - input_num, "].shape = ", indices_shape.DebugString())); - OP_REQUIRES(ctx, - input_num == 0 || SameExtraShape(data0_shape, indices0_shape, - data_shape, indices_shape), - errors::InvalidArgument( - "Need data[0].shape[", indices0_shape.dims(), - ":] = data[", input_num, "].shape[", indices_shape.dims(), - ":], got data[0].shape = ", data0_shape.DebugString(), - ", data[", input_num, "].shape = ", - data_shape.DebugString(), ", indices[0].shape = ", - indices0_shape.DebugString(), ", indices[", input_num, - "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::StartsWith(data_shape, indices_shape), + errors::InvalidArgument("data[", input_num, + "].shape = ", data_shape.DebugString(), + " does not start with indices[", input_num, + "].shape = ", indices_shape.DebugString())); + OP_REQUIRES( + ctx, + input_num == 0 || SameExtraShape(data0_shape, indices0_shape, + data_shape, indices_shape), + errors::InvalidArgument( + "Need data[0].shape[", indices0_shape.dims(), ":] = data[", + input_num, "].shape[", indices_shape.dims(), + ":], got data[0].shape = ", data0_shape.DebugString(), ", data[", + input_num, "].shape = ", data_shape.DebugString(), + ", indices[0].shape = ", indices0_shape.DebugString(), + ", indices[", input_num, + "].shape = ", indices_shape.DebugString())); OP_REQUIRES_OK(ctx, XlaHelpers::ReshapeLiteral(indices_input[input_num], @@ -159,8 +161,8 @@ class DynamicStitchOp : public XlaOpKernel { indices0_shape.dims()); std::vector slice_limit(1 + data0_shape.dims() - indices0_shape.dims()); - std::vector stride(1 + data0_shape.dims() - - indices0_shape.dims(), 1); + std::vector stride(1 + data0_shape.dims() - indices0_shape.dims(), + 1); for (int d = indices0_shape.dims(); d < data0_shape.dims(); d++) { slice_limit[1 + d - indices0_shape.dims()] = data0_shape.dim_size(d); } @@ -198,8 +200,10 @@ class DynamicStitchOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); -REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp); +REGISTER_XLA_OP(Name("DynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); +REGISTER_XLA_OP(Name("ParallelDynamicStitch").CompileTimeConstInput("indices"), + DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2970eae20a3fb71f06619f476a49d41b22bca56 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/extract_image_patches_op.cc @@ -0,0 +1,169 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +class ExtractImagePatchesOp : public XlaOpKernel { + public: + explicit ExtractImagePatchesOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksizes", &ksizes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("rates", &dilations_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorFormat data_format = FORMAT_NHWC; + const int num_dims = ksizes_.size(); + + OP_REQUIRES( + ctx, num_dims >= 3, + errors::InvalidArgument("Kernel size must have at least 3 dimensions")); + const int num_spatial_dims = num_dims - 2; + + OP_REQUIRES(ctx, strides_.size() == num_dims, + errors::InvalidArgument("Sliding window strides field must " + "specify ", + num_dims, " dimensions")); + OP_REQUIRES(ctx, dilations_.size() == num_dims, + errors::InvalidArgument("Dilations field must " + "specify ", + num_dims, " dimensions")); + + int batch_dim = GetTensorBatchDimIndex(num_dims, data_format); + int feature_dim = GetTensorFeatureDimIndex(num_dims, data_format); + OP_REQUIRES( + ctx, ksizes_[batch_dim] == 1 && ksizes_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "kernel sizes > 1 in the batch and depth " + "dimensions.")); + OP_REQUIRES( + ctx, strides_[batch_dim] == 1 && strides_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES( + ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1, + errors::Unimplemented("Current implementation does not support " + "dilations in the batch and depth dimensions.")); + + for (int i = 0; i < num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + OP_REQUIRES( + ctx, ksizes_[input_dim] >= 0, + errors::Unimplemented("Kernel size values must be non-negative; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + OP_REQUIRES(ctx, strides_[input_dim] >= 1, + errors::Unimplemented("Stride values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + OP_REQUIRES(ctx, dilations_[input_dim] >= 1, + errors::Unimplemented("Dilation values must be positive; ", i, + "th spatial dimension had dilation ", + dilations_[input_dim])); + } + + xla::PrimitiveType type; + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(ctx->input_type(0), &type)); + + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES( + ctx, input_shape.dims() == num_dims, + errors::InvalidArgument("input must be ", num_dims, "-dimensional", + input_shape.DebugString())); + const int64 depth = input_shape.dim_size(feature_dim); + + xla::ComputationBuilder* builder = ctx->builder(); + + // The following code is equivalent to: + // eye = np.eye(kH * kW * D).reshape([kH, kW, D, kH * kW * kD]) + int64 kernel_size = 1; + std::vector lhs_shape(num_dims, 1); + for (int i = 0; i < num_spatial_dims; ++i) { + int input_dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + lhs_shape[i] = ksizes_[input_dim]; + kernel_size *= ksizes_[input_dim]; + } + lhs_shape[num_spatial_dims] = depth; + lhs_shape[num_spatial_dims + 1] = 1; + + // Builds an identity matrix as a broadcast equality of iotas. + // iota = np.arange(np.prod(ksize), depth) + // filter = np.equal(np.reshape(iota, [-1, 1]), iota).astype(np.float32) + xla::ComputationDataHandle iota; + TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, + kernel_size * depth, &iota)); + + auto lhs = builder->Reshape(iota, lhs_shape); + auto filter = builder->ConvertElementType( + builder->Eq(lhs, iota, {num_spatial_dims + 1}), type); + + xla::ConvolutionDimensionNumbers dims; + std::vector window_strides(num_spatial_dims); + std::vector lhs_dilation(num_spatial_dims, 1); + std::vector rhs_dilation(num_spatial_dims); + std::vector> padding(num_spatial_dims); + + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); + dims.set_kernel_input_feature_dimension(num_spatial_dims); + dims.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + for (int i = 0; i < num_spatial_dims; ++i) { + const int64 dim = GetTensorSpatialDimIndex(num_dims, data_format, i); + dims.add_input_spatial_dimensions(dim); + dims.add_kernel_spatial_dimensions(i); + dims.add_output_spatial_dimensions(dim); + window_strides[i] = strides_.at(dim); + rhs_dilation[i] = dilations_.at(dim); + + int64 unused_output_size; + OP_REQUIRES_OK( + ctx, GetWindowedOutputSizeVerboseV2( + input_shape.dim_size(dim), ksizes_[dim], rhs_dilation[i], + window_strides[i], padding_, &unused_output_size, + &padding[i].first, &padding[i].second)); + } + + xla::ComputationDataHandle conv = + builder->ConvGeneralDilated(ctx->Input(0), filter, window_strides, + padding, lhs_dilation, rhs_dilation, dims); + ctx->SetOutput(0, conv); + } + + protected: + std::vector ksizes_; + std::vector dilations_; + std::vector strides_; + Padding padding_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ExtractImagePatchesOp); +}; + +REGISTER_XLA_OP(Name("ExtractImagePatches"), ExtractImagePatchesOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..453a32c494b42e9922bc35fc526f3306530054fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/platform/macros.h" + +namespace tensorflow { +namespace { + +// Gymnastics with nudged zero point is to ensure that the real zero maps to +// an integer, which is required for e.g. zero-padding in convolutional layers. +void CpuNudge(const float min, const float max, const float quant_min, + const float quant_max, float* nudged_min, float* nudged_max, + float* scale) { + *scale = (max - min) / (quant_max - quant_min); + + const float zero_point_from_min = quant_min - min / *scale; + float nudged_zero_point; + if (zero_point_from_min <= quant_min) { + nudged_zero_point = quant_min; + } else if (zero_point_from_min >= quant_max) { + nudged_zero_point = quant_max; + } else { + nudged_zero_point = std::round(zero_point_from_min); + } + + *nudged_min = (quant_min - nudged_zero_point) * (*scale); + *nudged_max = (quant_max - nudged_zero_point) * (*scale); +} + +// An XLA version of CpuNudge(). +void XlaNudge(xla::ComputationBuilder* b, const DataType data_type, + const xla::ComputationDataHandle& min, + const xla::ComputationDataHandle& max, + const float quant_min_value, const float quant_max_value, + xla::ComputationDataHandle* nudged_min, + xla::ComputationDataHandle* nudged_max, + xla::ComputationDataHandle* scale) { + *scale = b->Div(b->Sub(max, min), + XlaHelpers::FloatLiteral(b, data_type, + quant_max_value - quant_min_value)); + xla::ComputationDataHandle quant_min = + XlaHelpers::FloatLiteral(b, data_type, quant_min_value); + xla::ComputationDataHandle zero_point_from_min = + b->Sub(quant_min, b->Div(min, *scale)); + xla::ComputationDataHandle quant_max = + XlaHelpers::FloatLiteral(b, data_type, quant_max_value); + xla::ComputationDataHandle nudged_zero_point = + b->Select(b->Le(zero_point_from_min, quant_min), quant_min, + b->Select(b->Ge(zero_point_from_min, quant_max), quant_max, + b->Round(zero_point_from_min))); + *nudged_min = b->Mul(b->Sub(quant_min, nudged_zero_point), *scale); + *nudged_max = b->Mul(b->Sub(quant_max, nudged_zero_point), *scale); +} + +xla::ComputationDataHandle Quantize( + xla::ComputationBuilder* b, const xla::ComputationDataHandle& input, + const DataType data_type, + const xla::ComputationDataHandle& nudged_input_min, + const xla::ComputationDataHandle& nudged_input_max, + const xla::ComputationDataHandle& input_scale) { + xla::ComputationDataHandle one = XlaHelpers::FloatLiteral(b, data_type, 1.0f); + xla::ComputationDataHandle inv_scale = b->Div(one, input_scale); + xla::ComputationDataHandle half = + XlaHelpers::FloatLiteral(b, data_type, 0.5f); + + xla::ComputationDataHandle clamped = + b->Clamp(nudged_input_min, input, nudged_input_max); + xla::ComputationDataHandle clamped_shifted = + b->Sub(clamped, nudged_input_min); + xla::ComputationDataHandle rounded = + b->Floor(b->Add(b->Mul(clamped_shifted, inv_scale), half)); + return b->Add(b->Mul(rounded, input_scale), nudged_input_min); +} + +class FakeQuantWithMinMaxArgsOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxArgsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + + float input_min, input_max; + OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max)); + CpuNudge(input_min, input_max, quant_min_, quant_max_, &nudged_input_min_, + &nudged_input_max_, &input_scale_); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle input = ctx->Input(0); + const DataType data_type = ctx->input_type(0); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); + xla::ComputationDataHandle nudged_input_max = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); + xla::ComputationDataHandle input_scale = + XlaHelpers::FloatLiteral(b, data_type, input_scale_); + xla::ComputationDataHandle output = Quantize( + b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + ctx->SetOutput(0, output); + } + + private: + float quant_min_; + float quant_max_; + float nudged_input_min_; + float nudged_input_max_; + float input_scale_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgs"), FakeQuantWithMinMaxArgsOp); + +class FakeQuantWithMinMaxArgsGradOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxArgsGradOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + const float quant_min = narrow_range ? 1 : 0; + const float quant_max = (1 << num_bits) - 1; + + float input_min, input_max, scale; + OP_REQUIRES_OK(ctx, ctx->GetAttr("min", &input_min)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max", &input_max)); + CpuNudge(input_min, input_max, quant_min, quant_max, &nudged_input_min_, + &nudged_input_max_, &scale); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle gradient = ctx->Input(0); + const TensorShape gradient_shape = ctx->InputShape(0); + xla::ComputationDataHandle input = ctx->Input(1); + const DataType data_type = ctx->input_type(1); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_min_); + xla::ComputationDataHandle nudged_input_max = + XlaHelpers::FloatLiteral(b, data_type, nudged_input_max_); + + xla::ComputationDataHandle between_nudged_min_max = + b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::ComputationDataHandle zeroes = b->Broadcast( + XlaHelpers::Zero(b, data_type), gradient_shape.dim_sizes()); + xla::ComputationDataHandle output = + b->Select(between_nudged_min_max, gradient, zeroes); + ctx->SetOutput(0, output); + } + + private: + float nudged_input_min_; + float nudged_input_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxArgsGradient"), + FakeQuantWithMinMaxArgsGradOp); + +class FakeQuantWithMinMaxVarsOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxVarsOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle input = ctx->Input(0); + const DataType data_type = ctx->input_type(0); + xla::ComputationDataHandle input_min = ctx->Input(1); + xla::ComputationDataHandle input_max = ctx->Input(2); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, + &nudged_input_min, &nudged_input_max, &input_scale); + + xla::ComputationDataHandle output = Quantize( + b, input, data_type, nudged_input_min, nudged_input_max, input_scale); + ctx->SetOutput(0, output); + } + + private: + float quant_min_; + float quant_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVars"), FakeQuantWithMinMaxVarsOp); + +class FakeQuantWithMinMaxVarsGradOp : public XlaOpKernel { + public: + explicit FakeQuantWithMinMaxVarsGradOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + int num_bits; + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits)); + OP_REQUIRES(ctx, num_bits >= 2 && num_bits <= 16, + errors::InvalidArgument("num_bits is out of range, expected " + "between 2 and 16, was: ", + num_bits)); + bool narrow_range; + OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range)); + quant_min_ = narrow_range ? 1 : 0; + quant_max_ = (1 << num_bits) - 1; + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationDataHandle gradient = ctx->Input(0); + const TensorShape gradient_shape = ctx->InputShape(0); + xla::ComputationDataHandle input = ctx->Input(1); + const DataType data_type = ctx->input_type(1); + xla::ComputationDataHandle input_min = ctx->Input(2); + xla::ComputationDataHandle input_max = ctx->Input(3); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle nudged_input_min, nudged_input_max, input_scale; + XlaNudge(b, data_type, input_min, input_max, quant_min_, quant_max_, + &nudged_input_min, &nudged_input_max, &input_scale); + + xla::ComputationDataHandle between_nudged_min_max = + b->And(b->Le(nudged_input_min, input), b->Le(input, nudged_input_max)); + xla::ComputationDataHandle zero = XlaHelpers::Zero(b, data_type); + xla::ComputationDataHandle zeroes = + b->Broadcast(zero, gradient_shape.dim_sizes()); + xla::ComputationDataHandle output0 = + b->Select(between_nudged_min_max, gradient, zeroes); + ctx->SetOutput(0, output0); + + xla::ComputationDataHandle below_min = b->Lt(input, nudged_input_min); + xla::ComputationDataHandle output1 = + b->ReduceAll(b->Select(below_min, gradient, zeroes), zero, + *ctx->GetOrCreateAdd(data_type)); + ctx->SetOutput(1, output1); + + xla::ComputationDataHandle above_max = b->Gt(input, nudged_input_max); + xla::ComputationDataHandle output2 = + b->ReduceAll(b->Select(above_max, gradient, zeroes), zero, + *ctx->GetOrCreateAdd(data_type)); + ctx->SetOutput(2, output2); + } + + private: + float quant_min_; + float quant_max_; +}; + +REGISTER_XLA_OP(Name("FakeQuantWithMinMaxVarsGradient"), + FakeQuantWithMinMaxVarsGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fft_ops.cc b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4f3c1c3ad9a928e0552c388a25ed9fcb08edabb --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/fft_ops.cc @@ -0,0 +1,122 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA-specific Ops for FFT. + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_slice.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/util/padding.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +namespace { + +using xla::FftType; + +class GenericFftOp : public XlaOpKernel { + public: + explicit GenericFftOp(OpKernelConstruction* ctx, FftType fft_type, + int fft_rank) + : XlaOpKernel(ctx), fft_type_(fft_type), fft_rank_(fft_rank) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES( + ctx, TensorShapeUtils::IsVectorOrHigher(input_shape), + errors::InvalidArgument("input must be at least 1 dimensional")); + + std::vector fft_length; + if (fft_type_ == FftType::RFFT || fft_type_ == FftType::IRFFT) { + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &fft_length)); + OP_REQUIRES(ctx, fft_length.size() == fft_rank_, + errors::InvalidArgument("fft_length must be length ", + fft_rank_, " vector")); + } else { + // Innermost axis provides the FFT length. + for (int i = 0; i < fft_rank_; i++) { + fft_length.push_back( + input_shape.dim_size(input_shape.dims() - fft_rank_ + i)); + } + } + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle fft = + b->Fft(ctx->Input(0), fft_type_, fft_length); + ctx->SetOutput(0, fft); + } + + protected: + const FftType fft_type_; + const int fft_rank_; + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GenericFftOp); +}; + +template +class FFTOp : public GenericFftOp { + public: + explicit FFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::FFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("FFT"), FFTOp<1>); +REGISTER_XLA_OP(Name("FFT2D"), FFTOp<2>); +REGISTER_XLA_OP(Name("FFT3D"), FFTOp<3>); + +template +class IFFTOp : public GenericFftOp { + public: + explicit IFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::IFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("IFFT"), IFFTOp<1>); +REGISTER_XLA_OP(Name("IFFT2D"), IFFTOp<2>); +REGISTER_XLA_OP(Name("IFFT3D"), IFFTOp<3>); + +template +class RFFTOp : public GenericFftOp { + public: + explicit RFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::RFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("RFFT").CompileTimeConstInput("fft_length"), RFFTOp<1>); +REGISTER_XLA_OP(Name("RFFT2D").CompileTimeConstInput("fft_length"), RFFTOp<2>); +REGISTER_XLA_OP(Name("RFFT3D").CompileTimeConstInput("fft_length"), RFFTOp<3>); + +template +class IRFFTOp : public GenericFftOp { + public: + explicit IRFFTOp(OpKernelConstruction* ctx) + : GenericFftOp(ctx, /*fft_type=*/FftType::IRFFT, /*fft_rank=*/FFTRank) {} +}; +REGISTER_XLA_OP(Name("IRFFT").CompileTimeConstInput("fft_length"), IRFFTOp<1>); +REGISTER_XLA_OP(Name("IRFFT2D").CompileTimeConstInput("fft_length"), + IRFFTOp<2>); +REGISTER_XLA_OP(Name("IRFFT3D").CompileTimeConstInput("fft_length"), + IRFFTOp<3>); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/fill_op.cc b/tensorflow/compiler/tf2xla/kernels/fill_op.cc index 9e090fe01cbfd4dab81b0de21e3a44e42c2ef18e..eaa13b8dfacce9aaca42ce5fcdfa467ce7fa7b7f 100644 --- a/tensorflow/compiler/tf2xla/kernels/fill_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/fill_op.cc @@ -69,7 +69,7 @@ class FillOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Fill"), FillOp); +REGISTER_XLA_OP(Name("Fill").CompileTimeConstInput("dims"), FillOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index e420f21ca33fe7de9b33f404ce04eae62d9c041e..7945c05af40df21a798a2cff51fe7f8e935793f6 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/tf2xla/kernels/gather_op.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" @@ -26,46 +26,48 @@ limitations under the License. namespace tensorflow { -xla::ComputationDataHandle XlaComputeGatherDynamicSlice( - XlaOpKernelContext* context, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, DataType dtype, - DataType index_type, xla::ComputationBuilder* builder) { +Status XlaGather(const xla::ComputationDataHandle& input, + const TensorShape& input_shape, + const xla::ComputationDataHandle& indices, + TensorShape indices_shape, int64 axis, bool indices_are_nd, + DataType dtype, DataType index_type, + xla::ComputationBuilder* builder, + xla::ComputationDataHandle* gather_output) { + // If the indices are N-dimensional, then the minor dimension of indices + // should be of size N and correspond to the N indices. + int64 num_index_dims = 1; + if (indices_are_nd) { + CHECK_GE(indices_shape.dims(), 1); + num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); + indices_shape.RemoveLastDims(1); + } + // Although the indices Tensor is flattened into rank 1 during the lookup, // and each scalar entry is used as an index into the first dimension of the // input, the output is returned with shape: // input.shape[:axis] + indices.shape + input.shape[axis+1:] - const int num_indices = indices_shape.num_elements(); + + const int64 num_indices = indices_shape.num_elements(); TensorShape input_shape_pre_axis(input_shape); input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); TensorShape input_shape_post_axis(input_shape); - input_shape_post_axis.RemoveDimRange(0, axis + 1); - + input_shape_post_axis.RemoveDimRange(0, axis + num_index_dims); // Each slice of the input tensor has shape: - // [, 1, ] + // [, 1, ..., 1, ] TensorShape slice_shape(input_shape); - slice_shape.set_dim(axis, 1); - - // TODO(b/37575001) The tensor in which we construct the output during - // the loop must have rank >= 3 as a workaround for lowering issues. - int64 extra_dims = 0; - if (input_shape.dims() < 3) extra_dims = 3 - input_shape.dims(); + for (int64 i = 0; i < num_index_dims; ++i) { + slice_shape.set_dim(axis + i, 1); + } TensorShape loop_out_shape; - for (int64 k = 0; k < extra_dims; ++k) loop_out_shape.AddDim(1); loop_out_shape.AppendShape(input_shape_pre_axis); loop_out_shape.AddDim(num_indices); loop_out_shape.AppendShape(input_shape_post_axis); - - // Slices are reshaped into the rank >= 3 shape of the loop carried output. TensorShape loop_out_slice_shape; - for (int64 k = 0; k < extra_dims; ++k) loop_out_slice_shape.AddDim(1); loop_out_slice_shape.AppendShape(input_shape_pre_axis); loop_out_slice_shape.AddDim(1); loop_out_slice_shape.AppendShape(input_shape_post_axis); - // Finally, the loop-carried rank >= 3 output is reshaped to the op's - // specified result shape. TensorShape out_shape; out_shape.AppendShape(input_shape_pre_axis); out_shape.AppendShape(indices_shape); @@ -73,131 +75,176 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Degenerate case: empty indices. if (num_indices == 0) { - return builder->Broadcast(XlaHelpers::Zero(builder, dtype), - out_shape.dim_sizes()); + *gather_output = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + out_shape.dim_sizes()); + return Status::OK(); + } + + for (int64 i = 0; i < num_index_dims; ++i) { + if (input_shape.dim_size(axis + i) == 0) { + return errors::InvalidArgument("Gather dimension ", axis + i, + " is of size zero in tensor with shape ", + input_shape.DebugString()); + } + } + + // Flatten the major dimensions of indices into a single dimension for ease of + // iteration. If there is an axis dimension, we must leave it alone. + std::vector flat_indices_shape = {num_indices}; + if (indices_are_nd) { + flat_indices_shape.push_back(num_index_dims); } // Specify the shape of the loop-carried Tensor tuple. - xla::PrimitiveType ptype; - TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); - xla::PrimitiveType idxtype; - TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype)); - std::vector tuple_shapes( - {// The iteration counter i is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(idxtype, {}), - // The input array has shape input_shape. Loop invariant. - xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), - // The gather indices are reshaped to rank 1. Loop invariant. - xla::ShapeUtil::MakeShape(idxtype, {num_indices}), - // The output array is rank >= 3, and is updated on each loop iteration. - xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); // Construct the initial values of the loop-carried Tensors. - auto init_i = XlaHelpers::Zero(builder, index_type); + auto flat_indices = builder->Reshape(indices, flat_indices_shape); auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), loop_out_shape.dim_sizes()); - // Flatten the indices into 1-D for ease of iteration. - auto indices_1d = builder->Reshape(indices, {num_indices}); - auto init = builder->Tuple({init_i, input, indices_1d, init_out}); - - // Construct the while loop condition (i < num_indices) - xla::ComputationBuilder condb(context->builder()->client(), - "GatherWhileCond"); - condb.Lt(condb.GetTupleElement( - condb.Parameter(0, tuple_shape, "GatherWhileTuple"), 0), - XlaHelpers::IntegerLiteral(&condb, index_type, num_indices)); - auto cond_status = condb.Build(); - auto cond = cond_status.ConsumeValueOrDie(); + auto init = {input, flat_indices, init_out}; // Construct the while loop body's function. The implementation of gather is: // for i in range(num_indices): // index = dynamic-slice(indices, i) // xi = dynamic-slice(input, index) // output = dynamic-update-slice(output, xi, i) - xla::ComputationBuilder bodyb(context->builder()->client(), - "GatherWhileBody"); - { - // The four loop carried values. - auto loop_tuple = bodyb.Parameter(0, tuple_shape, "GatherWhileTuple"); - auto i = bodyb.GetTupleElement(loop_tuple, 0); - auto input = bodyb.GetTupleElement(loop_tuple, 1); - auto indices = bodyb.GetTupleElement(loop_tuple, 2); - auto output = bodyb.GetTupleElement(loop_tuple, 3); - - // Slice from the input array. - auto index = bodyb.DynamicSlice(indices, bodyb.Reshape(i, {1}), {1}); - auto start_indices = bodyb.Pad( - bodyb.Reshape(index, {1}), XlaHelpers::Zero(&bodyb, index_type), + auto body_fn = [&](xla::ComputationDataHandle i, + gtl::ArraySlice loop_vars, + xla::ComputationBuilder* bodyb) { + auto input = loop_vars[0]; + auto indices = loop_vars[1]; + auto output = loop_vars[2]; + + auto zero_index = XlaHelpers::Zero(bodyb, index_type); + + // Slice the i-th index from the indices array. + xla::ComputationDataHandle index; + auto indices_offset = bodyb->Reshape(i, {1}); + if (indices_are_nd) { + // Slice out the entire nd index, if applicable. + indices_offset = bodyb->Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); + index = bodyb->DynamicSlice(indices, indices_offset, {1, num_index_dims}); + index = bodyb->Collapse(index, {0, 1}); + } else { + index = bodyb->DynamicSlice(indices, indices_offset, {1}); + } + + // Slice the corresponding data from the input array. + auto start_indices = bodyb->Pad( + index, zero_index, xla::MakeEdgePaddingConfig( {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); - auto slice_i = bodyb.Reshape( - bodyb.DynamicSlice(input, start_indices, slice_shape.dim_sizes()), + auto slice_i = bodyb->Reshape( + bodyb->DynamicSlice(input, start_indices, slice_shape.dim_sizes()), loop_out_slice_shape.dim_sizes()); - // Construct the index into the R3+ output Tensor 0, ..., , 0, ... + // Construct the index into the output Tensor 0, ..., , 0, ... std::vector out_index_vals( - loop_out_shape.dims(), - bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1})); - out_index_vals[input_shape_pre_axis.dims() + extra_dims] = - bodyb.Reshape(i, {1}); - auto out_index = bodyb.ConcatInDim(out_index_vals, 0); + loop_out_shape.dims(), bodyb->Reshape(zero_index, {1})); + out_index_vals[input_shape_pre_axis.dims()] = bodyb->Reshape(i, {1}); + auto out_index = bodyb->ConcatInDim(out_index_vals, 0); // Update the output Tensor - auto updated_output = bodyb.DynamicUpdateSlice(output, slice_i, out_index); + auto updated_output = bodyb->DynamicUpdateSlice(output, slice_i, out_index); - bodyb.Tuple({bodyb.Add(i, XlaHelpers::One(&bodyb, index_type)), input, - indices, updated_output}); - } - auto body_status = bodyb.Build(); - auto body = body_status.ConsumeValueOrDie(); + return std::vector{input, indices, + updated_output}; + }; // Construct the While loop, extract and reshape the output. - auto gather_while = builder->While(cond, body, init); - auto gather_output = builder->GetTupleElement(gather_while, 3); - return builder->Reshape(gather_output, out_shape.dim_sizes()); + xla::PrimitiveType ptype; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(index_type, &ptype)); + TF_ASSIGN_OR_RETURN(auto outputs, XlaForEachIndex(num_indices, ptype, body_fn, + init, "gather", builder)); + *gather_output = builder->Reshape(outputs[2], out_shape.dim_sizes()); + return Status::OK(); } -GatherOpDynamicSlice::GatherOpDynamicSlice(OpKernelConstruction* context) - : XlaOpKernel(context) {} - -void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { - xla::ComputationBuilder* builder = context->builder(); - auto input = context->Input(0); - auto input_shape = context->InputShape(0); - auto indices = context->Input(1); - auto indices_shape = context->InputShape(1); - int64 axis = 0; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); - const auto params_dims = input_shape.dims(); - if (axis < 0) { - axis += params_dims; +class GatherOp : public XlaOpKernel { + public: + explicit GatherOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + xla::ComputationBuilder* builder = context->builder(); + auto input = context->Input(0); + auto input_shape = context->InputShape(0); + auto indices = context->Input(1); + auto indices_shape = context->InputShape(1); + int64 axis = 0; + if (context->num_inputs() == 3) { + const TensorShape axis_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), + errors::InvalidArgument("axis must be scalar")); + DataType axis_type = input_type(2); + OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, + errors::InvalidArgument("axis must be int32 or int64")); + + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); + const auto params_dims = input_shape.dims(); + if (axis < 0) { + axis += params_dims; + } + OP_REQUIRES( + context, 0 <= axis && axis < params_dims, + errors::InvalidArgument("Expected axis in the range [", -params_dims, + ", ", params_dims, "), but got ", axis)); } - OP_REQUIRES( - context, 0 <= axis && axis < params_dims, - errors::InvalidArgument("Expected axis in the range [", -params_dims, - ", ", params_dims, "), but got ", axis)); - } - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("indices must be int32 or int64")); + DataType index_type = input_type(1); + OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("indices must be int32 or int64")); - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - context, input, input_shape, indices, indices_shape, axis, input_type(0), - index_type, builder); - context->SetOutput(0, gather); -} + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + context, XlaGather(input, input_shape, indices, indices_shape, axis, + /*indices_are_nd=*/false, input_type(0), index_type, + builder, &gather)); + context->SetOutput(0, gather); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(GatherOp); +}; + +REGISTER_XLA_OP(Name("Gather"), GatherOp); +REGISTER_XLA_OP(Name("GatherV2").CompileTimeConstInput("axis"), GatherOp); + +class GatherNdOp : public XlaOpKernel { + public: + explicit GatherNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + DataType params_type = context->input_type(0); + DataType indices_type = context->input_type(1); + + TensorShape params_shape = context->InputShape(0); + TensorShape indices_shape = context->InputShape(1); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(params_shape), + errors::InvalidArgument("params must be at least a vector")); + OP_REQUIRES(context, TensorShapeUtils::IsVectorOrHigher(indices_shape), + errors::InvalidArgument("indices must be at least a vector")); + const int64 num_index_dims = + indices_shape.dim_size(indices_shape.dims() - 1); + OP_REQUIRES( + context, num_index_dims <= params_shape.dims(), + errors::InvalidArgument( + "index innermost dimension length must be <= params rank; saw: ", + indices_shape.dim_size(indices_shape.dims() - 1), " vs. ", + params_shape.dims())); + + xla::ComputationBuilder* builder = context->builder(); + auto params = context->Input(0); + auto indices = context->Input(1); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK(context, XlaGather(params, params_shape, indices, + indices_shape, /*axis=*/0, + /*indices_are_nd=*/true, params_type, + indices_type, builder, &gather)); + context->SetOutput(0, gather); + } +}; -REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); -REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherNd"), GatherNdOp); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 2c80395c56d73adad7dc1679ba6423fbe103605a..bd8b92c22d71fe89ab8951ec79f411feef6505e3 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -30,11 +30,16 @@ namespace tensorflow { // shape input_shape) keyed on indices (of shape indices_shape). // // index_type must be must be DT_INT32 or DT_INT64. -xla::ComputationDataHandle XlaComputeGatherDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 axis, DataType dtype, - DataType index_type, xla::ComputationBuilder* builder); +// If `indices_are_nd` is true, the last dimension of `indices` are treated as +// a multidimensional index values. Otherwise, `indices` is treated as a tensor +// of scalar indices. +Status XlaGather(const xla::ComputationDataHandle& input, + const TensorShape& input_shape, + const xla::ComputationDataHandle& indices, + TensorShape indices_shape, int64 axis, bool indices_are_nd, + DataType dtype, DataType index_type, + xla::ComputationBuilder* builder, + xla::ComputationDataHandle* gather_output); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f22f384256a8ddd8c05de4a1322aba741dc4d7fd --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/image_ops.cc @@ -0,0 +1,305 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +// Converts 'input' from RGB format to HSV format. +// 'shape' is the shape of the red/green/blue tensors. +std::array RGBToHSV( + XlaOpKernelContext* ctx, xla::ComputationBuilder* b, + const std::array& rgb, DataType dtype, + const TensorShape& shape) { + auto zero = XlaHelpers::Zero(b, dtype); + auto one = XlaHelpers::One(b, dtype); + + auto red = rgb[0]; + auto green = rgb[1]; + auto blue = rgb[2]; + auto value = b->Max(b->Max(red, green), blue); + auto minimum = b->Min(b->Min(red, green), blue); + auto range = b->Sub(value, minimum); + + auto zeros = b->Broadcast(zero, shape.dim_sizes()); + auto saturation = b->Select(b->Gt(value, zero), b->Div(range, value), zeros); + + auto norm = b->Div(XlaHelpers::FloatLiteral(b, dtype, 1.0 / 6.0), range); + + auto hue = b->Select(b->Eq(green, value), + b->Add(b->Mul(norm, b->Sub(blue, red)), + XlaHelpers::FloatLiteral(b, dtype, 2.0 / 6.0)), + b->Add(b->Mul(norm, b->Sub(red, green)), + XlaHelpers::FloatLiteral(b, dtype, 4.0 / 6.0))); + hue = b->Select(b->Eq(red, value), b->Mul(norm, b->Sub(green, blue)), hue); + hue = b->Select(b->Gt(range, zero), hue, zeros); + hue = b->Select(b->Lt(hue, zero), b->Add(hue, one), hue); + return {hue, saturation, value}; +} + +// Converts 'input' from HSV format to RGB format. +std::array HSVToRGB( + xla::ComputationBuilder* b, + const std::array& hsv, DataType dtype) { + xla::ComputationDataHandle hue = hsv[0]; + xla::ComputationDataHandle saturation = hsv[1]; + xla::ComputationDataHandle value = hsv[2]; + auto zero = XlaHelpers::Zero(b, dtype); + auto one = XlaHelpers::FloatLiteral(b, dtype, 1.0); + auto two = XlaHelpers::FloatLiteral(b, dtype, 2.0); + auto three = XlaHelpers::FloatLiteral(b, dtype, 3.0); + auto four = XlaHelpers::FloatLiteral(b, dtype, 4.0); + auto six = XlaHelpers::FloatLiteral(b, dtype, 6.0); + + auto dh = b->Mul(hue, six); + auto dr = b->Clamp(zero, b->Sub(b->Abs(b->Sub(dh, three)), one), one); + auto dg = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, two))), one); + auto db = b->Clamp(zero, b->Sub(two, b->Abs(b->Sub(dh, four))), one); + auto one_minus_s = b->Sub(one, saturation); + + auto red = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dr)), value); + auto green = b->Mul(b->Add(one_minus_s, b->Mul(saturation, dg)), value); + auto blue = b->Mul(b->Add(one_minus_s, b->Mul(saturation, db)), value); + return {red, green, blue}; +} + +class RGBToHSVOp : public XlaOpKernel { + public: + explicit RGBToHSVOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be at least 1D", + input_shape.DebugString())); + int channel_dim = input_shape.dims() - 1; + int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::FailedPrecondition("input must have 3 channels but input has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + + xla::ComputationDataHandle red = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle green = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle blue = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + TensorShape channel_shape = input_shape; + channel_shape.set_dim(channel_dim, 1); + auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), + channel_shape); + + context->SetOutput(0, b->ConcatInDim(hsv, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("RGBToHSV"), RGBToHSVOp); + +class HSVToRGBOp : public XlaOpKernel { + public: + explicit HSVToRGBOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + OP_REQUIRES(context, input_shape.dims() >= 1, + errors::InvalidArgument("input must be at least 1D", + input_shape.DebugString())); + int channel_dim = input_shape.dims() - 1; + int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::FailedPrecondition("input must have 3 channels but input has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle hue = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle saturation = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle value = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + + auto rgb = HSVToRGB(context->builder(), {hue, saturation, value}, + context->input_type(0)); + + context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("HSVToRGB"), HSVToRGBOp); + +class AdjustContrastOpV2 : public XlaOpKernel { + public: + explicit AdjustContrastOpV2(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape& input_shape = context->InputShape(0); + const TensorShape& factor_shape = context->InputShape(1); + OP_REQUIRES(context, input_shape.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input_shape.DebugString())); + int height_dim = input_shape.dims() - 3; + int width_dim = input_shape.dims() - 2; + int channel_dim = input_shape.dims() - 1; + const int64 height = input_shape.dim_size(height_dim); + const int64 width = input_shape.dim_size(width_dim); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(factor_shape), + errors::InvalidArgument("contrast_factor must be scalar: ", + factor_shape.DebugString())); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle factor = context->Input(1); + + DataType type = context->input_type(0); + + auto output = b->Reduce(input, /*init_value=*/XlaHelpers::Zero(b, type), + /*computation=*/*context->GetOrCreateAdd(type), + {height_dim, width_dim}); + output = b->Div(output, XlaHelpers::FloatLiteral(b, type, height * width)); + + std::vector broadcast_dims(input_shape.dims() - 2); + std::iota(broadcast_dims.begin(), broadcast_dims.end(), 0); + broadcast_dims.back() = channel_dim; + output = b->Add(b->Mul(input, factor), + b->Mul(output, b->Sub(XlaHelpers::One(b, type), factor)), + broadcast_dims); + context->SetOutput(0, output); + } +}; +REGISTER_XLA_OP(Name("AdjustContrastv2"), AdjustContrastOpV2); + +class AdjustSaturationOp : public XlaOpKernel { + public: + explicit AdjustSaturationOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape& input_shape = context->InputShape(0); + const TensorShape& scale_shape = context->InputShape(1); + OP_REQUIRES(context, input_shape.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input_shape.DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(scale_shape), + errors::InvalidArgument("scale must be scalar: ", + scale_shape.DebugString())); + const int channel_dim = input_shape.dims() - 1; + const int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::InvalidArgument("input must have 3 channels but instead has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle scale = context->Input(1); + + DataType type = context->input_type(0); + + xla::ComputationDataHandle red = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle green = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle blue = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + TensorShape channel_shape = input_shape; + channel_shape.set_dim(channel_dim, 1); + auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), + channel_shape); + + hsv[1] = b->Clamp(XlaHelpers::Zero(b, type), b->Mul(hsv[1], scale), + XlaHelpers::One(b, type)); + + auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + + context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("AdjustSaturation"), AdjustSaturationOp); + +class AdjustHueOp : public XlaOpKernel { + public: + explicit AdjustHueOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape& input_shape = context->InputShape(0); + const TensorShape& delta_shape = context->InputShape(1); + OP_REQUIRES(context, input_shape.dims() >= 3, + errors::InvalidArgument("input must be at least 3-D, got shape", + input_shape.DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(delta_shape), + errors::InvalidArgument("delta must be scalar: ", + delta_shape.DebugString())); + const int channel_dim = input_shape.dims() - 1; + const int64 channels = input_shape.dim_size(channel_dim); + OP_REQUIRES( + context, channels == 3, + errors::InvalidArgument("input must have 3 channels but instead has ", + channels, " channels.")); + + xla::ComputationBuilder* b = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle delta = context->Input(1); + + DataType type = context->input_type(0); + + xla::ComputationDataHandle red = + b->SliceInDim(input, /*start_index=*/0, /*limit_index=*/1, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle green = + b->SliceInDim(input, /*start_index=*/1, /*limit_index=*/2, /*stride=*/1, + /*dimno=*/channel_dim); + xla::ComputationDataHandle blue = + b->SliceInDim(input, /*start_index=*/2, /*limit_index=*/3, /*stride=*/1, + /*dimno=*/channel_dim); + TensorShape channel_shape = input_shape; + channel_shape.set_dim(channel_dim, 1); + auto hsv = RGBToHSV(context, b, {red, green, blue}, context->input_type(0), + channel_shape); + + auto zero = XlaHelpers::Zero(b, type); + auto one = XlaHelpers::One(b, type); + + auto& hue = hsv[0]; + hue = b->Rem(b->Add(hsv[0], delta), one); + hue = b->Select(b->Lt(hue, zero), b->Rem(b->Add(one, hue), one), hue); + + auto rgb = HSVToRGB(context->builder(), hsv, context->input_type(0)); + + context->SetOutput(0, b->ConcatInDim(rgb, channel_dim)); + } +}; +REGISTER_XLA_OP(Name("AdjustHue"), AdjustHueOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f36b3f594826c27b7866d956c855aa3638db9cb4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/image_resize_ops.cc @@ -0,0 +1,449 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/array4d.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/lib/math/math_util.h" + +namespace tensorflow { +namespace { + +// We implement bilinear interpolation by upsampling followed by convolution. +// The basic idea is as follows. To scale from NxN to RxR: +// +// 1. S := (N - 1) / gcd(N-1, R-1) +// 2. k := (R - 1) / gcd(N-1, R-1) +// 3. Convolution(kxk, stride=S, lhs_dilation=k, padding=k-1) +// +// For example, to Scale from 7x7 -> 15x15: +// +// 1. S := (7-1) / gcd(7-1, 15-1) = 6 / gcd(6, 14) = 6 / 2 = 3 +// 2. k := (15 - 1) / gcd(7-1, 15-1) = 14 / gcd(6, 14) = 14 / 2 = 7 +// 3. Convolution(7x7, stride=3, lhs_dilation=3, padding=2) +// +// +// The 7x7 -> 15x15 case is much too large to write out in full as an +// example. The smallest interesting example is 3x3 -> 4x4. +// +// S := 2 +// k := 3 +// +// 00 03 06 00 00 00 00 00 00 00 00 00 00 00 00 02 04 06 +// 09 12 15 -> 00 00 00 00 00 00 00 00 00 00 00 -> 06 08 10 12 +// 18 21 24 00 00 00 00 00 03 00 00 06 00 00 12 14 16 18 +// 00 00 00 00 00 00 00 00 00 00 00 18 20 22 24 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 09 00 00 12 00 00 15 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 18 00 00 21 00 00 24 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// 00 00 00 00 00 00 00 00 00 00 00 +// +// with the following convolutional kernel, with stride [2, 2]: +// 1 2 3 2 1 +// 2 4 6 4 2 +// 1/9 * 3 6 9 6 3 +// 2 4 6 4 2 +// 1 2 3 2 1 + +// Computes the size of the convolutional kernel and stride to use when resizing +// from in_size to out_size. +struct ResizeConvolutionDims { + // Size of the kernel to use. + std::vector kernel_size; + + // Stride of the convolution to use. + std::vector stride; +}; +ResizeConvolutionDims ComputeResizeConvolutionParameters( + gtl::ArraySlice in_size, gtl::ArraySlice out_size) { + CHECK_EQ(in_size.size(), out_size.size()); + int num_spatial_dims = in_size.size(); + ResizeConvolutionDims dims; + dims.kernel_size.resize(num_spatial_dims); + dims.stride.resize(num_spatial_dims); + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1) { + // We must handle input size 1 specially because XLA convolution does + // not allow stride 0. + dims.stride[i] = dims.kernel_size[i] = 1; + } else if (out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + dims.stride[i] = dims.kernel_size[i] = 1; + } else { + int64 gcd = MathUtil::GCD(static_cast(in_size[i] - 1), + static_cast(out_size[i] - 1)); + dims.stride[i] = (in_size[i] - 1) / gcd; + dims.kernel_size[i] = (out_size[i] - 1) / gcd; + } + } + return dims; +} + +xla::ComputationDataHandle MakeBilinearResizeKernel( + xla::ComputationBuilder* builder, gtl::ArraySlice kernel_size, + int64 channels) { + // Form a 2D convolution kernel like: + // 1 2 3 2 1 + // 2 4 6 4 2 + // 1/9 * 3 6 9 6 3 + // 2 4 6 4 2 + // 1 2 3 2 1 + // by multiplying two 1D kernels of the form: + // 1/3 * [1 2 3 2 1] + auto make_1d_kernel = [](int64 n) { + std::vector kernel(n * 2 - 1); + for (int64 i = 0; i < n; ++i) { + float v = (i + 1.0f) / n; + kernel[i] = v; + kernel[n * 2 - 2 - i] = v; + } + return kernel; + }; + + xla::ComputationDataHandle channels_iota; + // DT_INT32 Iota will always return status::OK(). + TF_CHECK_OK( + XlaHelpers::Iota(builder, DataType::DT_INT32, channels, &channels_iota)); + + auto diag = builder->ConvertElementType( + builder->Eq( + builder->Broadcast(channels_iota, {2 * kernel_size[0] - 1, + 2 * kernel_size[1] - 1, channels}), + channels_iota, /*broadcast_dimensions=*/{2}), + xla::PrimitiveType::F32); + return builder->Mul( + builder->Mul(diag, + builder->ConstantR1(make_1d_kernel(kernel_size[1])), + /*broadcast_dimensions=*/{1}), + builder->ConstantR1(make_1d_kernel(kernel_size[0])), + /*broadcast_dimensions=*/{0}); +} + +xla::ComputationDataHandle ResizeUsingDilationAndConvolution( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& input, + const int num_spatial_dims, std::vector in_size, + std::vector out_size, const int64 channels) { + // Picture for a 1x3 to 1x4 resize: + // stride = 2, kernel size = 3 + // Input: + // 3 6 9 + // Input with dilation and padding: + // 0 0 3 0 0 6 0 0 9 0 0 + // Convolution kernel: + // 1/3 * [1 2 3 2 1] + // Output: + // 3 5 7 9 + xla::ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(0); + dimension_numbers.set_output_batch_dimension(0); + dimension_numbers.set_input_feature_dimension(3); + dimension_numbers.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(1 + i); + dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_kernel_spatial_dimensions(i); + } + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, out_size); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + xla::ComputationDataHandle output = builder->ConvGeneralDilated( + input, kernel, dims.stride, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.kernel_size, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + // Add broadcasts to handle expanding from a size == 1 dimension to a + // size > 1 dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && out_size[i] > 1) { + output = builder->Add(output, builder->ConstantR1(out_size[i], 0), + /*broadcast_dimensions=*/{1 + i}); + } + } + return output; +} + +xla::ComputationDataHandle ResizeUsingDilationAndConvolutionGradOp( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& grad, + const int num_spatial_dims, std::vector in_size, + std::vector grad_size, const int64 channels) { + ResizeConvolutionDims dims = + ComputeResizeConvolutionParameters(in_size, grad_size); + + // To form the backward convolution, we keep the kernel unchanged (it is + // already symmetric) and swap the roles of strides and LHS dilation. + xla::ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(0); + dimension_numbers.set_output_batch_dimension(0); + dimension_numbers.set_input_feature_dimension(3); + dimension_numbers.set_output_feature_dimension(3); + for (int i = 0; i < num_spatial_dims; ++i) { + dimension_numbers.add_input_spatial_dimensions(1 + i); + dimension_numbers.add_output_spatial_dimensions(1 + i); + dimension_numbers.add_kernel_spatial_dimensions(i); + } + dimension_numbers.set_kernel_input_feature_dimension(num_spatial_dims); + dimension_numbers.set_kernel_output_feature_dimension(num_spatial_dims + 1); + xla::ComputationDataHandle kernel = + MakeBilinearResizeKernel(builder, dims.kernel_size, channels); + + // Broadcast the input kernel where the forward op expanded from a size == 1 + // dimension to a size > 1 dimension. This has the effect of summing the + // gradient contributions in that dimension. + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] == 1 && grad_size[i] > 1) { + kernel = builder->Add(kernel, builder->ConstantR1(grad_size[i], 0), + /*broadcast_dimensions=*/{i}); + } + } + + xla::ComputationDataHandle output = builder->ConvGeneralDilated( + grad, kernel, /*window_strides=*/dims.kernel_size, + /*padding=*/ + {{dims.kernel_size[0] - 1, dims.kernel_size[0] - 1}, + {dims.kernel_size[1] - 1, dims.kernel_size[1] - 1}}, + /*lhs_dilation=*/dims.stride, + /*rhs_dilation=*/{1, 1}, dimension_numbers); + + // If in_size[i] > 1 and grad_size[i] == 1, pad the output in dimension i. + // Opposite of the slice performed by the forward op. + xla::PaddingConfig padding = xla::MakeNoPaddingConfig(4); + bool pad_output = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && grad_size[i] == 1) { + pad_output = true; + padding.mutable_dimensions(1 + i)->set_edge_padding_high(in_size[i] - 1); + } + } + if (pad_output) { + output = builder->Pad(output, builder->ConstantR0(0.0f), padding); + } + return output; +} + +class ResizeBilinearOp : public XlaOpKernel { + public: + explicit ResizeBilinearOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented( + "ResizeBilinear with align_corners=False is not yet implemented")); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + std::vector out_size; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &out_size)); + OP_REQUIRES(ctx, out_size.size() == 2, + errors::InvalidArgument("output size must be length 2, got ", + out_size.size())); + OP_REQUIRES(ctx, out_size[0] > 0 && out_size[1] > 0, + errors::InvalidArgument("output size must be positive, got [", + out_size[0], ",", out_size[1], "]")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle input = ctx->Input(0); + + // If in_size[i] > 1 and out_size[i] == 1, slice out the first input in + // dimension i. + std::vector slice_size = in_size; + bool slice_input = false; + for (int i = 0; i < num_spatial_dims; ++i) { + if (in_size[i] > 1 && out_size[i] == 1) { + // If in_size[i] > 1 but out_size[i] == 1, then we slice out the first + // entry before resizing. + slice_input = true; + slice_size[i] = 1; + } + } + if (slice_input) { + input = b->Slice(input, {0, 0, 0, 0}, + {batch, slice_size[0], slice_size[1], channels}, + {1, 1, 1, 1}); + } + + // Output is always type float. + input = b->ConvertElementType(input, xla::F32); + + // Special Case: + // Instead of doing a ResizeUsingDilationAndConvolution directly, + // while (out_size[0]-1) = c * 2^x * (in_size[0]-1) for x>1 c>1, resize the + // image to 2*(in_size[0]-1)+1 x-times and then resize by scale c(int here). + // Instead of resizing directly we resize it iteratively. + // + // Since bilinear resize can be broken down as 2 sequential linear + // operations along different dimensions. + // Given sufficient numerical stability and a cxd is same as resizing axb -> exf -> cxd. + // + // This makes the convolutions kernels smaller and the operation faster. + xla::ComputationDataHandle output = input; + while (in_size != out_size) { + if (in_size[0] != 1 && in_size[1] != 1) { + std::vector k = { + (static_cast(out_size[0]) - 1) / ((in_size[0] - 1) * 2), + (static_cast(out_size[1]) - 1) / ((in_size[1] - 1) * 2)}; + if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && + k[0] > 1 && k[1] > 1) { + std::vector next_out_size = {(in_size[0] - 1) * 2 + 1, + (in_size[1] - 1) * 2 + 1}; + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, next_out_size, channels); + input = output; + in_size = next_out_size; + } else { + output = ResizeUsingDilationAndConvolution( + b, input, num_spatial_dims, in_size, out_size, channels); + in_size = out_size; + } + } else { + output = ResizeUsingDilationAndConvolution(b, input, num_spatial_dims, + in_size, out_size, channels); + in_size = out_size; + } + } + + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinear").CompileTimeConstInput("size"), + ResizeBilinearOp); + +class ResizeBilinearGradOp : public XlaOpKernel { + public: + explicit ResizeBilinearGradOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("align_corners", &align_corners_)); + OP_REQUIRES( + ctx, align_corners_ == true, + errors::Unimplemented("ResizeBilinearGrad with align_corners=False is " + "not yet implemented")); + + DataType output_dtype; + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &output_dtype)); + OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(output_dtype, &output_type_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + xla::ComputationBuilder* b = ctx->builder(); + + TensorShape input_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, input_shape.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input_shape.DebugString())); + const int64 batch = input_shape.dim_size(0); + std::vector in_size = {input_shape.dim_size(1), + input_shape.dim_size(2)}; + const int64 channels = input_shape.dim_size(3); + OP_REQUIRES(ctx, in_size[0] > 0 && in_size[1] > 0, + errors::InvalidArgument("input size must be positive, got [", + in_size[0], ",", in_size[1], "]")); + + TensorShape grad_shape = ctx->InputShape(0); + OP_REQUIRES(ctx, grad_shape.dims() == 4, + errors::InvalidArgument("gradient must be 4-dimensional", + grad_shape.DebugString())); + const int64 grad_batch = grad_shape.dim_size(0); + const std::vector grad_size = {grad_shape.dim_size(1), + grad_shape.dim_size(2)}; + const int64 grad_channels = grad_shape.dim_size(3); + OP_REQUIRES(ctx, batch == grad_batch, + errors::InvalidArgument( + "activations and gradients must have the same batch size (", + batch, " vs. ", grad_batch, ")")); + OP_REQUIRES(ctx, grad_size[0] > 0 && grad_size[1] > 0, + errors::InvalidArgument("gradient size must be positive, got [", + grad_size[0], ",", grad_size[1], "]")); + OP_REQUIRES( + ctx, channels == grad_channels, + errors::InvalidArgument( + "activations and gradients must have the same number of channels (", + channels, " vs. ", grad_channels, ")")); + + const int num_spatial_dims = 2; + + xla::ComputationDataHandle grad = ctx->Input(0); + + xla::ComputationDataHandle output = grad; + while (in_size != grad_size) { + if (in_size[0] != 1 && in_size[1] != 1) { + std::vector k = { + (static_cast(grad_size[0]) - 1) / ((in_size[0] - 1) * 2), + (static_cast(grad_size[1]) - 1) / ((in_size[1] - 1) * 2)}; + if ((k[0] == std::floor(k[0])) && (k[1] == std::floor(k[1])) && + k[0] > 1 && k[1] > 1) { + std::vector next_grad_size = {(in_size[0] - 1) * 2 + 1, + (in_size[1] - 1) * 2 + 1}; + output = ResizeUsingDilationAndConvolutionGradOp( + b, grad, num_spatial_dims, in_size, next_grad_size, channels); + grad = output; + in_size = next_grad_size; + } else { + output = ResizeUsingDilationAndConvolutionGradOp( + b, grad, num_spatial_dims, in_size, grad_size, channels); + in_size = grad_size; + } + } else { + output = ResizeUsingDilationAndConvolutionGradOp( + b, grad, num_spatial_dims, in_size, grad_size, channels); + in_size = grad_size; + } + } + + output = b->ConvertElementType(output, output_type_); + ctx->SetOutput(0, output); + } + + private: + bool align_corners_; + xla::PrimitiveType output_type_; +}; + +REGISTER_XLA_OP(Name("ResizeBilinearGrad"), ResizeBilinearGradOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index e0dc1870f2a4934c35163f0cc10196e8fcbed9be..7bf4b435f526afa93d8a218b191928acb932cd6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -80,7 +80,10 @@ void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} -REGISTER_XLA_OP(Name("ArgMax").Device(DEVICE_GPU_XLA_JIT), XlaArgMaxOp); +REGISTER_XLA_OP(Name("ArgMax") + .Device(DEVICE_GPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + XlaArgMaxOp); namespace { @@ -90,7 +93,7 @@ class XlaArgMinOp : public XlaArgMinMaxOp { }; XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx) : XlaArgMinMaxOp(ctx, /*is_min=*/true) {} -REGISTER_XLA_OP(Name("ArgMin"), XlaArgMinOp); +REGISTER_XLA_OP(Name("ArgMin").CompileTimeConstInput("dimension"), XlaArgMinOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 20946e247a9459d7c8a0d8a666fef24bd32838f2..b1f3c3c298ce0cadf38b9bda715761fe7e2896d7 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -56,10 +56,10 @@ class ArgMaxCustomCallOp : public XlaOpKernel { errors::InvalidArgument("dim must be < input rank (", input_shape.dims(), "), but got: ", dim)); const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES( - ctx, dim_size > 0, - errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ", - input_shape.DebugString())); + OP_REQUIRES(ctx, dim_size > 0, + errors::InvalidArgument( + "Reduction axis ", dim, + " is empty in shape: ", input_shape.DebugString())); // The output shape is the input shape contracted along dim. TensorShape output_shape; @@ -113,9 +113,11 @@ class ArgMaxCustomCallOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp); }; -REGISTER_XLA_OP( - Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT), - ArgMaxCustomCallOp); +REGISTER_XLA_OP(Name("ArgMax") + .TypeConstraint("T", DT_FLOAT) + .Device(DEVICE_CPU_XLA_JIT) + .CompileTimeConstInput("dimension"), + ArgMaxCustomCallOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index fcef497e5845d9080bc83b54e92dcf2fdecf5f12..886baf8115243a22b7255a3961c914d4cf6c2ed5 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,16 +23,18 @@ limitations under the License. namespace tensorflow { namespace { -constexpr std::array kMatmulTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; class MatMulOp : public XlaOpKernel { public: explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false) - : XlaOpKernel(ctx) { + : XlaOpKernel(ctx), is_sparse_(is_sparse) { OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_a", &transpose_a_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("transpose_b", &transpose_b_)); if (is_sparse) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("Ta", &a_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("Tb", &b_type_)); // SparseMatMul is actually dense matmul with a hint that one or // both of the inputs may contain a lot of zeroes. On CPU these // inputs are dynamically converted to sparse representation @@ -66,14 +68,25 @@ class MatMulOp : public XlaOpKernel { xla::ComputationDataHandle a = ctx->Input(0); xla::ComputationDataHandle b = ctx->Input(1); + if (is_sparse_) { + if (a_type_ == DT_BFLOAT16) { + a = ctx->builder()->ConvertElementType(a, xla::F32); + } + if (b_type_ == DT_BFLOAT16) { + b = ctx->builder()->ConvertElementType(b, xla::F32); + } + } auto lhs = (transpose_a_) ? ctx->builder()->Transpose(a, {1, 0}) : a; auto rhs = (transpose_b_) ? ctx->builder()->Transpose(b, {1, 0}) : b; ctx->SetOutput(0, ctx->builder()->Dot(lhs, rhs)); } private: + bool is_sparse_; bool transpose_a_; bool transpose_b_; + DataType a_type_; + DataType b_type_; }; REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kMatmulTypes), MatMulOp); @@ -85,10 +98,7 @@ class SparseMatMulOp : public MatMulOp { ~SparseMatMulOp() override = default; }; -REGISTER_XLA_OP(Name("SparseMatMul") - .TypeConstraint("Ta", kFloatTypes) - .TypeConstraint("Tb", kFloatTypes), - SparseMatMulOp); +REGISTER_XLA_OP(Name("SparseMatMul"), SparseMatMulOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..faa415a97b053b4b11d015fefcd430210b98118a --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_band_part_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class MatrixBandPartOp : public XlaOpKernel { + public: + explicit MatrixBandPartOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + const TensorShape num_lower_in_shape = context->InputShape(1); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_lower_in_shape), + errors::InvalidArgument("num_lower must be scalar, got shape ", + num_lower_in_shape.DebugString())); + + const TensorShape num_upper_in_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_upper_in_shape), + errors::InvalidArgument("num_upper must be scalar, got shape ", + num_upper_in_shape.DebugString())); + + xla::ComputationBuilder* builder = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle num_lower = context->Input(1); + xla::ComputationDataHandle num_upper = context->Input(2); + DataType input_type = context->input_type(0); + DataType index_type = context->input_type(1); + + TensorShape batch_shape = input_shape; + batch_shape.RemoveLastDims(2); + const int64 m = input_shape.dim_size(input_shape.dims() - 2); + const int64 n = input_shape.dim_size(input_shape.dims() - 1); + + // Compute 'offset', which is how many diagonals we are above/below the + // diagonal. + xla::ComputationDataHandle iota_m; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, m, &iota_m)); + + xla::ComputationDataHandle iota_n; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, index_type, n, &iota_n)); + + auto offset = builder->Sub(builder->Broadcast(iota_n, {m}), iota_m, + /*broadcast_dimensions=*/{0}); + + // If num_lower or num_upper are negative, include all lower/upper + // diagonals. + auto zero_index = XlaHelpers::Zero(builder, index_type); + num_lower = builder->Select( + builder->Lt(num_lower, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, m), num_lower); + num_upper = builder->Select( + builder->Lt(num_upper, zero_index), + XlaHelpers::IntegerLiteral(builder, index_type, n), num_upper); + + auto indicator = builder->And(builder->Le(builder->Neg(num_lower), offset), + builder->Le(offset, num_upper)); + indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + + auto zero_input = XlaHelpers::Zero(builder, input_type); + auto output = builder->Select( + indicator, input, + builder->Broadcast(zero_input, input_shape.dim_sizes())); + + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixBandPartOp); +}; +REGISTER_XLA_OP(Name("MatrixBandPart"), MatrixBandPartOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b2940bdcff75a087c914fdad0cb2426276e41aff --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_set_diag_op.cc @@ -0,0 +1,93 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { + +class MatrixSetDiagOp : public XlaOpKernel { + public: + explicit MatrixSetDiagOp(OpKernelConstruction* context) + : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape diag_shape = context->InputShape(1); + + const int rank = input_shape.dims(); + + // Preliminary validation of sizes. + OP_REQUIRES(context, TensorShapeUtils::IsMatrixOrHigher(input_shape), + errors::InvalidArgument( + "input must be at least 2-dim, received shape: ", + input_shape.DebugString())); + + // Check to make sure the last dimension of diag is equal to the smaller of + // the last two dimensions of input. + const int64 m = input_shape.dim_size(rank - 2); + const int64 n = input_shape.dim_size(rank - 1); + const int64 min_dim = std::min(m, n); + + TensorShape batch_shape = input_shape; + batch_shape.RemoveLastDims(2); + + TensorShape expected_diag_shape = batch_shape; + expected_diag_shape.AddDim(min_dim); + OP_REQUIRES(context, expected_diag_shape == diag_shape, + errors::InvalidArgument( + "must have diagonal.shape == input.shape[:-2] + " + "min(input.shape[-2:]), but received input shape: ", + input_shape.DebugString(), + " and diagonal shape: ", diag_shape.DebugString())); + + xla::ComputationBuilder* builder = context->builder(); + xla::ComputationDataHandle input = context->Input(0); + xla::ComputationDataHandle diag = context->Input(1); + + auto zero = XlaHelpers::Zero(builder, context->input_type(0)); + + // Create an indicator tensor that is true only on the diagonal. + xla::ComputationDataHandle iota_m; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, m, &iota_m)); + xla::ComputationDataHandle iota_n; + OP_REQUIRES_OK(context, XlaHelpers::Iota(builder, DT_INT32, n, &iota_n)); + auto indicator = builder->Eq(iota_m, + builder->Broadcast(iota_n, {m}), + /*broadcast_dimensions=*/{0}); + indicator = builder->Broadcast(indicator, batch_shape.dim_sizes()); + + // Broadcast diag up to the input shape. Use an implicit broadcast (Add) + // because we need to broadcast on the right. + std::vector diag_broadcast_dims(rank - 1); + std::iota(diag_broadcast_dims.begin(), diag_broadcast_dims.end(), 0); + if (min_dim != m) { + diag_broadcast_dims.back() = rank - 1; + } + diag = builder->Add(diag, builder->Broadcast(zero, input_shape.dim_sizes()), + /*broadcast_dimensions=*/diag_broadcast_dims); + + auto output = builder->Select(indicator, diag, input); + context->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(MatrixSetDiagOp); +}; + +REGISTER_XLA_OP(Name("MatrixSetDiag"), MatrixSetDiagOp); + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..eaed93146460de5a6e8328432302cc75bf36a534 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/matrix_triangular_solve_op.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/triangular_solve.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class MatrixTriangularSolveOp : public XlaOpKernel { + public: + explicit MatrixTriangularSolveOp(OpKernelConstruction* ctx) + : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("lower", &lower_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint", &adjoint_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + auto result = TriangularSolve( + ctx->builder(), ctx->Input(0), ctx->Input(1), /*left_side=*/true, + /*lower=*/lower_, /*transpose_a=*/adjoint_, /*conjugate_a=*/adjoint_); + if (!result.ok()) { + ctx->SetStatus(result.status()); + return; + } + ctx->SetOutput(0, result.ValueOrDie()); + } + + private: + bool lower_; + bool adjoint_; +}; + +REGISTER_XLA_OP(Name("MatrixTriangularSolve"), MatrixTriangularSolveOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index bea1d1600b5b5fc0c44f0208d394f25061ecbb68..05a36a031ad73be289604da1b7e56203ff12fbf5 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -92,7 +92,8 @@ class MirrorPadOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(MirrorPadOp); }; -REGISTER_XLA_OP(Name("MirrorPad"), MirrorPadOp); +REGISTER_XLA_OP(Name("MirrorPad").CompileTimeConstInput("paddings"), + MirrorPadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc index 2a9cfcb2eb86399bd446db8d591012a7a2f3d667..9f7c9913802d311895479b914b66553e135aa426 100644 --- a/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/one_hot_op.cc @@ -76,7 +76,7 @@ class OneHotOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(OneHotOp); }; -REGISTER_XLA_OP(Name("OneHot"), OneHotOp); +REGISTER_XLA_OP(Name("OneHot").CompileTimeConstInput("depth"), OneHotOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pad_op.cc b/tensorflow/compiler/tf2xla/kernels/pad_op.cc index d841bd37b33c31dbc156fa824ff62a58169a99cb..791351637aee61c5fdd911dd8a48959990514395 100644 --- a/tensorflow/compiler/tf2xla/kernels/pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/pad_op.cc @@ -83,8 +83,8 @@ class PadOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Pad"), PadOp); -REGISTER_XLA_OP(Name("PadV2"), PadOp); +REGISTER_XLA_OP(Name("Pad").CompileTimeConstInput("paddings"), PadOp); +REGISTER_XLA_OP(Name("PadV2").CompileTimeConstInput("paddings"), PadOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 2b6053d19dd64a0c893b3613133c8f4691f9cd27..d4fb5dd4e06c7c70591262c0d63a91c383a2a6e0 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -37,21 +37,23 @@ class PoolingOp : public XlaOpKernel { public: PoolingOp(OpKernelConstruction* ctx, int num_spatial_dims) : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { - std::vector ksize_int; - std::vector stride_int; - OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); - OP_REQUIRES(ctx, ksize_int.size() == num_dims(), - errors::InvalidArgument("Sliding window ksize field must " - "specify ", - num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); - OP_REQUIRES(ctx, stride_int.size() == num_dims(), - errors::InvalidArgument("Sliding window stride field must " - "specify ", - num_dims(), " dimensions")); - for (int i = 0; i < num_dims(); ++i) { - ksize_.push_back(ksize_int[i]); - stride_.push_back(stride_int[i]); + if (ctx->num_inputs() == 1) { + std::vector ksize_int; + std::vector stride_int; + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_int)); + OP_REQUIRES(ctx, ksize_int.size() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_int)); + OP_REQUIRES(ctx, stride_int.size() == num_dims(), + errors::InvalidArgument("Sliding window stride field must " + "specify ", + num_dims(), " dimensions")); + for (int i = 0; i < num_dims(); ++i) { + ksize_.push_back(ksize_int[i]); + stride_.push_back(stride_int[i]); + } } Padding padding; OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding)); @@ -77,6 +79,33 @@ class PoolingOp : public XlaOpKernel { xla::ComputationDataHandle input = ctx->Input(0); const TensorShape input_shape = ctx->InputShape(0); + std::vector ksize = ksize_; + std::vector stride = stride_; + if (ctx->num_inputs() != 1) { + const TensorShape ksize_shape = ctx->InputShape(1); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES(ctx, ksize_shape.num_elements() == num_dims(), + errors::InvalidArgument("Sliding window ksize field must " + "specify ", + num_dims(), " dimensions")); + ksize.clear(); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &ksize)); + + const TensorShape stride_shape = ctx->InputShape(2); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES(ctx, stride_shape.num_elements() == num_dims(), + errors::InvalidArgument("Sliding window stride field must " + "specify ", + num_dims(), " dimensions")); + stride.clear(); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &stride)); + } OP_REQUIRES(ctx, input_shape.dims() == num_dims(), errors::InvalidArgument("Input to ", type_string(), " operator must have ", num_dims(), @@ -84,8 +113,8 @@ class PoolingOp : public XlaOpKernel { const DataType type = input_type(0); xla::ComputationDataHandle pooled = ctx->builder()->ReduceWindow( - input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize_, - stride_, padding_); + input, InitValue(ctx->builder(), type), *Reduction(ctx, type), ksize, + stride, padding_); ctx->SetOutput(0, PostProcessOutput(ctx, pooled, type, input_shape)); } @@ -130,6 +159,10 @@ class MaxPool2DOp : public MaxPoolOp { } }; REGISTER_XLA_OP(Name("MaxPool"), MaxPool2DOp); +REGISTER_XLA_OP(Name("MaxPoolV2") + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DOp); class MaxPool3DOp : public MaxPoolOp { public: @@ -243,22 +276,44 @@ class MaxPoolGradOp : public XlaOpKernel { public: MaxPoolGradOp(OpKernelConstruction* ctx, int num_spatial_dims) : XlaOpKernel(ctx), num_spatial_dims_(num_spatial_dims) { - OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + if (ctx->num_inputs() == 3) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("ksize", &ksize_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); + } + OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); + } + + int num_dims() const { return num_spatial_dims_ + 2; } + + void Compile(XlaOpKernelContext* ctx) override { + if (ctx->num_inputs() != 3) { + OP_REQUIRES( + ctx, ctx->num_inputs() == 5, + errors::InvalidArgument("Must supply ksize and stride arguments.")); + const TensorShape ksize_shape = ctx->InputShape(3); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ksize_shape), + errors::InvalidArgument("ksize must be a vector, not shape ", + ksize_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(3, &ksize_)); + + const TensorShape stride_shape = ctx->InputShape(4); + // Validate input sizes. + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(stride_shape), + errors::InvalidArgument("stride must be a vector, not shape ", + stride_shape.DebugString())); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(4, &stride_)); + } + OP_REQUIRES(ctx, ksize_.size() == num_dims(), errors::InvalidArgument("Sliding window ksize field must " "specify ", num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &stride_)); OP_REQUIRES(ctx, stride_.size() == num_dims(), errors::InvalidArgument("Sliding window strides field must " "specify ", num_dims(), " dimensions")); - OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_)); - } - int num_dims() const { return num_spatial_dims_ + 2; } - - void Compile(XlaOpKernelContext* ctx) override { const TensorShape tensor_in_shape = ctx->InputShape(0); const TensorShape tensor_out_shape = ctx->InputShape(1); const TensorShape out_backprop_shape = ctx->InputShape(2); @@ -315,6 +370,10 @@ class MaxPool2DGradOp : public MaxPoolGradOp { } }; REGISTER_XLA_OP(Name("MaxPoolGrad"), MaxPool2DGradOp); +REGISTER_XLA_OP(Name("MaxPoolGradV2") + .CompileTimeConstInput("ksize") + .CompileTimeConstInput("strides"), + MaxPool2DGradOp); class MaxPool3DGradOp : public MaxPoolGradOp { public: @@ -455,14 +514,16 @@ class AvgPool2DGradOp : public AvgPoolGradOp { errors::InvalidArgument("Invalid data format")); } }; -REGISTER_XLA_OP(Name("AvgPoolGrad"), AvgPool2DGradOp); +REGISTER_XLA_OP(Name("AvgPoolGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool2DGradOp); class AvgPool3DGradOp : public AvgPoolGradOp { public: explicit AvgPool3DGradOp(OpKernelConstruction* ctx) : AvgPoolGradOp(ctx, /*num_spatial_dims=*/3) {} }; -REGISTER_XLA_OP(Name("AvgPool3DGrad"), AvgPool3DGradOp); +REGISTER_XLA_OP(Name("AvgPool3DGrad").CompileTimeConstInput("orig_input_shape"), + AvgPool3DGradOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 2421825ead17a3acee9f145f00904d382fb656f4..c0994c434bca5174eaee7b9e63e10432d9c2ed8d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -52,7 +52,8 @@ class RandomUniformOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformOp); }; -REGISTER_XLA_OP(Name("RandomUniform"), RandomUniformOp); +REGISTER_XLA_OP(Name("RandomUniform").CompileTimeConstInput("shape"), + RandomUniformOp); class RandomUniformIntOp : public XlaOpKernel { public: @@ -83,7 +84,8 @@ class RandomUniformIntOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomUniformIntOp); }; -REGISTER_XLA_OP(Name("RandomUniformInt"), RandomUniformIntOp); +REGISTER_XLA_OP(Name("RandomUniformInt").CompileTimeConstInput("shape"), + RandomUniformIntOp); class RandomStandardNormalOp : public XlaOpKernel { public: @@ -111,7 +113,8 @@ class RandomStandardNormalOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(RandomStandardNormalOp); }; -REGISTER_XLA_OP(Name("RandomStandardNormal"), RandomStandardNormalOp); +REGISTER_XLA_OP(Name("RandomStandardNormal").CompileTimeConstInput("shape"), + RandomStandardNormalOp); class TruncatedNormalOp : public XlaOpKernel { public: @@ -183,7 +186,8 @@ class TruncatedNormalOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("TruncatedNormal"), TruncatedNormalOp); +REGISTER_XLA_OP(Name("TruncatedNormal").CompileTimeConstInput("shape"), + TruncatedNormalOp); } // anonymous namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index 647b6274083cf8886af6c451b746416445a4a2b2..03b13b2924f4b81c1017804c91d5ffb81c44ea0b 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -35,7 +35,7 @@ class SumOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Sum"), SumOp); +REGISTER_XLA_OP(Name("Sum").CompileTimeConstInput("reduction_indices"), SumOp); class ProdOp : public XlaReductionOp { public: @@ -53,7 +53,8 @@ class ProdOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Prod"), ProdOp); +REGISTER_XLA_OP(Name("Prod").CompileTimeConstInput("reduction_indices"), + ProdOp); class MinOp : public XlaReductionOp { public: @@ -73,7 +74,7 @@ class MinOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Min"), MinOp); +REGISTER_XLA_OP(Name("Min").CompileTimeConstInput("reduction_indices"), MinOp); class MaxOp : public XlaReductionOp { public: @@ -93,7 +94,7 @@ class MaxOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Max"), MaxOp); +REGISTER_XLA_OP(Name("Max").CompileTimeConstInput("reduction_indices"), MaxOp); class MeanOp : public XlaReductionOp { public: @@ -115,7 +116,8 @@ class MeanOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Mean"), MeanOp); +REGISTER_XLA_OP(Name("Mean").CompileTimeConstInput("reduction_indices"), + MeanOp); class AllOp : public XlaReductionOp { public: @@ -133,7 +135,7 @@ class AllOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("All"), AllOp); +REGISTER_XLA_OP(Name("All").CompileTimeConstInput("reduction_indices"), AllOp); class AnyOp : public XlaReductionOp { public: @@ -151,7 +153,7 @@ class AnyOp : public XlaReductionOp { } }; -REGISTER_XLA_OP(Name("Any"), AnyOp); +REGISTER_XLA_OP(Name("Any").CompileTimeConstInput("reduction_indices"), AnyOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index 5952e752724d1e6953dd4dbb6a8099b847c64d08..af4d64b159c09ed7e01017f25a2b23e58542dc3c 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -95,7 +95,7 @@ class ReshapeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reshape"), ReshapeOp); +REGISTER_XLA_OP(Name("Reshape").CompileTimeConstInput("shape"), ReshapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc index 7489321f72f50c8f55f8da9dabb9f4b5c7797195..e51d386926763ecbb5a943dfb6f872e78901dc69 100644 --- a/tensorflow/compiler/tf2xla/kernels/reverse_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reverse_op.cc @@ -16,7 +16,6 @@ limitations under the License. // XLA-specific reverse Op. #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -53,7 +52,8 @@ class ReverseOp : public XlaOpKernel { xla::Literal lax; OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {x_shape.dims()}, &lax)); std::vector revdims(x_shape.dims()); - std::copy(lax.preds().begin(), lax.preds().end(), revdims.begin()); + std::copy(lax.data().begin(), lax.data().end(), + revdims.begin()); std::vector dimensions; for (int d = 0; d < x_shape.dims(); ++d) { @@ -66,7 +66,7 @@ class ReverseOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Reverse"), ReverseOp); +REGISTER_XLA_OP(Name("Reverse").CompileTimeConstInput("dims"), ReverseOp); class ReverseV2Op : public XlaOpKernel { public: @@ -104,7 +104,7 @@ class ReverseV2Op : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("ReverseV2"), ReverseV2Op); +REGISTER_XLA_OP(Name("ReverseV2").CompileTimeConstInput("axis"), ReverseV2Op); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bc5d3adb091cd238974c5b69b7a2f8fe639cc68 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/reverse_sequence_op.cc @@ -0,0 +1,182 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class ReverseSequenceOp : public XlaOpKernel { + public: + explicit ReverseSequenceOp(OpKernelConstruction* context) + : XlaOpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("batch_dim", &batch_dim_)); + OP_REQUIRES_OK(context, context->GetAttr("seq_dim", &seq_dim_)); + } + + void Compile(XlaOpKernelContext* context) override { + const TensorShape input_shape = context->InputShape(0); + const TensorShape seq_lens_shape = context->InputShape(1); + + OP_REQUIRES(context, TensorShapeUtils::IsVector(seq_lens_shape), + errors::InvalidArgument("seq_lens input must be 1-dim, not ", + seq_lens_shape.dims())); + OP_REQUIRES(context, batch_dim_ != seq_dim_, + errors::InvalidArgument("batch_dim == seq_dim == ", seq_dim_)); + OP_REQUIRES( + context, seq_dim_ < input_shape.dims(), + errors::InvalidArgument("seq_dim must be < input.dims()", "( ", + seq_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, batch_dim_ < input_shape.dims(), + errors::InvalidArgument("batch_dim must be < input.dims()", "( ", + batch_dim_, " vs. ", input_shape.dims(), ")")); + OP_REQUIRES( + context, + seq_lens_shape.num_elements() == input_shape.dim_size(batch_dim_), + errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim_, + "), ", "(", seq_lens_shape.num_elements(), + " vs. ", input_shape.dim_size(batch_dim_))); + + xla::ComputationBuilder* builder = context->builder(); + const auto input = context->Input(0); + const auto seq_lens = context->Input(1); + + const int64 batch_size = input_shape.dim_size(batch_dim_); + + const DataType input_type = context->input_type(0); + const DataType seq_lens_type = context->input_type(1); + const int64 max_seq_len = input_shape.dim_size(seq_dim_); + + xla::Shape input_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(input_type, input_shape, + &input_xla_shape)); + xla::Shape seq_lens_xla_shape; + OP_REQUIRES_OK(context, TensorShapeToXLAShape(seq_lens_type, seq_lens_shape, + &seq_lens_xla_shape)); + + const auto tuple_shape = xla::ShapeUtil::MakeTupleShape({ + xla::ShapeUtil::MakeShape(seq_lens_xla_shape.element_type(), {}), + seq_lens_xla_shape, + input_xla_shape, + }); + + // For each entry in the batch, reverse the sequence. + // TODO(b/65689298): generalize the Map() operator to non-scalar cases and + // use it here, instead of a While loop. + + // Condition: lambda (i, _, _): i < batch_size + auto condition_builder = + builder->CreateSubBuilder("reverse_sequence_condition"); + { + auto param = condition_builder->Parameter(0, tuple_shape, "param"); + auto i = condition_builder->GetTupleElement(param, 0); + condition_builder->Lt( + i, XlaHelpers::IntegerLiteral(condition_builder.get(), seq_lens_type, + batch_size)); + } + auto condition = condition_builder->Build(); + OP_REQUIRES_OK(context, condition.status()); + + auto body_builder = builder->CreateSubBuilder("reverse_sequence_body"); + { + auto param = body_builder->Parameter(0, tuple_shape, "param"); + auto i = body_builder->GetTupleElement(param, 0); + auto seq_lens = body_builder->GetTupleElement(param, 1); + auto output = body_builder->GetTupleElement(param, 2); + + // seq_len is the sequence length of the current batch element (rank 1) + auto seq_len = body_builder->DynamicSlice( + seq_lens, body_builder->Reshape(i, {1}), {1}); + + // Indices is the offset of the batch element in the input. + auto indices = body_builder->Broadcast( + XlaHelpers::Zero(body_builder.get(), seq_lens_type), + {input_shape.dims()}); + indices = body_builder->DynamicUpdateSlice( + indices, body_builder->Reshape(i, {1}), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + batch_dim_), + {1})); + + // slice_indices is the offset of the start of the reversed sequence in + // the input. + auto slice_indices = body_builder->DynamicUpdateSlice( + indices, + body_builder->Sub(XlaHelpers::IntegerLiteral( + body_builder.get(), seq_lens_type, max_seq_len), + seq_len), + body_builder->Reshape( + XlaHelpers::IntegerLiteral(body_builder.get(), seq_lens_type, + seq_dim_), + {1})); + + // Slice out the reversed sequence. The slice will overflow the end of the + // sequence, and the contents of the overflow are implementation-defined. + // However, we will mask off these elements and replace them with elements + // from the original input so their values do not matter. + TensorShape slice_shape = input_shape; + slice_shape.set_dim(batch_dim_, 1); + auto slice = body_builder->DynamicSlice(output, slice_indices, + slice_shape.dim_sizes()); + + // Shift the reversed sequence to the left. + output = body_builder->DynamicUpdateSlice(output, slice, indices); + + body_builder->Tuple( + {body_builder->Add( + i, XlaHelpers::One(body_builder.get(), seq_lens_type)), + seq_lens, output}); + } + auto body = body_builder->Build(); + OP_REQUIRES_OK(context, body.status()); + + auto loop_output = builder->While( + condition.ValueOrDie(), body.ValueOrDie(), + builder->Tuple({XlaHelpers::Zero(builder, seq_lens_type), seq_lens, + builder->Rev(input, {seq_dim_})})); + auto output = builder->GetTupleElement(loop_output, 2); + + // Mask out elements after the sequence length. + xla::ComputationDataHandle iota; + OP_REQUIRES_OK( + context, XlaHelpers::Iota(builder, seq_lens_type, max_seq_len, &iota)); + std::vector dims(input_shape.dims(), 1); + dims[batch_dim_] = batch_size; + auto mask = builder->Lt(iota, builder->Reshape(seq_lens, dims), {seq_dim_}); + + // Broadcast the mask up to the input shape. + mask = + builder->Or(mask, builder->Broadcast(builder->ConstantR0(false), + input_shape.dim_sizes())); + + output = builder->Select(mask, output, input); + context->SetOutput(0, output); + } + + private: + int32 batch_dim_; + int32 seq_dim_; +}; + +REGISTER_XLA_OP(Name("ReverseSequence"), ReverseSequenceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee4a94164c4a43828eb4feedbfa9d1a9e231ef8f --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/concat_lib.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { + +// TODO(phawkins): implement double-sized windowed reductions in XLA and remove +// the type constraint. +constexpr std::array kScanOpTypes = { + {DT_HALF, DT_BFLOAT16, DT_FLOAT}}; + +class ScanOp : public XlaOpKernel { + public: + ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape tensor_axis_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape), + errors::InvalidArgument("ScanOp: axis must be a scalar, not ", + tensor_axis_shape.DebugString())); + + int64 axis; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis)); + if (axis < 0) { + axis += input_shape.dims(); + } + OP_REQUIRES( + ctx, FastBoundsCheck(axis, input_shape.dims()), + errors::InvalidArgument("ScanOp: Expected scan axis in the range [", + -input_shape.dims(), ", ", input_shape.dims(), + "), but got ", axis)); + + DataType dtype = ctx->input_type(0); + + if (input_shape.num_elements() == 0) { + // Exit early if there is nothing to compute. + ctx->SetOutput(0, ctx->Input(0)); + return; + } + + xla::ComputationBuilder* builder = ctx->builder(); + + std::vector window_strides(input_shape.dims(), 1); + std::vector window_dims(input_shape.dims(), 1); + window_dims[axis] = input_shape.dim_size(axis); + + std::vector> padding(input_shape.dims(), {0, 0}); + padding[axis].first = input_shape.dim_size(axis) - 1; + // In exclusive mode, add an extra padding element so there is a complete + // window of padding before the data starts. + if (exclusive_) { + ++padding[axis].first; + } + if (reverse_) { + std::swap(padding[axis].first, padding[axis].second); + } + + xla::ComputationDataHandle input = ctx->Input(0); + xla::ComputationDataHandle init; + const xla::Computation* reducer; + if (sum_) { + init = XlaHelpers::Zero(builder, dtype); + reducer = ctx->GetOrCreateAdd(dtype); + } else { + init = XlaHelpers::One(builder, dtype); + reducer = ctx->GetOrCreateMul(dtype); + } + auto output = builder->ReduceWindowWithGeneralPadding( + ctx->Input(0), init, *reducer, window_dims, window_strides, padding); + + // In exclusive mode, we have computed an extra element containing the sum + // of all the input elements. Slice off this extra "last" element. + if (exclusive_) { + if (reverse_) { + output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1, + 1, axis); + + } else { + output = + builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis); + } + } + ctx->SetOutput(0, output); + } + + private: + const bool sum_; // True=cumulative sum. False=cumulative product. + bool reverse_; + bool exclusive_; +}; + +class CumsumOp : public ScanOp { + public: + explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {} +}; +REGISTER_XLA_OP(Name("Cumsum") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumsumOp); + +class CumprodOp : public ScanOp { + public: + explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {} +}; +REGISTER_XLA_OP(Name("Cumprod") + .TypeConstraint("T", kScanOpTypes) + .CompileTimeConstInput("axis"), + CumprodOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8433a29c4e203cac726ee6bf7f67a863447326ed --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/scatter_nd_op.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace { + +// Check whether updates.shape = indices.shape[:batch_dim] + +// buffer_shape[num_index_dims:] +Status ValidateUpdateShape(const TensorShape& buffer_shape, + const TensorShape& indices_shape, + const TensorShape& updates_shape) { + if (indices_shape.dims() < 1) { + return errors::InvalidArgument( + "indices shape must have >= 1 dimension; got ", + indices_shape.DebugString()); + } + + const int64 num_index_dims = indices_shape.dim_size(indices_shape.dims() - 1); + const int64 batch_dim = indices_shape.dims() - 1; + + auto shape_err = [&]() { + return errors::InvalidArgument( + "Must have updates.shape = indices.shape[:batch_dim] + ", + "buffer_shape[num_index_dims:], got updates.shape: ", + updates_shape.DebugString(), + ", indices.shape: ", indices_shape.DebugString(), + ", buffer_shape: ", buffer_shape.DebugString(), + ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim); + }; + + if (updates_shape.dims() < batch_dim) return shape_err(); + if (buffer_shape.dims() < + num_index_dims + (updates_shape.dims() - batch_dim)) { + return shape_err(); + } + if (updates_shape.dims() != + batch_dim + buffer_shape.dims() - num_index_dims) { + return shape_err(); + } + for (int d = 0; d < batch_dim; ++d) { + if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) { + return shape_err(); + } + } + for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) { + if (updates_shape.dim_size(d + batch_dim) != + buffer_shape.dim_size(d + num_index_dims)) { + return shape_err(); + } + } + return Status::OK(); +} + +class ScatterNdOp : public XlaOpKernel { + public: + explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {} + + void Compile(XlaOpKernelContext* context) override { + DataType dtype = context->input_type(1); + + TensorShape indices_shape = context->InputShape(0); + TensorShape updates_shape = context->InputShape(1); + + TensorShape buffer_shape; + OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape)); + + OP_REQUIRES( + context, TensorShapeUtils::IsVectorOrHigher(buffer_shape), + errors::InvalidArgument("Output must be at least 1-D, ", + "got shape: ", buffer_shape.DebugString())); + + OP_REQUIRES( + context, + buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 && + updates_shape.num_elements() == 0), + errors::InvalidArgument( + "Indices and updates specified for empty output. indices shape: ", + indices_shape.DebugString())); + + OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape, + updates_shape)); + + xla::ComputationBuilder* builder = context->builder(); + auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + buffer_shape.dim_sizes()); + auto indices = context->Input(0); + auto updates = context->Input(1); + auto result = + XlaScatter(buffer, updates, indices, + /*indices_are_vectors=*/true, /*combiner=*/{}, builder); + OP_REQUIRES_OK(context, result.status()); + context->SetOutput(0, result.ValueOrDie()); + } +}; + +REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstInput("shape"), ScatterNdOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h deleted file mode 100644 index a5ab7de17adb734014fe2dcbd60ae5c219c8e486..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/scatter_op_helpers.h +++ /dev/null @@ -1,39 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -// Helper methods for XLA Scatter Ops. -#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ -#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ - -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/util/bcast.h" - -namespace tensorflow { - -// Adds to builder an XLA computation that performs a scatter-add of input (of -// shape input_shape) keyed on indices (of shape indices_shape). The shape -// of the Tensor returned by this is num_segments input_shape[indices.dims():] -// -static xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 num_segments, DataType dtype, - xla::ComputationBuilder* builder); - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SCATTER_OP_HELPERS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc index 8a67c0b67fcd95f4841c5e011a4e51638eea5b0f..80d6df6c48b0141734dcee1c2a3c413926931feb 100644 --- a/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/segment_reduction_ops.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,143 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" -#include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/lib/scatter.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/computation_builder.h" -#include "tensorflow/compiler/xla/literal_util.h" -#include "tensorflow/core/framework/kernel_def_builder.h" -#include "tensorflow/core/framework/types.h" namespace tensorflow { - -xla::ComputationDataHandle XlaComputeScatterAddDynamicSlice( - XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, - const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, int64 num_segments, DataType dtype, - xla::ComputationBuilder* builder) { - // Flatten data for dynamic indexing via indices_1d. - TensorShape input_shape_i(input_shape); - for (int64 d = 0; d < indices_shape.dims(); ++d) { - input_shape_i.RemoveDim(0); - } - TensorShape flat_shape({indices_shape.num_elements()}); - flat_shape.AppendShape(input_shape_i); - - // output is same as flattened input shape with dim_size(0) = num_segments. - TensorShape out_shape(flat_shape); - out_shape.set_dim(0, num_segments); - - // TODO(b/37575001) The tensor in which we construct the output during - // the loop must have rank >= 3 as a workaround for lowering issues. - int64 extra_dims = 0; - if (out_shape.dims() < 3) { - extra_dims = 3 - out_shape.dims(); - } - TensorShape loop_out_shape; - for (int64 k = 0; k < extra_dims; ++k) { - loop_out_shape.AddDim(1); - } - loop_out_shape.AppendShape(out_shape); - - // Slices from the input data are same shape as the input data, except dim 0. - TensorShape slice_shape(flat_shape); - slice_shape.set_dim(0, 1); - // slices are reshaped into the rank >= 3 shape of the loop-carried output - TensorShape loop_out_slice_shape(loop_out_shape); - loop_out_slice_shape.set_dim(extra_dims, 1); - - // Construct the initial values of the loop-carried variables - // Flatten the indices into 1-D for ease of iteration. - auto indices_1d = builder->Reshape(indices, {indices_shape.num_elements()}); - // Flatten the data for ease of indexing via values in indices_1d. - auto data_flat = builder->Reshape(input, flat_shape.dim_sizes()); - - auto init_i = builder->ConstantR0(0); - auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), - loop_out_shape.dim_sizes()); - - xla::PrimitiveType ptype; - TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); - - std::vector tuple_shapes( - {// The loop iteration counter is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), - // The flattened input data is loop invariant. - xla::ShapeUtil::MakeShape(ptype, flat_shape.dim_sizes()), - // The scatter indices tensor is loop invariant. - xla::ShapeUtil::MakeShape(xla::S32, {indices_shape.num_elements()}), - // The output data array is updated each loop iteration. - xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); - - auto init = builder->Tuple({init_i, data_flat, indices_1d, init_out}); - - // Construct the while loop condition (i < num_indices) - xla::ComputationBuilder condb(ctx->builder()->client(), - "ScatterAddWhileCond"); - condb.Lt(condb.GetTupleElement( - condb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"), 0), - condb.ConstantR0(indices_shape.num_elements())); - auto cond_status = condb.Build(); - // TF_CHECK_OK(cond_status); - auto cond = cond_status.ConsumeValueOrDie(); - - // Construct the while loop body's function. The implementation of scatter is: - // for i in range(num_indices): - // index = dynamic-slice(indices, i) - // xi = dynamic-slice(input, i) - // output = dynamic-update-slice(output, xi, index) - xla::ComputationBuilder bodyb(ctx->builder()->client(), - "ScatterAddWhileBody"); - { - auto input_tuple = bodyb.Parameter(0, tuple_shape, "ScatterAddWhileTuple"); - auto i = bodyb.GetTupleElement(input_tuple, 0); - auto data = bodyb.GetTupleElement(input_tuple, 1); - auto idcs = bodyb.GetTupleElement(input_tuple, 2); - auto output = bodyb.GetTupleElement(input_tuple, 3); - - // Index into the data array at i. - auto zero = bodyb.ConstantR1({0}); - std::vector index_vals(flat_shape.dims(), zero); - index_vals[0] = bodyb.Reshape(i, {1}); - auto index = bodyb.ConcatInDim(index_vals, 0); - - auto data_slice = - bodyb.Reshape(bodyb.DynamicSlice(data, index, slice_shape.dim_sizes()), - loop_out_slice_shape.dim_sizes()); - - // Index into the output array. - // Construct the index into the R3+ output array 0, ..., , 0, ... - std::vector out_index_vals( - loop_out_shape.dims(), zero); - out_index_vals[extra_dims] = - bodyb.DynamicSlice(idcs, bodyb.Reshape(i, {1}), {1}); - auto out_index = bodyb.ConcatInDim(out_index_vals, 0); - - // Slice the output array, update value, and update the output slice. - auto updated_output = bodyb.DynamicUpdateSlice( - output, - bodyb.Add(data_slice, - bodyb.DynamicSlice(output, out_index, - loop_out_slice_shape.dim_sizes())), - out_index); - - auto ip1 = bodyb.Add(i, bodyb.ConstantR0(1)); - bodyb.Tuple({ip1, data, idcs, updated_output}); - } - auto body_status = bodyb.Build(); - // TF_CHECK_OK(body_status); - auto body = body_status.ConsumeValueOrDie(); - - auto gather_while = builder->While(cond, body, init); - auto updated_output = builder->GetTupleElement(gather_while, 3); - return builder->Reshape(updated_output, out_shape.dim_sizes()); -} - namespace { class UnsortedSegmentSum : public XlaOpKernel { @@ -171,10 +41,10 @@ class UnsortedSegmentSum : public XlaOpKernel { // as data with the first indices.rank dimensions are replaced // by a single dimension with size num_segments. auto data = ctx->Input(0); - auto data_shape = ctx->InputShape(0); + TensorShape data_shape = ctx->InputShape(0); auto indices = ctx->Input(1); - auto indices_shape = ctx->InputShape(1); + TensorShape indices_shape = ctx->InputShape(1); int64 num_segments; OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); @@ -192,10 +62,21 @@ class UnsortedSegmentSum : public XlaOpKernel { d, " differs ", data_shape.dim_size(d), " vs. ", indices_shape.dim_size(d))); } - auto result = XlaComputeScatterAddDynamicSlice( - ctx, data, data_shape, indices, indices_shape, num_segments, dtype_, - ctx->builder()); - ctx->SetOutput(0, result); + xla::ComputationBuilder* builder = ctx->builder(); + TensorShape buffer_shape = data_shape; + buffer_shape.RemoveDimRange(0, indices_shape.dims()); + buffer_shape.InsertDim(0, num_segments); + auto buffer = builder->Broadcast(XlaHelpers::Zero(builder, dtype_), + buffer_shape.dim_sizes()); + + auto combiner = + [](xla::ComputationDataHandle a, xla::ComputationDataHandle b, + xla::ComputationBuilder* builder) { return builder->Add(a, b); }; + + auto result = XlaScatter(buffer, /*updates=*/data, indices, + /*indices_are_vectors=*/false, combiner, builder); + OP_REQUIRES_OK(ctx, result.status()); + ctx->SetOutput(0, result.ValueOrDie()); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc index c2b0e1bb4c1a141d0ab3f5b3ff5397d9da620bd8..2c31f8d90891924f6f86a54ccf548de4df87f3bd 100644 --- a/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sequence_ops.cc @@ -138,7 +138,11 @@ class RangeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Range"), RangeOp); +REGISTER_XLA_OP(Name("Range") + .CompileTimeConstInput("start") + .CompileTimeConstInput("limit") + .CompileTimeConstInput("delta"), + RangeOp); class LinSpaceOp : public XlaOpKernel { public: @@ -207,7 +211,11 @@ class LinSpaceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("LinSpace"), LinSpaceOp); +REGISTER_XLA_OP(Name("LinSpace") + .CompileTimeConstInput("start") + .CompileTimeConstInput("stop") + .CompileTimeConstInput("num"), + LinSpaceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index 24a99f253d6dc8bb699fff587c363b12c227e821..05354bca5bb089703fdcceb6f44648bbb98d004b 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -15,6 +15,7 @@ limitations under the License. // XLA-specific Shape Ops. +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -27,56 +28,42 @@ namespace { class ShapeOp : public XlaOpKernel { public: - explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_shape = ctx->InputShape(0); - const int rank = input_shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({rank})); - auto vec = shape_constant.vec(); - // TODO(dga): support int64. b/28119922. - for (int i = 0; i < rank; ++i) { - int64 dim_size = input_shape.dim_size(i); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but dim ", i, " is ", dim_size)); - vec(i) = static_cast(dim_size); - } - + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(0, shape_constant); } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("Shape"), ShapeOp); class ShapeNOp : public XlaOpKernel { public: - explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } void Compile(XlaOpKernelContext* ctx) override { for (int i = 0; i < ctx->num_inputs(); ++i) { - const TensorShape shape = ctx->InputShape(i); - const int dims = shape.dims(); - Tensor shape_constant(DT_INT32, TensorShape({dims})); - auto vec = shape_constant.vec(); - - // TODO(dga): support int64. b/28119922. - for (int j = 0; j < dims; ++j) { - int64 dim_size = shape.dim_size(j); - OP_REQUIRES( - ctx, FastBoundsCheck(dim_size, std::numeric_limits::max()), - errors::InvalidArgument("Shape does not support tensors > int32max", - " but shape ", i, " dim ", j, " is ", - dim_size)); - vec(j) = static_cast(dim_size); - } - + const TensorShape input_shape = ctx->InputShape(i); + Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); ctx->SetConstantOutput(i, shape_constant); } } bool IsExpensive() override { return false; } + + private: + DataType out_dtype_; }; REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp); @@ -134,7 +121,7 @@ class ExpandDimsOp : public XlaOpKernel { xla::Literal literal; OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {1}, &literal)); - int dim = literal.s32s(0); + int dim = literal.data()[0]; OP_REQUIRES(ctx, (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), @@ -163,7 +150,7 @@ class ExpandDimsOp : public XlaOpKernel { ctx->SetOutput(0, ctx->builder()->Reshape(ctx->Input(0), new_shape)); } }; -REGISTER_XLA_OP(Name("ExpandDims"), ExpandDimsOp); +REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstInput("dim"), ExpandDimsOp); class SqueezeOp : public XlaOpKernel { public: diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.cc b/tensorflow/compiler/tf2xla/kernels/shape_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..76ea5f525598f511f295eb5a30f3cf603fbf57aa --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.cc @@ -0,0 +1,48 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" + +#include + +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { + +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant) { + const int dims = input_shape.dims(); + if (shape_constant->dtype() == DT_INT32) { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + if (!FastBoundsCheck(dim_size, std::numeric_limits::max())) { + return errors::InvalidArgument( + "Shape with out_type=int32 does not support tensors > int32max", + " but dim ", i, " is ", dim_size); + } + vec(i) = static_cast(dim_size); + } + } else { + auto vec = shape_constant->vec(); + for (int i = 0; i < dims; ++i) { + int64 dim_size = input_shape.dim_size(i); + vec(i) = dim_size; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/shape_util.h b/tensorflow/compiler/tf2xla/kernels/shape_util.h new file mode 100644 index 0000000000000000000000000000000000000000..ca57be3d47b95d71b07746e50256070e0a4f4c09 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/shape_util.h @@ -0,0 +1,34 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Converts a TensorShape to a constant Tensor. +// +// The input TensorShape input_shape is used to populate the elements of +// shape_constant, which is modified in place. +Status TensorShapeToConstant(const TensorShape& input_shape, + Tensor* shape_constant); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index fbe8c78d8fb5f800967942555531a50937cad0ca..be1e97bf26fa4cde1b741c8d0b843a85ce33a59c 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -112,7 +112,9 @@ class SliceOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Slice"), SliceOp); +REGISTER_XLA_OP( + Name("Slice").CompileTimeConstInput("begin").CompileTimeConstInput("size"), + SliceOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index 83a87f19a718ce86a105e3c33ab9eaf0faff3a76..01b46e160d1f1f10a43faf7ca35afb42dfde6e33 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -162,7 +162,10 @@ class SpaceToBatchNDOp : public XlaOpKernel { block_shape, paddings); } }; -REGISTER_XLA_OP(Name("SpaceToBatchND"), SpaceToBatchNDOp); +REGISTER_XLA_OP(Name("SpaceToBatchND") + .CompileTimeConstInput("paddings") + .CompileTimeConstInput("block_shape"), + SpaceToBatchNDOp); class SpaceToBatchOp : public XlaOpKernel { public: @@ -184,7 +187,8 @@ class SpaceToBatchOp : public XlaOpKernel { private: int block_size_; }; -REGISTER_XLA_OP(Name("SpaceToBatch"), SpaceToBatchOp); +REGISTER_XLA_OP(Name("SpaceToBatch").CompileTimeConstInput("paddings"), + SpaceToBatchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc index 89befda346ec06fec23ab1d1c9d910ded8cd806d..806fda632cde64c1b37ae3b9199028d6b6b0a215 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { namespace { @@ -23,6 +24,16 @@ namespace { class SpaceToDepthOp : public XlaOpKernel { public: explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + string data_format_str; + OP_REQUIRES_OK(ctx, ctx->GetAttr("data_format", &data_format_str)); + OP_REQUIRES(ctx, FormatFromString(data_format_str, &data_format_), + errors::InvalidArgument("Invalid data format")); + + OP_REQUIRES(ctx, data_format_ == FORMAT_NCHW || data_format_ == FORMAT_NHWC, + errors::InvalidArgument("Unsupported data format ", + ToString(data_format_), + "; expected formats NHWC or NCHW")); + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); OP_REQUIRES( ctx, block_size_ > 1, @@ -31,34 +42,100 @@ class SpaceToDepthOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { const TensorShape input_tensor_shape = ctx->InputShape(0); - // The input is presumed to be [batch, height, width, depth] int input_rank = input_tensor_shape.dims(); static const int kRequiredDims = 4; OP_REQUIRES(ctx, kRequiredDims == input_rank, - errors::InvalidArgument("Input rank should be: ", kRequiredDims, - " instead of: ", input_rank)); + errors::InvalidArgument("Input rank should be ", kRequiredDims, + "; got ", input_rank)); const gtl::InlinedVector input_shape = input_tensor_shape.dim_sizes(); xla::ComputationBuilder* b = ctx->builder(); xla::ComputationDataHandle input = ctx->Input(0); + int feature_dim = GetTensorFeatureDimIndex(input_rank, data_format_); + int num_spatial_dims = GetTensorSpatialDims(input_rank, data_format_); + + std::vector reshaped_shape; + std::vector transpose_order; + std::vector output_shape; + reshaped_shape.reserve(input_rank); + transpose_order.reserve(input_rank); + output_shape.reserve(input_rank); + if (data_format_ == FORMAT_NHWC) { + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 1 + i, "]=", input_shape[1 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[1 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + reshaped_shape.push_back(input_shape[feature_dim]); + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 1); + } + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + transpose_order.push_back(feature_dim + num_spatial_dims); + + output_shape.push_back(input_shape[0]); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[1 + i] / block_size_); + } + output_shape.push_back(input_shape[feature_dim] * block_elems); + } else { + // FORMAT_NCHW + int64 block_elems = 1; + for (int i = 0; i < num_spatial_dims; ++i) { + OP_REQUIRES(ctx, input_shape[2 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 2 + i, "]=", input_shape[2 + i], + " is not divisible by block_size=", block_size_)); + block_elems *= block_size_; + } + + reshaped_shape.push_back(input_shape[0]); + reshaped_shape.push_back(input_shape[feature_dim]); + for (int i = 0; i < num_spatial_dims; ++i) { + reshaped_shape.push_back(input_shape[2 + i] / block_size_); + reshaped_shape.push_back(block_size_); + } + + transpose_order.push_back(0); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 3); + } + transpose_order.push_back(feature_dim); + for (int i = 0; i < num_spatial_dims; ++i) { + transpose_order.push_back(i * 2 + 2); + } + + output_shape.push_back(input_shape[0]); + output_shape.push_back(input_shape[feature_dim] * block_elems); + for (int i = 0; i < num_spatial_dims; ++i) { + output_shape.push_back(input_shape[2 + i] / block_size_); + } + } + + // Note: comments are given in NHWC format; NCHW is similar with a different + // dimension order. // 1. Reshape `input` to `reshaped` of shape: // // [batch, // input_shape[1] / block_size_, block_size_, // input_shape[2] / block_size_, block_size_, // depth] - const int block_rank = 2; - for (int i = 0; i < block_rank; ++i) { - OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, - errors::InvalidArgument( - "input shape[", 1 + i, "]=", input_shape[1 + i], - " is not divisible by block_size=", block_size_)); - } - xla::ComputationDataHandle reshaped = b->Reshape( - input, {input_shape[0], input_shape[1] / block_size_, block_size_, - input_shape[2] / block_size_, block_size_, input_shape[3]}); + xla::ComputationDataHandle reshaped = b->Reshape(input, reshaped_shape); // 2. Permute dimensions of `reshaped` to produce // `permuted_reshaped` of shape: @@ -69,7 +146,7 @@ class SpaceToDepthOp : public XlaOpKernel { // block_size_, block_size_, // depth] xla::ComputationDataHandle permuted_reshaped = - b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + b->Transpose(reshaped, transpose_order); // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the // batch dimension, producing an output tensor of shape: @@ -79,15 +156,14 @@ class SpaceToDepthOp : public XlaOpKernel { // input_shape[2] / block_size_, // block_size_ * block_size_ * depth] // - xla::ComputationDataHandle output = b->Reshape( - permuted_reshaped, {input_shape[0], input_shape[1] / block_size_, - input_shape[2] / block_size_, - block_size_ * block_size_ * input_shape[3]}); + xla::ComputationDataHandle output = + b->Reshape(permuted_reshaped, output_shape); ctx->SetOutput(0, output); } private: + TensorFormat data_format_; int block_size_; }; REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp); diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 795eb1794f577e0f7fd2a2068878e540ff0c1a1d..79c435c90a1f57250be90c2c2523bf3d7d231461 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -103,7 +103,7 @@ class SplitOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Split"), SplitOp); +REGISTER_XLA_OP(Name("Split").CompileTimeConstInput("split_dim"), SplitOp); class SplitVOp : public XlaOpKernel { public: @@ -142,8 +142,9 @@ class SplitVOp : public XlaOpKernel { int neg_one_dim = -1; std::vector split_sizes_vec(num_split, -1); const TensorShape split_size_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, split_size_shape.dims() == 1 && - split_size_shape.num_elements() == num_split, + OP_REQUIRES(ctx, + split_size_shape.dims() == 1 && + split_size_shape.num_elements() == num_split, errors::InvalidArgument( "shape of tensor describing " " the output must have dimension 1 and the same " @@ -171,10 +172,11 @@ class SplitVOp : public XlaOpKernel { } OP_REQUIRES( - ctx, (neg_one_dim == -1 && - total_split_size == input_shape.dim_size(split_dim)) || - (neg_one_dim >= 0 && - total_split_size <= input_shape.dim_size(split_dim)), + ctx, + (neg_one_dim == -1 && + total_split_size == input_shape.dim_size(split_dim)) || + (neg_one_dim >= 0 && + total_split_size <= input_shape.dim_size(split_dim)), errors::InvalidArgument("Determined shape must either match " "input shape along split_dim exactly if " "fully specified, or be less than the size of " @@ -206,7 +208,10 @@ class SplitVOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("SplitV"), SplitVOp); +REGISTER_XLA_OP(Name("SplitV") + .CompileTimeConstInput("split_dim") + .CompileTimeConstInput("size_splits"), + SplitVOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc index bb7891b31f6d52fd84cf72579c343f50473e1632..1a78c7ab9be701d3d02285ed21604f0f856b3f1f 100644 --- a/tensorflow/compiler/tf2xla/kernels/stack_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/stack_ops.cc @@ -40,7 +40,7 @@ namespace { Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, TensorShape* stack_shape) { - auto shape_or_status = builder->GetShape(resource->value); + auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } @@ -63,22 +63,22 @@ Status GetStackShape(xla::ComputationBuilder* builder, XlaResource* resource, Status MaybeInitializeStack(xla::ComputationBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "Stack dtype is ", DataTypeString(resource->type), " but op has dtype ", - DataTypeString(dtype), "."); + "Stack dtype is ", DataTypeString(resource->type()), + " but op has dtype ", DataTypeString(dtype), "."); } TensorShape stack_shape; - stack_shape.AddDim(resource->tensor_array_size); + stack_shape.AddDim(resource->tensor_array_size()); stack_shape.AppendShape(elem_shape); - if (resource->value.handle() == 0) { + if (!resource->initialized()) { // Stack has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); - resource->value = - builder->Tuple({builder->Broadcast(zero, stack_shape.dim_sizes()), - builder->ConstantR0(0)}); + xla::ComputationDataHandle zero = + XlaHelpers::Zero(builder, resource->type()); + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the expected shape matches the actual shape. TensorShape actual_shape; @@ -105,7 +105,9 @@ class StackOp : public XlaOpKernel { OP_REQUIRES( ctx, size >= 0, errors::InvalidArgument( - "XLA compilation requires a fixed stack size upper bound.")); + "XLA compilation requires a fixed stack size upper bound. If " + "you are using tf.while_loop, set the maximum_iterations parameter " + "to fix this issue.")); // We defer initializing the Stack resource until we see the first push. // Otherwise we do not know the shape of the stack elements. @@ -115,8 +117,8 @@ class StackOp : public XlaOpKernel { string name = strings::StrCat("Stack: ", stack_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kStack, -1, std::move(name), dtype_, - value, &resource)); - resource->tensor_array_size = size; + TensorShape(), value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &resource)); ctx->SetResourceOutput(0, resource); } @@ -127,7 +129,7 @@ class StackOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(StackOp); }; -REGISTER_XLA_OP(Name("StackV2"), StackOp); +REGISTER_XLA_OP(Name("StackV2").CompileTimeConstInput("max_size"), StackOp); class StackPushOp : public XlaOpKernel { public: @@ -145,8 +147,8 @@ class StackPushOp : public XlaOpKernel { // Initializes the Stack, if the element shape was not already known. OP_REQUIRES_OK(ctx, MaybeInitializeStack(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = b->GetTupleElement(resource->value, 0); - xla::ComputationDataHandle index = b->GetTupleElement(resource->value, 1); + xla::ComputationDataHandle ta = b->GetTupleElement(resource->value(), 0); + xla::ComputationDataHandle index = b->GetTupleElement(resource->value(), 1); xla::ComputationDataHandle value = ctx->Input(1); // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. @@ -160,9 +162,9 @@ class StackPushOp : public XlaOpKernel { // TODO(phawkins): We don't check the index is in bounds --- there is no // error mechanism in XLA. - resource->value = - b->Tuple({b->DynamicUpdateSlice(ta, update, start_indices), - b->Add(index, b->ConstantR0(1))}); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple( + {b->DynamicUpdateSlice(ta, update, start_indices), + b->Add(index, b->ConstantR0(1))}))); ctx->SetOutput(0, value); } @@ -187,27 +189,22 @@ class StackPopOp : public XlaOpKernel { XlaResource* resource; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - OP_REQUIRES(ctx, resource->type == dtype_, - errors::InvalidArgument( - "Stack dtype is ", DataTypeString(resource->type), - " but Op requested dtype ", DataTypeString(dtype_), ".")); - // There is a somewhat subtle issue here: here "uninitialized" means we have // not yet seen a pop in the order that we compile operators, not the order // that we run them. However, in practice the two orders should be the same // for the sole user of the stack operators (loop gradients). - OP_REQUIRES(ctx, resource->value.handle() != 0, + OP_REQUIRES(ctx, resource->initialized(), errors::InvalidArgument("Stack pop on uninitialized stack")); TensorShape stack_shape; OP_REQUIRES_OK(ctx, GetStackShape(b, resource, &stack_shape)); - xla::ComputationDataHandle state = resource->value; + xla::ComputationDataHandle state = resource->value(); xla::ComputationDataHandle ta = b->GetTupleElement(state, 0); xla::ComputationDataHandle index = b->GetTupleElement(state, 1); index = b->Sub(index, b->ConstantR0(1)); - resource->value = b->Tuple({ta, index}); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Tuple({ta, index}))); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. auto start_indices = diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 6af4bd0496e0da926726e3f74376281f539e925a..91c169428c7a88a8d107a97445aeea999946e3e9 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -106,7 +106,11 @@ class StridedSliceOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSlice"), StridedSliceOp); +REGISTER_XLA_OP(Name("StridedSlice") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceOp); class StridedSliceGradOp : public XlaOpKernel { public: @@ -211,7 +215,12 @@ class StridedSliceGradOp : public XlaOpKernel { DataType index_type_; }; -REGISTER_XLA_OP(Name("StridedSliceGrad"), StridedSliceGradOp); +REGISTER_XLA_OP(Name("StridedSliceGrad") + .CompileTimeConstInput("shape") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceGradOp); class StridedSliceAssignOp : public XlaOpKernel { public: @@ -222,6 +231,7 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetAttr("new_axis_mask", &new_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("shrink_axis_mask", &shrink_axis_mask_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Index", &index_type_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_)); } void Compile(XlaOpKernelContext* ctx) override { @@ -243,9 +253,9 @@ class StridedSliceAssignOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, LiteralToHostTensor(strides_literal, index_type_, &strides_tensor)); - DataType lhs_type; TensorShape lhs_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &lhs_type, &lhs_shape)); + xla::ComputationDataHandle lhs; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &lhs_shape, &lhs)); const TensorShape rhs_shape = ctx->InputShape(4); @@ -273,9 +283,6 @@ class StridedSliceAssignOp : public XlaOpKernel { " does not match r-value shape ", rhs_shape.DebugString(), ". Automatic broadcasting not yet implemented.")); - xla::ComputationDataHandle lhs; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &lhs)); - xla::ComputationDataHandle rhs = ctx->Input(4); gtl::InlinedVector dimensions_to_reverse; @@ -311,16 +318,21 @@ class StridedSliceAssignOp : public XlaOpKernel { lhs, rhs, ctx->builder()->ConstantR1(slice_begin)); } - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, lhs_type, lhs)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, dtype_, lhs)); } private: int32 begin_mask_, end_mask_; int32 ellipsis_mask_, new_axis_mask_, shrink_axis_mask_; DataType index_type_; + DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceStridedSliceAssign"), StridedSliceAssignOp); +REGISTER_XLA_OP(Name("ResourceStridedSliceAssign") + .CompileTimeConstInput("begin") + .CompileTimeConstInput("end") + .CompileTimeConstInput("strides"), + StridedSliceAssignOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index 351fda251798e43b607fb445f2c98abd57b3d86b..000b50af6bd86b7268c016865fb0856c16053ece 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -21,10 +21,10 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -50,35 +50,38 @@ namespace { Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, XlaResource* resource, DataType dtype, const TensorShape& elem_shape) { - if (resource->kind != XlaResource::kTensorArray) { + if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument("Unexpected non-TensorArray resource"); } - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(resource->type), + "TensorArray dtype is ", DataTypeString(resource->type()), " but op has dtype ", DataTypeString(dtype), "."); } - TF_RET_CHECK(resource->tensor_array_size >= 0) - << resource->name << " size " << resource->tensor_array_size; - TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size); - ta_shape.AppendShape(elem_shape); + TF_RET_CHECK(resource->tensor_array_size() >= 0) + << resource->name() << " size " << resource->tensor_array_size(); - if (resource->value.handle() == 0) { - // TensorArray has not been initialized. - xla::ComputationDataHandle zero = XlaHelpers::Zero(builder, resource->type); - resource->value = builder->Broadcast(zero, ta_shape.dim_sizes()); + if (!resource->initialized()) { + xla::ComputationDataHandle zero = + XlaHelpers::Zero(builder, resource->type()); + + TF_RETURN_IF_ERROR(resource->SetTypeAndShape(dtype, elem_shape)); + TF_RETURN_IF_ERROR(resource->SetZeroValue(builder)); } else { // Checks the elem_shape matches the TensorArray shape. - auto shape_or_status = builder->GetShape(resource->value); + auto shape_or_status = builder->GetShape(resource->value()); if (!shape_or_status.ok()) { return shape_or_status.status(); } TensorShape shape; TF_RETURN_IF_ERROR( XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TensorShape ta_shape; + ta_shape.AddDim(resource->tensor_array_size()); + ta_shape.AppendShape(elem_shape); if (ta_shape != shape) { return errors::InvalidArgument( "Mismatched TensorArray sizes: ", ta_shape.DebugString(), " vs ", @@ -93,19 +96,17 @@ Status MaybeInitializeTensorArray(xla::ComputationBuilder* builder, Status CheckTensorArrayIsInitialized(const string& op_name, const XlaResource* resource, DataType dtype) { - if (resource->kind != XlaResource::kTensorArray) { + if (resource->kind() != XlaResource::kTensorArray) { return errors::InvalidArgument( - "Unexpected non-TensorArray resource passed " - "to ", - op_name); + "Unexpected non-TensorArray resource passed to ", op_name); } - if (resource->value.handle() == 0) { + if (!resource->initialized()) { return errors::InvalidArgument("Uninitialized TensorArray passed to ", op_name); } - if (resource->type != dtype) { + if (resource->type() != dtype) { return errors::InvalidArgument( - "TensorArray dtype is ", DataTypeString(resource->type), + "TensorArray dtype is ", DataTypeString(resource->type()), " but op has dtype ", DataTypeString(dtype), "."); } @@ -115,10 +116,8 @@ Status CheckTensorArrayIsInitialized(const string& op_name, Status GetTensorArrayShape(const XlaResource* resource, xla::ComputationBuilder* builder, TensorShape* shape) { - TF_RETURN_IF_ERROR(resource->GetShape(builder, shape)); - if (shape->dims() < 1) { - return errors::InvalidArgument("TensorArray rank must be >= 1"); - } + *shape = resource->shape(); + shape->InsertDim(0, resource->tensor_array_size()); return Status::OK(); } @@ -161,8 +160,8 @@ class TensorArrayOp : public XlaOpKernel { // Initializes the TensorArray value if we know the element shape. // Otherwise, defer initialization to the first write. xla::ComputationDataHandle value; + TensorShape shape; if (element_shape_.IsFullyDefined()) { - TensorShape shape; CHECK(element_shape_.AsTensorShape(&shape)); TensorShape ta_shape; ta_shape.AddDim(size); @@ -176,8 +175,8 @@ class TensorArrayOp : public XlaOpKernel { string name = strings::StrCat("TensorArray: ", tensor_array_name_); OP_REQUIRES_OK( ctx, xc.CreateResource(XlaResource::kTensorArray, -1, std::move(name), - dtype_, value, &var)); - var->tensor_array_size = size; + dtype_, shape, value, /*tensor_array_size=*/size, + /*tensor_array_gradients=*/{}, &var)); ctx->SetResourceOutput(0, var); Tensor flow(DT_FLOAT, TensorShape({})); @@ -193,7 +192,8 @@ class TensorArrayOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArrayOp); }; -REGISTER_XLA_OP(Name("TensorArrayV3"), TensorArrayOp); +REGISTER_XLA_OP(Name("TensorArrayV3").CompileTimeConstInput("size"), + TensorArrayOp); class TensorArrayWriteOp : public XlaOpKernel { public: @@ -213,7 +213,7 @@ class TensorArrayWriteOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); xla::ComputationDataHandle index = ctx->Input(1); xla::ComputationDataHandle value = ctx->Input(2); xla::ComputationDataHandle flow = ctx->Input(3); @@ -230,7 +230,7 @@ class TensorArrayWriteOp : public XlaOpKernel { xla::ComputationDataHandle written = DynamicAddSlice(b, ta, update, slice_shape.dim_sizes(), start_indices); - resource->value = written; + OP_REQUIRES_OK(ctx, resource->SetValue(written)); ctx->SetOutput(0, flow); } @@ -259,7 +259,7 @@ class TensorArrayReadOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); xla::ComputationDataHandle index = ctx->Input(1); // start_indices of the DynamicSlice are [index, 0, 0, ..., 0]. @@ -309,10 +309,39 @@ class TensorArrayGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); + + // Look for the case where the gather takes a simple slice from the + // tensor array (0, 1, 2, 3, 4, ..., N) + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok()) { + bool gather_is_dense_slice = true; + for (auto i = 0; i < const_indices.size(); i++) { + if (const_indices[i] != i) { + gather_is_dense_slice = false; + break; + } + } + + if (gather_is_dense_slice) { + std::vector begin(ta_shape.dims(), 0); + std::vector strides(ta_shape.dims(), 1); + std::vector end(ta_shape.dims(), 1); + end[0] = const_indices.size(); + for (auto i = 1; i < ta_shape.dims(); i++) { + end[i] = ta_shape.dim_size(i); + } + ctx->SetOutput(0, b->Slice(ta, begin, end, strides)); + return; + } + } - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + ctx, + XlaGather(ta, ta_shape, indices, indices_shape, /*axis=*/0, + /*indices_are_nd=*/false, dtype_, index_type, b, &gather)); ctx->SetOutput(0, gather); } @@ -348,35 +377,54 @@ class TensorArrayScatterOp : public XlaOpKernel { const int num_indices = indices_shape.dim_size(0); const xla::ComputationDataHandle indices = ctx->Input(1); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); const xla::ComputationDataHandle value = ctx->Input(2); const xla::ComputationDataHandle flow = ctx->Input(3); - auto slice_dims = value_shape.dim_sizes(); - slice_dims[0] = 1LL; - - std::vector value_starts(value_shape.dims(), 0); - auto value_ends = value_shape.dim_sizes(); - - std::vector value_strides(value_shape.dims(), 1); - - // For every (index, value) pair, update the corresponding TensorArray - // storage. - for (int i = 0; i < num_indices; ++i) { - // Slice out part of the value. - value_starts[0] = i; - value_ends[0] = i + 1; - auto slice = b->Slice(value, value_starts, value_ends, value_strides); + // Look for the case where the scatter is for each sub-tensor in order. The + // tensor array implementation allows for this to be a straight addition. + bool scatter_all_elements_in_order = false; + std::vector const_indices; + Status status = ctx->ConstantInputAsIntVector(1, &const_indices); + if (status.ok() && num_indices == value_shape.dim_size(0)) { + scatter_all_elements_in_order = true; + for (auto i = 0; i < num_indices; i++) { + if (const_indices[i] != i) { + scatter_all_elements_in_order = false; + break; + } + } + } - // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. - auto index = b->Slice(indices, {i}, {i + 1}, {1}); - auto start_indices = - b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); - ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + if (scatter_all_elements_in_order) { + ta = b->Add(ta, value); + } else { + auto slice_dims = value_shape.dim_sizes(); + slice_dims[0] = 1LL; + + std::vector value_starts(value_shape.dims(), 0); + auto value_ends = value_shape.dim_sizes(); + + std::vector value_strides(value_shape.dims(), 1); + + // For every (index, value) pair, update the corresponding TensorArray + // storage. + for (int i = 0; i < num_indices; ++i) { + // Slice out part of the value. + value_starts[0] = i; + value_ends[0] = i + 1; + auto slice = b->Slice(value, value_starts, value_ends, value_strides); + + // start_indices of the DynamicUpdateSlice are [index, 0, 0, ..., 0]. + auto index = b->Slice(indices, {i}, {i + 1}, {1}); + auto start_indices = + b->Pad(b->Reshape(index, {1}), b->ConstantR0(0), + xla::MakeEdgePaddingConfig({{0, elem_shape.dims()}})); + ta = DynamicAddSlice(b, ta, slice, slice_dims, start_indices); + } } - resource->value = ta; + OP_REQUIRES_OK(ctx, resource->SetValue(ta)); ctx->SetOutput(0, flow); } @@ -405,7 +453,7 @@ class TensorArrayConcatOp : public XlaOpKernel { TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); auto ta_dims = ta_shape.dim_sizes(); std::vector shape(ta_dims.begin() + 1, ta_dims.end()); @@ -460,16 +508,17 @@ class TensorArraySplitOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK(ctx, MaybeInitializeTensorArray(b, resource, dtype_, elem_shape)); - xla::ComputationDataHandle ta = resource->value; + xla::ComputationDataHandle ta = resource->value(); TensorShape ta_shape; - ta_shape.AddDim(resource->tensor_array_size); + ta_shape.AddDim(resource->tensor_array_size()); ta_shape.AppendShape(elem_shape); - OP_REQUIRES(ctx, lengths.size() == resource->tensor_array_size, - errors::InvalidArgument( - "TensorArray's size is not equal to the size of lengths (", - lengths.size(), " vs. ", resource->tensor_array_size, ")")); + OP_REQUIRES( + ctx, lengths.size() == resource->tensor_array_size(), + errors::InvalidArgument( + "TensorArray's size is not equal to the size of lengths (", + lengths.size(), " vs. ", resource->tensor_array_size(), ")")); const xla::ComputationDataHandle value = ctx->Input(1); const xla::ComputationDataHandle flow = ctx->Input(3); @@ -479,7 +528,8 @@ class TensorArraySplitOp : public XlaOpKernel { value_shape.DebugString(), " vs. ", ta_shape.DebugString())); - resource->value = b->Add(ta, b->Reshape(value, ta_shape.dim_sizes())); + OP_REQUIRES_OK(ctx, resource->SetValue(b->Add( + ta, b->Reshape(value, ta_shape.dim_sizes())))); ctx->SetOutput(0, flow); } @@ -490,7 +540,8 @@ class TensorArraySplitOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TensorArraySplitOp); }; -REGISTER_XLA_OP(Name("TensorArraySplitV3"), TensorArraySplitOp); +REGISTER_XLA_OP(Name("TensorArraySplitV3").CompileTimeConstInput("lengths"), + TensorArraySplitOp); class TensorArraySizeOp : public XlaOpKernel { public: @@ -500,7 +551,8 @@ class TensorArraySizeOp : public XlaOpKernel { XlaResource* var; OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &var)); Tensor size_tensor(DT_INT32, {}); - size_tensor.scalar()() = static_cast(var->tensor_array_size); + size_tensor.scalar()() = + static_cast(var->tensor_array_size()); ctx->SetConstantOutput(0, size_tensor); } @@ -523,7 +575,7 @@ class TensorArrayGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); OP_REQUIRES_OK( - ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type)); + ctx, CheckTensorArrayIsInitialized(name(), resource, resource->type())); TensorShape ta_shape; OP_REQUIRES_OK(ctx, GetTensorArrayShape(resource, b, &ta_shape)); diff --git a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc index 9ee6bd892504e683a191484fb09259619759f36d..9aefcd4fc7f94a1dba1c56273c55d0b98fbbfaf2 100644 --- a/tensorflow/compiler/tf2xla/kernels/tile_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tile_ops.cc @@ -122,7 +122,7 @@ class TileOp : public XlaOpKernel { TF_DISALLOW_COPY_AND_ASSIGN(TileOp); }; -REGISTER_XLA_OP(Name("Tile"), TileOp); +REGISTER_XLA_OP(Name("Tile").CompileTimeConstInput("multiples"), TileOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 5534d1bfa1338c7fe3647cd6aa281c4907dfdf8c..f750f7003be288461f5f10455e58932d1b4e4524 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -32,9 +32,24 @@ class ResourceApplyGradientDescent : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; xla::ComputationBuilder* b = ctx->builder(); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + DataType type = ctx->input_type(1); + TensorShape var_shape; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &handle)); + + TensorShape alpha_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(alpha_shape), + errors::InvalidArgument("alpha is not a scalar: ", + alpha_shape.DebugString())); + + TensorShape delta_shape = ctx->InputShape(2); + OP_REQUIRES( + ctx, var_shape.IsSameSize(delta_shape), + errors::InvalidArgument("var and delta do not have the same shape: ", + var_shape.DebugString(), " vs ", + delta_shape.DebugString())); + handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2))); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -52,18 +67,10 @@ class ResourceApplyMomentum : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyMomentum must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -86,10 +93,6 @@ class ResourceApplyMomentum : public XlaOpKernel { errors::InvalidArgument("momentum is not a scalar: ", momentum_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle momentum = ctx->Input(4); @@ -122,18 +125,10 @@ class ResourceApplyAdagrad : public XlaOpKernel { DataType type = ctx->input_type(2); - DataType var_type, accum_type; TensorShape var_shape, accum_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == accum_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyAdagrad must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type))); + xla::ComputationDataHandle var, accum; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &accum_shape, &accum)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -151,9 +146,6 @@ class ResourceApplyAdagrad : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, accum; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); xla::ComputationDataHandle lr = ctx->Input(2); xla::ComputationDataHandle grad = ctx->Input(3); @@ -175,18 +167,11 @@ class ResourceApplyAdam : public XlaOpKernel { } void Compile(XlaOpKernelContext* ctx) override { - DataType var_type, m_type, v_type; TensorShape var_shape, m_shape, v_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &m_type, &m_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &v_type, &v_shape)); - - OP_REQUIRES( - ctx, dtype_ == var_type && dtype_ == m_type && dtype_ == v_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(dtype_), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(m_type), " vs. ", DataTypeString(v_type))); + xla::ComputationDataHandle var, m, v; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype_, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype_, &m_shape, &m)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype_, &v_shape, &v)); TensorShape beta1_power_shape = ctx->InputShape(3); TensorShape beta2_power_shape = ctx->InputShape(4); @@ -228,10 +213,6 @@ class ResourceApplyAdam : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, m, v; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &m)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &v)); xla::ComputationDataHandle beta1_power = ctx->Input(3); xla::ComputationDataHandle beta2_power = ctx->Input(4); xla::ComputationDataHandle lr = ctx->Input(5); @@ -278,18 +259,11 @@ class ResourceApplyRMSProp : public XlaOpKernel { DataType type = ctx->input_type(3); - DataType var_type, ms_type, mom_type; TensorShape var_shape, ms_shape, mom_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(1, &ms_type, &ms_shape)); - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(2, &mom_type, &mom_shape)); - - OP_REQUIRES( - ctx, type == var_type && type == ms_type && type == mom_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyRMSProp must match: ", - DataTypeString(type), " vs. ", DataTypeString(var_type), " vs. ", - DataTypeString(ms_type), " vs. ", DataTypeString(mom_type))); + xla::ComputationDataHandle var, ms, mom; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, type, &ms_shape, &ms)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, type, &mom_shape, &mom)); TensorShape lr_shape = ctx->InputShape(3); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape), @@ -323,10 +297,6 @@ class ResourceApplyRMSProp : public XlaOpKernel { "var and grad do not have the same shape", var_shape.DebugString(), " ", grad_shape.DebugString())); - xla::ComputationDataHandle var, ms, mom; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &ms)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &mom)); xla::ComputationDataHandle lr = ctx->Input(3); xla::ComputationDataHandle rho = ctx->Input(4); xla::ComputationDataHandle momentum = ctx->Input(5); @@ -373,20 +343,11 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { xla::ComputationBuilder* b = ctx->builder(); - DataType var_type, accum_type, linear_type; TensorShape var_shape, accum_shape, linear_shape; - OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape)); - OP_REQUIRES_OK(ctx, - ctx->GetVariableTypeAndShape(2, &linear_type, &linear_shape)); - - OP_REQUIRES( - ctx, dtype == var_type && dtype == accum_type && dtype == linear_type, - errors::InvalidArgument( - "Types of variable arguments to ResourceApplyFtrlV2 must match: ", - DataTypeString(dtype), " vs. ", DataTypeString(var_type), " and ", - DataTypeString(accum_type), " and ", DataTypeString(linear_type))); + xla::ComputationDataHandle var, accum, linear; + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, dtype, &var_shape, &var)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, dtype, &accum_shape, &accum)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, dtype, &linear_shape, &linear)); OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape), errors::InvalidArgument( @@ -438,10 +399,6 @@ void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, errors::InvalidArgument("lr_power is not a scalar: ", lr_power_shape.DebugString())); - xla::ComputationDataHandle var, accum, linear; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum)); - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(2, &linear)); xla::ComputationDataHandle grad = ctx->Input(3); xla::ComputationDataHandle lr = ctx->Input(4); xla::ComputationDataHandle l1 = ctx->Input(5); diff --git a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc index 2fc5d40d1059b868eef0a632071e7cccdecaf9f4..c167642174b328a968d7f7ce1f0ad6e0ab8a7a68 100644 --- a/tensorflow/compiler/tf2xla/kernels/transpose_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/transpose_op.cc @@ -54,7 +54,8 @@ class TransposeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputReshaped(1, {dims}, &literal)); std::vector perm(dims); - std::copy(literal.s32s().begin(), literal.s32s().end(), perm.begin()); + std::copy(literal.data().begin(), literal.data().end(), + perm.begin()); std::vector transposed_order; // Check whether permutation is a permutation of integers of [0 .. dims). @@ -72,8 +73,9 @@ class TransposeOp : public XlaOpKernel { } } for (int i = 0; i < dims; ++i) { - OP_REQUIRES(ctx, bits[i], errors::InvalidArgument( - i, " is missing from 'perm' argument.")); + OP_REQUIRES( + ctx, bits[i], + errors::InvalidArgument(i, " is missing from 'perm' argument.")); } // 0-D, 1-D, and identity transposes do nothing. @@ -87,7 +89,7 @@ class TransposeOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("Transpose"), TransposeOp); +REGISTER_XLA_OP(Name("Transpose").CompileTimeConstInput("perm"), TransposeOp); // InvertPermutation frequently forms part of the gradient of Transpose. // @@ -103,8 +105,9 @@ class InvertPermutationOp : public XlaOpKernel { explicit InvertPermutationOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - OP_REQUIRES(ctx, FastBoundsCheck(ctx->InputShape(0).num_elements(), - std::numeric_limits::max()), + OP_REQUIRES(ctx, + FastBoundsCheck(ctx->InputShape(0).num_elements(), + std::numeric_limits::max()), errors::InvalidArgument("permutation of nonnegative int32s " "must have <= int32 max elements")); @@ -128,7 +131,9 @@ class InvertPermutationOp : public XlaOpKernel { } }; -REGISTER_XLA_OP(Name("InvertPermutation").TypeConstraint("T", DT_INT32), +REGISTER_XLA_OP(Name("InvertPermutation") + .TypeConstraint("T", DT_INT32) + .CompileTimeConstInput("x"), InvertPermutationOp); } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index a266e9013c41b88788dbc99849f01c09f3d61348..0c5ad9e5255ffc3dfcfb83335060ae833937b3ce 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -50,18 +50,41 @@ XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); +// acos(x) = 2 * atan(sqrt(1 - x^2) / (1 + x)) +XLAJIT_MAKE_UNARY( + Acos, + b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + b->Atan2(b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), 0.5)), + b->Add(XlaHelpers::One(b, input_type(0)), x)))); + // acosh(x) = log(x + sqrt(x^2 - 1)) XLAJIT_MAKE_UNARY( Acosh, b->Log(b->Add(x, b->Pow(b->Sub(b->Mul(x, x), XlaHelpers::One(b, input_type(0))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + +// asin(x) = 2 * atan(x / (1 + sqrt(1 - x^2))) +XLAJIT_MAKE_UNARY( + Asin, + b->Mul(XlaHelpers::FloatLiteral(b, input_type(0), 2.0), + b->Atan2(x, b->Add(XlaHelpers::One(b, input_type(0)), + b->Pow(b->Sub(XlaHelpers::One(b, input_type(0)), + b->Mul(x, x)), + XlaHelpers::FloatLiteral(b, input_type(0), + 0.5)))))); + // asinh(x) = log(x + sqrt(x^2 + 1)) XLAJIT_MAKE_UNARY( Asinh, b->Log(b->Add(x, b->Pow(b->Add(b->Mul(x, x), XlaHelpers::One(b, input_type(0))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))))); + +XLAJIT_MAKE_UNARY(Atan, b->Atan2(x, XlaHelpers::One(b, input_type(0)))); + // atanh(x) = 0.5 * log((1 + x) / (1 - x)) XLAJIT_MAKE_UNARY( Atanh, b->Mul(b->Log(b->Div(b->Add(XlaHelpers::One(b, input_type(0)), x), diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index b19ea22f50d2dd44e8d1d81f5930263f364030e1..71173f5aead47702f0ed9e95b827a6fefd9b7efd 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" +#include "tensorflow/compiler/tf2xla/kernels/shape_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" @@ -22,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/no_op.h" namespace tensorflow { @@ -31,21 +33,29 @@ class VarIsInitializedOp : public XlaOpKernel { public: explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - xla::ComputationDataHandle handle; - bool initialized = ctx->ReadVariableInput(0, &handle).ok(); - ctx->SetOutput(0, ctx->builder()->ConstantR0(initialized)); + XlaResource* variable; + OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); + ctx->SetOutput(0, + ctx->builder()->ConstantR0(variable->initialized())); } }; REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); class ReadVariableOp : public XlaOpKernel { public: - explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); + } + void Compile(XlaOpKernelContext* ctx) override { xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK( + ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); ctx->SetOutput(0, handle); } + + private: + DataType dtype_; }; REGISTER_XLA_OP(Name("ReadVariableOp"), ReadVariableOp); @@ -63,10 +73,12 @@ class AssignAddVariableOp : public XlaOpKernel { public: explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Add(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -77,10 +89,12 @@ class AssignSubVariableOp : public XlaOpKernel { public: explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + DataType type = ctx->input_type(1); xla::ComputationDataHandle handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle)); + OP_REQUIRES_OK(ctx, + ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); handle = ctx->builder()->Sub(handle, ctx->Input(1)); - OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); + OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); } }; REGISTER_XLA_OP( @@ -93,33 +107,47 @@ class ResourceGatherOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* builder = ctx->builder(); - // Get the shape of the resource tensor. - TensorShape resource_shape; - DataType resource_dtype; - OP_REQUIRES_OK( - ctx, ctx->GetVariableTypeAndShape(0, &resource_dtype, &resource_shape)); - - DataType expected_output_dtype = ctx->expected_output_dtype(0); - OP_REQUIRES(ctx, resource_dtype == expected_output_dtype, - errors::InvalidArgument( - "Variable dtype is ", DataTypeString(resource_dtype), - " but expected output dtype is ", - DataTypeString(expected_output_dtype), ".")); + DataType type = ctx->expected_output_dtype(0); + TensorShape resource_shape; xla::ComputationDataHandle resource_handle; - OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &resource_handle)); + OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, + &resource_handle)); auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); DataType index_type = ctx->input_type(1); - xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, resource_handle, resource_shape, indices, indices_shape, 0, - resource_dtype, index_type, builder); + xla::ComputationDataHandle gather; + OP_REQUIRES_OK( + ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, + /*axis=*/0, /*indices_are_nd=*/false, type, index_type, + builder, &gather)); ctx->SetOutput(0, gather); } }; REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes), ResourceGatherOp); +class VariableShapeOp : public XlaOpKernel { + public: + explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + DataType variable_dtype; + TensorShape shape; + OP_REQUIRES_OK(ctx, + ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); + Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); + OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); + ctx->SetConstantOutput(0, shape_constant); + } + + private: + DataType out_dtype_; +}; + +REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index ead26478ff2a3a1302e95e4ee5dbbf366b04efc6..0ff1b65ae9179d506e453f98097cd88083eb2be7 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -39,7 +39,7 @@ Status MakeXlaCompilerArgumentsFromInputs( *has_uninitialized_vars = false; *has_tensor_arrays = false; for (int i = 0; i < ctx->num_inputs(); ++i) { - VLOG(2) << " Input " << i + VLOG(2) << " Input " << i << " type: " << DataTypeString(ctx->input_type(i)) << " shape: " << ctx->InputShape(i).DebugString(); XlaCompiler::Argument& arg = (*args)[i]; @@ -50,34 +50,32 @@ Status MakeXlaCompilerArgumentsFromInputs( XlaResource* resource; TF_RETURN_IF_ERROR(ctx->GetResourceInput(i, &resource)); - arg.initialized = resource->value.handle() > 0; + arg.initialized = resource->initialized(); arg.kind = XlaCompiler::Argument::kResource; - arg.resource_kind = resource->kind; + arg.resource_kind = resource->kind(); if (arg.resource_kind == XlaResource::kTensorArray) { *has_tensor_arrays = true; } - arg.type = resource->type; - if (arg.initialized) { - TF_RETURN_IF_ERROR(resource->PackedShape(ctx->builder(), &arg.shape)); - } else { + arg.type = resource->type(); + arg.shape = resource->shape(); + if (!arg.initialized) { *has_uninitialized_vars = true; } - arg.tensor_array_size = resource->tensor_array_size; - for (const auto& gradient : resource->tensor_array_gradients) { + arg.tensor_array_size = resource->tensor_array_size(); + for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } - arg.name = resource->name; - VLOG(2) << " resource " << resource->name + arg.name = resource->name(); + VLOG(2) << " resource " << resource->name() << " type: " << DataTypeString(arg.type) - << " shape: " << xla::ShapeUtil::HumanString(arg.shape) + << " shape: " << arg.shape.DebugString() << " initialized: " << arg.initialized; } else { arg.kind = XlaCompiler::Argument::kParameter; arg.type = ctx->input_type(i); - TF_RETURN_IF_ERROR( - TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + arg.shape = ctx->InputShape(i); } } return Status::OK(); @@ -120,6 +118,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { body_options.use_tuple_arg = true; body_options.return_updated_values_for_all_resources = true; body_options.resolve_compile_time_constants = false; + body_options.is_entry_computation = false; XlaCompiler::CompilationResult body; OP_REQUIRES_OK(ctx, compiler->CompileFunction(body_options, body_name_attr_, arguments, &body)); @@ -153,22 +152,20 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler::Argument& arg = arguments[update.input_index]; if (!arg.initialized) { VLOG(2) << "Update shape for argument " << update.input_index << " " - << xla::ShapeUtil::HumanString(update.shape); + << update.shape.DebugString(); arg.initialized = true; - xla::Shape shape = update.shape; - if (!update.tensor_array_gradients_accessed.empty()) { - shape = xla::ShapeUtil::GetTupleElementShape(shape, 0); - } - std::unique_ptr zero = - xla::Literal::CreateFromShape(shape); - resource->value = builder->ConstantLiteral(*zero); + arg.shape = update.shape; + OP_REQUIRES_OK(ctx, + resource->SetTypeAndShape(update.type, update.shape)); + + OP_REQUIRES_OK(ctx, resource->SetZeroValue(builder)); } // Add any TensorArray gradients touched by the body to the enclosing // graph. for (const string& grad_source : update.tensor_array_gradients_accessed) { - VLOG(4) << "TensorArray " << resource->name << " accessed gradient " + VLOG(4) << "TensorArray " << resource->name() << " accessed gradient " << grad_source; XlaResource* gradient; OP_REQUIRES_OK(ctx, resource->GetOrCreateTensorArrayGradient( @@ -177,12 +174,9 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Add all of the TensorArray gradients to the argument. For simplicity, // we always pass all known gradients. - for (const auto& gradient : resource->tensor_array_gradients) { + for (const auto& gradient : resource->tensor_array_gradients()) { arg.tensor_array_gradients.insert(gradient.first); } - - // Recompute the argument shape. - OP_REQUIRES_OK(ctx, resource->PackedShape(ctx->builder(), &arg.shape)); } // Recompile the body with the "correct" resource shapes. VLOG(1) << "Recompiling body with corrected resource shapes"; @@ -196,14 +190,21 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { XlaCompiler::CompileOptions cond_options; cond_options.use_tuple_arg = true; cond_options.resolve_compile_time_constants = false; + cond_options.is_entry_computation = false; XlaCompiler::CompilationResult cond; OP_REQUIRES_OK(ctx, compiler->CompileFunction(cond_options, cond_name_attr_, arguments, &cond)); - xla::Shape body_input_shape = - xla::ShapeUtil::MakeTupleShape(body.xla_input_shapes); - xla::Shape cond_input_shape = - xla::ShapeUtil::MakeTupleShape(cond.xla_input_shapes); + OP_REQUIRES(ctx, body.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape body_input_shape = body.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(body_input_shape), + errors::FailedPrecondition("Expected tuple shape")); + OP_REQUIRES(ctx, cond.xla_input_shapes.size() == 1, + errors::FailedPrecondition("Expected one input shape")); + xla::Shape cond_input_shape = cond.xla_input_shapes[0]; + OP_REQUIRES(ctx, xla::ShapeUtil::IsTuple(cond_input_shape), + errors::FailedPrecondition("Expected tuple shape")); VLOG(2) << "Body shape: " << xla::ShapeUtil::HumanString(body_input_shape) << " -> " << xla::ShapeUtil::HumanString(body.xla_output_shape); @@ -286,9 +287,9 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { builder->GetTupleElement(while_result, pos), builder)); } VLOG(2) << "Loop-carried variable: pos: " << update.input_index - << " name: " << resource->name << " modified: " << update.modified + << " name: " << resource->name() << " modified: " << update.modified << " type: " << DataTypeString(update.type) - << " shape: " << xla::ShapeUtil::HumanString(update.shape); + << " shape: " << update.shape.DebugString(); // Copies the identity of the resource variable from input to output // unchanged, even if the variable was not modified. ctx->op_kernel_context()->set_output( diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index 21ad21f73737a289390ed1ea767db1078d05b466..488fda74bf7b5c1d66f8d706a1be3cc1fc29a492 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -49,6 +49,25 @@ cc_library( ], ) +cc_library( + name = "scatter", + srcs = ["scatter.cc"], + hdrs = ["scatter.h"], + deps = [ + ":util", + ":while_loop", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/core:lib", + ], +) + cc_library( name = "triangular_solve", srcs = ["triangular_solve.cc"], @@ -60,6 +79,8 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/core:lib", @@ -105,6 +126,21 @@ cc_library( ], ) +cc_library( + name = "while_loop", + srcs = ["while_loop.cc"], + hdrs = ["while_loop.h"], + deps = [ + ":util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.cc b/tensorflow/compiler/tf2xla/lib/batch_dot.cc index 28a5e6a58bb312f4c4821bcce484a08160009d56..798f0fa78055e800038e8bf41b4f410b670be7dd 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.cc +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.cc @@ -25,12 +25,10 @@ limitations under the License. namespace tensorflow { -// The current implementation simply unrolls the computation along the batch -// dimension. -// TODO(andydavis): add batching support to XLA's Dot operator. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y) { + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, + bool conjugate_x, bool conjugate_y) { TF_ASSIGN_OR_RETURN(std::unique_ptr x_shape, builder->GetShape(x)); TF_ASSIGN_OR_RETURN(std::unique_ptr y_shape, @@ -52,26 +50,20 @@ xla::StatusOr BatchDot( // The batch dimensions must be equal and the matrix dimensions must be // valid. - std::vector dimensions; - int64 batch_count = 1; + std::vector batch_dimension_numbers; for (int i = 0; i < ndims - 2; ++i) { - int64 x_size = x_shape->dimensions(i); - int64 y_size = y_shape->dimensions(i); - if (x_size != y_size) { + if (x_shape->dimensions(i) != y_shape->dimensions(i)) { return errors::InvalidArgument( "Dimension ", i, " of inputs to BatchedDot must be equal: ", xla::ShapeUtil::HumanString(*x_shape), " vs ", xla::ShapeUtil::HumanString(*y_shape)); } - dimensions.push_back(x_size); - batch_count *= x_size; + batch_dimension_numbers.push_back(i); } int x_inner_dim = transpose_x ? (ndims - 2) : (ndims - 1); int y_inner_dim = transpose_y ? (ndims - 1) : (ndims - 2); - int64 x_inner_dim_size = x_shape->dimensions(x_inner_dim); - int64 y_inner_dim_size = y_shape->dimensions(y_inner_dim); - if (x_inner_dim_size != y_inner_dim_size) { + if (x_shape->dimensions(x_inner_dim) != y_shape->dimensions(y_inner_dim)) { return errors::InvalidArgument( "Dimensions ", x_inner_dim, " and ", y_inner_dim, " of arguments to BatchedDot must be equal: ", @@ -80,75 +72,46 @@ xla::StatusOr BatchDot( " transpose: ", transpose_y); } - // If there are no batch dimensions, use a regular Dot. This case exists - // to improve the readability of the emitted graphs. - if (dimensions.empty()) { - auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; - auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; - return builder->Dot(lhs, rhs); + // Check for zero lhs/rhs dim size. + if (xla::ShapeUtil::HasZeroElements(*x_shape) || + xla::ShapeUtil::HasZeroElements(*y_shape)) { + std::vector dimensions(batch_dimension_numbers.size()); + for (int i = 0; i < batch_dimension_numbers.size(); ++i) { + dimensions[i] = x_shape->dimensions(batch_dimension_numbers[i]); + } + int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); + int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); + dimensions.push_back(x_shape->dimensions(x_outer_dim)); + dimensions.push_back(y_shape->dimensions(y_outer_dim)); + return builder->Broadcast( + builder->ConstantLiteral(xla::Literal::Zero(x_shape->element_type())), + dimensions); } - int x_outer_dim = transpose_x ? (ndims - 1) : (ndims - 2); - int y_outer_dim = transpose_y ? (ndims - 2) : (ndims - 1); - dimensions.push_back(x_shape->dimensions(x_outer_dim)); - dimensions.push_back(y_shape->dimensions(y_outer_dim)); - - if (x_shape->element_type() == xla::C64 && transpose_x) { + if (x_shape->element_type() == xla::C64 && conjugate_x) { x = builder->Conj(x); } - if (y_shape->element_type() == xla::C64 && transpose_y) { + if (y_shape->element_type() == xla::C64 && conjugate_y) { y = builder->Conj(y); } - // Reshape input tensors into 3D tensors by flattening the batch - // dimensions. This makes it easier to unroll the batch dimension. - auto x_flat = - builder->Reshape(x, {batch_count, x_shape->dimensions(ndims - 2), - x_shape->dimensions(ndims - 1)}); - auto y_flat = - builder->Reshape(y, {batch_count, y_shape->dimensions(ndims - 2), - y_shape->dimensions(ndims - 1)}); - - // Slice batches into individual matrices and multiply them. - std::vector out_slices; - for (int64 i = 0; i < batch_count; ++i) { - // Slice off individual matrices and reshape to 2D tensors. - auto x_slice = builder->Slice( - x_flat, {i, 0, 0}, - {i + 1, x_shape->dimensions(ndims - 2), x_shape->dimensions(ndims - 1)}, - {1, 1, 1}); - x_slice = builder->Reshape(x_slice, {x_shape->dimensions(ndims - 2), - x_shape->dimensions(ndims - 1)}); - auto y_slice = builder->Slice( - y_flat, {i, 0, 0}, - {i + 1, y_shape->dimensions(ndims - 2), y_shape->dimensions(ndims - 1)}, - {1, 1, 1}); - y_slice = builder->Reshape(y_slice, {y_shape->dimensions(ndims - 2), - y_shape->dimensions(ndims - 1)}); - - // Transpose if needed. - auto lhs = transpose_x ? builder->Transpose(x_slice, {1, 0}) : x_slice; - auto rhs = transpose_y ? builder->Transpose(y_slice, {1, 0}) : y_slice; - - // Multiply matrices and add an outer singleton dimension to the output - // so we can concatenate along the flattened batch dimension later. - auto out = builder->Dot(lhs, rhs); - out = builder->Reshape(out, - {1, dimensions[ndims - 2], dimensions[ndims - 1]}); - out_slices.push_back(out); + // If there are no batch dimensions, use a regular Dot. + // TODO(b/69062148) Remove this code when Dot emitters can be passed + // dimensions to transpose directly (i.e. without requiring a Transpose HLO). + if (batch_dimension_numbers.empty()) { + auto lhs = transpose_x ? builder->Transpose(x, {1, 0}) : x; + auto rhs = transpose_y ? builder->Transpose(y, {1, 0}) : y; + return builder->Dot(lhs, rhs); } - // Concatenate output slices and reshape to original number of dimensions. - xla::ComputationDataHandle data; - if (out_slices.empty()) { - // It is illegal to pass an empty list to ConcatInDim. - // The batch count is empty, so both inputs must have zero elements. - // Arbitrarily use the left input as the argument to Reshape(). - data = x; - } else { - data = builder->ConcatInDim(out_slices, 0); + xla::DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(x_inner_dim); + dot_dnums.add_rhs_contracting_dimensions(y_inner_dim); + for (auto batch_dimension_number : batch_dimension_numbers) { + dot_dnums.add_lhs_batch_dimensions(batch_dimension_number); + dot_dnums.add_rhs_batch_dimensions(batch_dimension_number); } - return builder->Reshape(data, dimensions); + return builder->DotGeneral(x, y, dot_dnums); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/batch_dot.h b/tensorflow/compiler/tf2xla/lib/batch_dot.h index b46bc7417d29dc5b7e9649ac28cc78b57d4b619c..b230e885f10f45a78cdd6e455da3ba55ce589b96 100644 --- a/tensorflow/compiler/tf2xla/lib/batch_dot.h +++ b/tensorflow/compiler/tf2xla/lib/batch_dot.h @@ -27,7 +27,10 @@ namespace tensorflow { // viewed as an element of a batch), and arranges the individual results // in a single output tensor of the same batch size. Each of the // individual slices can optionally be transposed before multiplication by -// setting the `transpose_x` or `transpose_y` flag to `true`. +// setting the `transpose_x` or `transpose_y` flag to `true`. Similarly, each +// can be elementwise-complex-conjugated by setting the `conjugate_x` or +// `conjugate_y` flag to `true`. To apply a Hermitian adjoint to `x`, set both +// `transpose_x` and `conjugate_x` to `true`, and analogously for `y`. // // The input tensors `x` and `y` are 2-D or higher with shape `[..., r_x, c_x]` // and `[..., r_y, c_y]`. @@ -40,11 +43,10 @@ namespace tensorflow { // It is computed as: // // output[..., :, :] = matrix(x[..., :, :]) * matrix(y[..., :, :]) -// TODO(phawkins): add an option to take the complex conjugate of the LHS or -// RHS. xla::StatusOr BatchDot( xla::ComputationBuilder* builder, xla::ComputationDataHandle x, - xla::ComputationDataHandle y, bool transpose_x, bool transpose_y); + xla::ComputationDataHandle y, bool transpose_x, bool transpose_y, + bool conjugate_x = false, bool conjugate_y = false); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.cc b/tensorflow/compiler/tf2xla/lib/cholesky.cc index b3cc489adf6042acb3f56b3a0a6c8fbe43bde629..e795701181dd80a2ff544743d513bffd52fd2399 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.cc +++ b/tensorflow/compiler/tf2xla/lib/cholesky.cc @@ -71,11 +71,14 @@ xla::StatusOr CholeskyUnblocked( SliceInMinorDims(builder, l, {j + 1, 0}, {n, j})); TF_ASSIGN_OR_RETURN(auto r_squared, BatchDot(builder, r, r, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false)); new_d_squared = builder->Sub(new_d_squared, r_squared); TF_ASSIGN_OR_RETURN(br, BatchDot(builder, b, r, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, + /*conjugate_x=*/false, + /*conjugate_y=*/false)); } auto new_d_inv = builder->Pow( new_d_squared, FloatLiteral(builder, shape->element_type(), -0.5)); @@ -134,7 +137,8 @@ xla::StatusOr Cholesky( SliceInMinorDims(builder, l, {i, 0}, {i + k, i})); TF_ASSIGN_OR_RETURN(auto delta, BatchDot(builder, lhs, rhs, /*transpose_x=*/false, - /*transpose_y=*/true)); + /*transpose_y=*/true, /*conjugate_x=*/false, + /*conjugate_y=*/false)); TF_ASSIGN_OR_RETURN(auto before, SliceInMinorDims(builder, a, {i, i}, {n, i + k})); TF_ASSIGN_OR_RETURN( @@ -155,6 +159,10 @@ xla::StatusOr Cholesky( SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); TF_ASSIGN_OR_RETURN(auto update, TriangularSolve(builder, factorized, panel, + /*left_side=*/false, + /*lower=*/true, + /*transpose_a=*/true, + /*conjugate_a=*/false, /*block_size=*/8)); TF_ASSIGN_OR_RETURN( l, UpdateSliceInMinorDims(builder, l, update, {i + k, i})); diff --git a/tensorflow/compiler/tf2xla/lib/cholesky.h b/tensorflow/compiler/tf2xla/lib/cholesky.h index 2bead7359baaf3582c1230adf0cd4a90046859d2..e083a383be4be0d1b556b63214fe5f70323b4149 100644 --- a/tensorflow/compiler/tf2xla/lib/cholesky.h +++ b/tensorflow/compiler/tf2xla/lib/cholesky.h @@ -29,6 +29,7 @@ namespace tensorflow { // the block size to use. // TODO(phawkins): check for negative values on the diagonal and return an // error, instead of silently yielding NaNs. +// TODO(mattjj): handle the complex Hermitian case xla::StatusOr Cholesky( xla::ComputationBuilder* builder, xla::ComputationDataHandle a, int64 block_size = 256); diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc new file mode 100644 index 0000000000000000000000000000000000000000..6009243f9774eea24e8049e2bd50fe32f291132f --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -0,0 +1,189 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/scatter.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +xla::StatusOr XlaScatter( + const xla::ComputationDataHandle& buffer, + const xla::ComputationDataHandle& updates, + const xla::ComputationDataHandle& indices, bool indices_are_vectors, + const std::function& combiner, + xla::ComputationBuilder* builder) { + TF_ASSIGN_OR_RETURN(std::unique_ptr buffer_shape, + builder->GetShape(buffer)); + TF_ASSIGN_OR_RETURN(std::unique_ptr updates_shape, + builder->GetShape(updates)); + TF_ASSIGN_OR_RETURN(std::unique_ptr indices_shape, + builder->GetShape(indices)); + gtl::ArraySlice indices_dims = + xla::AsInt64Slice(indices_shape->dimensions()); + gtl::ArraySlice buffer_dims = + xla::AsInt64Slice(buffer_shape->dimensions()); + + // If the indices are N-dimensional, the minor dimension of indices contains + // the indices to update. Otherwise the indices are all scalars. + int64 num_index_dims = 1; + if (indices_are_vectors) { + TF_RET_CHECK(!indices_dims.empty()); + num_index_dims = indices_dims.back(); + if (num_index_dims > xla::ShapeUtil::Rank(*buffer_shape)) { + return errors::InvalidArgument( + "The size of the minor dimension of the indices (shape: ", + xla::ShapeUtil::HumanString(*indices_shape), + ") must be <= the rank of the buffer (shape: ", + xla::ShapeUtil::HumanString(*buffer_shape), ")"); + } + indices_dims.pop_back(); + } + + int64 num_indices = 1; + for (int64 dim : indices_dims) { + num_indices *= dim; + } + + // Degenerate case: nothing to update. Return the buffer unchanged. + if (num_indices == 0) { + return buffer; + } + + // If any of the indexed dimensions are zero in the buffer, the update cannot + // succeed since it updates a slice of size 1. + for (int64 i = 0; i < num_index_dims; ++i) { + if (xla::ShapeUtil::GetDimension(*buffer_shape, i) == 0) { + return errors::InvalidArgument( + "Scatter dimension ", i, " is of size zero in tensor with shape ", + xla::ShapeUtil::HumanString(*buffer_shape)); + } + } + + // Shape of the non-indexed dimensions of the buffer. + std::vector buffer_shape_post_axes( + buffer_dims.begin() + num_index_dims, buffer_dims.end()); + + // Flatten the major dimensions of indices and updates into a single dimension + // for ease of iteration. + std::vector flat_indices_shape({num_indices}); + if (indices_are_vectors) { + flat_indices_shape.push_back(num_index_dims); + } + + std::vector flat_updates_shape({num_indices}); + flat_updates_shape.insert(flat_updates_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + + // Construct the initial values of the loop-carried Tensors. + auto flat_indices = builder->Reshape(indices, flat_indices_shape); + auto flat_updates = builder->Reshape(updates, flat_updates_shape); + auto init = {flat_indices, flat_updates, buffer}; + + // Constructs the loop body. The implementation of scatter is essentially: + // for i in range(num_indices): + // index = dynamic-slice(indices, i) + // update = dynamic-slice(updates, i) + // buffer = dynamic-update-slice(buffer, update, index) + auto body_fn = [&](xla::ComputationDataHandle i, + gtl::ArraySlice loop_vars, + xla::ComputationBuilder* body_builder) { + auto indices = loop_vars[0]; + auto updates = loop_vars[1]; + auto buffer = loop_vars[2]; + + auto zero_index = body_builder->ConstantLiteral( + xla::Literal::Zero(indices_shape->element_type())); + + // Slice the i-th index from the indices array. + xla::ComputationDataHandle index; + auto indices_offset = body_builder->Reshape(i, {1}); + if (indices_are_vectors) { + indices_offset = body_builder->Pad(indices_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, 1}})); + + index = body_builder->DynamicSlice(indices, indices_offset, + {1, num_index_dims}); + index = body_builder->Collapse(index, {0, 1}); + } else { + index = body_builder->DynamicSlice(indices, indices_offset, {1}); + } + + // Discard updates with negative indices, since some users expect this. + auto index_in_range = + body_builder->ReduceAll(body_builder->Le(zero_index, index), + body_builder->ConstantR0(true), + xla::CreateScalarAndComputation(body_builder)); + + index = body_builder->Pad( + index, zero_index, + xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); + + // Slice the i-th index from the updates array. + auto updates_offset = body_builder->Reshape(i, {1}); + updates_offset = body_builder->Pad( + updates_offset, zero_index, + xla::MakeEdgePaddingConfig({{0, buffer_shape_post_axes.size()}})); + std::vector flat_updates_slice_shape({1}); + flat_updates_slice_shape.insert(flat_updates_slice_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + auto update = body_builder->DynamicSlice(updates, updates_offset, + flat_updates_slice_shape); + + // Unflatten the major (iteration) dimensions of the slice to their original + // shape. + std::vector updates_slice_shape(num_index_dims, 1); + updates_slice_shape.insert(updates_slice_shape.end(), + buffer_shape_post_axes.begin(), + buffer_shape_post_axes.end()); + update = body_builder->Reshape(update, updates_slice_shape); + + // Apply the update to the buffer. If there is a combiner, use it to merge + // the current values with the update. + if (combiner) { + auto current_value = + body_builder->DynamicSlice(buffer, index, updates_slice_shape); + update = combiner(current_value, update, body_builder); + } + // Apply the update if it is in range. + buffer = body_builder->Select( + index_in_range, body_builder->DynamicUpdateSlice(buffer, update, index), + buffer); + + return std::vector{indices, updates, buffer}; + }; + + TF_ASSIGN_OR_RETURN( + auto outputs, XlaForEachIndex(num_indices, indices_shape->element_type(), + body_fn, init, "scatter", builder)); + return outputs[2]; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h new file mode 100644 index 0000000000000000000000000000000000000000..41e6d3b195ebf90662c7b9b42c53fcb0133ab29e --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace tensorflow { + +// Builds an XLA computation that performs a scatter operation on `buffer`, +// returning an updated buffer. +// For each i0, i1, ..., sets +// buffer[indices[i0, i1, ...], ...] := updates[i0, i1, ...] +// +// If `indices_are_vectors` is false, then each index in indices is a scalar, +// and the shape of `indices` must be a prefix of the shape of updates. +// Otherwise, `indices_are_vectors`, then indices are multidimensional and the +// minor dimension of `indices` represents a vector of indices. +// +// If any indices are negative, the corresponding update is discarded. +// +// If a `combiner` is provided, updates are combined with the existing values in +// the buffer using the combiner function. Otherwise, the updates replace the +// existing values. The order of updates is implementation-defined. +xla::StatusOr XlaScatter( + const xla::ComputationDataHandle& buffer, + const xla::ComputationDataHandle& updates, + const xla::ComputationDataHandle& indices, bool indices_are_vectors, + const std::function& combiner, + xla::ComputationBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_SCATTER_H_ diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc index 579944c3a381e7018b7fee5013d0509158ce21cc..7f72a6073df218b9e2bd4cc0c0b5bb10b5cd4b84 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc @@ -24,13 +24,15 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { xla::StatusOr TriangularSolve( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, int64 block_size) { + xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size) { TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, builder->GetShape(a)); TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, @@ -60,14 +62,15 @@ xla::StatusOr TriangularSolve( batch_dimensions.push_back(a_size); } - const int64 n = xla::ShapeUtil::GetDimension(*a_shape, -1); - const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); - if (n != xla::ShapeUtil::GetDimension(*a_shape, -2)) { + if (xla::ShapeUtil::GetDimension(*a_shape, -1) != + xla::ShapeUtil::GetDimension(*a_shape, -2)) { return errors::InvalidArgument( "The 'a' arguments to TriangularSolve must be square matrices: ", xla::ShapeUtil::HumanString(*a_shape)); } - if (n != xla::ShapeUtil::GetDimension(*b_shape, -1)) { + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); + if ((left_side ? m : n) != xla::ShapeUtil::GetDimension(*a_shape, -1)) { return errors::InvalidArgument( "Arguments to TriangularSolve have incompatible matrix shapes: ", xla::ShapeUtil::HumanString(*a_shape), " vs ", @@ -89,6 +92,14 @@ xla::StatusOr TriangularSolve( return output; }; + // Applies a complex conjugation operation if `a` is complex and `conjugate_a` + // is true, otherwise returns its argument. + auto maybe_conj = [&](xla::ComputationBuilder* builder, + xla::ComputationDataHandle x) { + auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + return perform_conj ? builder->Conj(x) : x; + }; + std::map base_computations; auto get_base_triangular_solve = [&](int k) -> xla::StatusOr { @@ -103,19 +114,35 @@ xla::StatusOr TriangularSolve( prepend_batch_dims({k, k})), "a"); + std::array b_lastd; + if (left_side) { + b_lastd = {k, n}; + } else { + b_lastd = {m, k}; + } auto b_param = sub->Parameter(1, xla::ShapeUtil::MakeShape(b_shape->element_type(), - prepend_batch_dims({m, k})), + prepend_batch_dims(b_lastd)), "b"); - // TODO(phawkins): it might make sense to use a while loop here, rather - // than unrolling. - // TODO(phawkins): the left-looking variant of the algorithm might be more - // efficient at block size 1. - TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, - /*block_size=*/1) - .status()); + // We use a left-looking subroutine on the block diagonal in some common + // cases, while falling back to a recursive call in unsupported cases. The + // left-looking subroutine is written with a While loop and so yields much + // faster compile times. Moreover, the left-looking variant can give + // higher performance on smaller (sub)problems. + if (left_side && lower) { + TF_RETURN_IF_ERROR(TriangularSolveLeftLooking(sub.get(), a_param, + b_param, transpose_a, + conjugate_a) + .status()); + } else { + TF_RETURN_IF_ERROR(TriangularSolve(sub.get(), a_param, b_param, + left_side, lower, transpose_a, + conjugate_a, + /*block_size=*/1) + .status()); + } TF_ASSIGN_OR_RETURN(computation, sub->Build()); } @@ -129,47 +156,396 @@ xla::StatusOr TriangularSolve( // Goto, Kazushige, and Robert Van De Geijn. "High-performance implementation // of the level-3 BLAS." ACM Transactions on Mathematical Software (TOMS) 35.1 // (2008): 4. - for (int64 i = 0; i < n; i += block_size) { - int64 k = std::min(block_size, n - i); - // if k > 1: - // output[..., :, i:i+k] = triangular_solve( - // a[..., i:i+k, ..., i:i+k], b[..., :, i:i+k], side='Right', - // kind='Lower', transpose=True, block_size=1) - // else: - // output[..., :, i] = b[..., :, i] / a[..., i, i] + // In the code comments below, T = lambda x: np.swapaxes(x, -1, -2) if + // conjugate_a is False, or T = lambda x: np.conj(np.swapaxes(x, -1, -2)) if + // conjugate_a is True. + + if (!left_side && lower == transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < n; i += block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, i+k:] -= np.matmul(output[..., :, i:i+k], a_slice_2) + if (i + k < n) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); + } else { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, n})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, + BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, i + k}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + } + } + + } else if (left_side && lower != transpose_a) { + // for i in range(0, a.shape[-1], block_size): + for (int64 i = 0; i < m; i += block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] = triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + + // if i + k < a.shape[-1]: + // a_slice_2 = a[..., i+k:, i:i+k] if lower else a[..., i:i+k, i+k:] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., i+k:, :] -= np.matmul(a_slice_2, output[..., i:i+k, :]) + if (i + k < m) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i + k, i}, {m, i + k})); + } else { + TF_ASSIGN_OR_RETURN( + a_slice_2, SliceInMinorDims(builder, a, {i, i + k}, {i + k, m})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {i + k, 0}, {m, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {i + k, 0})); + } + } + } else if (!left_side && lower != transpose_a) { + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = xla::RoundUpToNearest(n, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, n - i); + + // output[..., :, i:i+k] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., :, i:i+k], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {0, i}, {m, i + k})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {0, i})); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :, :i] -= np.matmul(out[..., :, i:i+k], a_slice_2) + if (i - k >= 0) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, + BatchDot(builder, update, a_slice_2, + /*transpose_x=*/false, + /*transpose_y=*/transpose_a, + /*conjugate_x=*/false, + /*conjugate_y=*/conjugate_a)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, 0}, {m, i})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } + } + } else { // left_side && lower == transpose_a + // for i in reversed(range(0, a.shape[-1], block_size)): + const int64 last_blk_ix = xla::RoundUpToNearest(m, block_size) - block_size; + for (int64 i = last_blk_ix; i >= 0; i -= block_size) { + int64 k = std::min(block_size, m - i); + + // output[..., i:i+k, :] triangular_solve( + // a[..., i:i+k, i:i+k], b[..., i:i+k, :], ..., block_size=1) + TF_ASSIGN_OR_RETURN(auto a_slice, + SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + TF_ASSIGN_OR_RETURN(auto b_slice, + SliceInMinorDims(builder, b, {i, 0}, {i + k, n})); + xla::ComputationDataHandle update; + if (k > 1) { + TF_ASSIGN_OR_RETURN(xla::Computation * solve, + get_base_triangular_solve(k)); + update = builder->Call(*solve, {a_slice, b_slice}); + } else { + update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + } + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + + // if i - k >= 0: + // a_slice_2 = a[..., i:i+k, :i] if lower else a[..., :i, i:i+k] + // a_slice_2 = T(a_slice_2) if transpose_a else a_slice_2 + // b[..., :i, :] -= np.matmul(a_slice_2, out[..., i:i+k, :]) + if (i - k >= 0) { + xla::ComputationDataHandle a_slice_2; + if (lower) { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {i, 0}, {i + k, i})); + } else { + TF_ASSIGN_OR_RETURN(a_slice_2, + SliceInMinorDims(builder, a, {0, i}, {i, i + k})); + } + + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, a_slice_2, update, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + TF_ASSIGN_OR_RETURN(auto b_slice_2, + SliceInMinorDims(builder, b, {0, 0}, {i, n})); + b_update = builder->Sub(b_slice_2, b_update); + TF_ASSIGN_OR_RETURN( + b, UpdateSliceInMinorDims(builder, b, b_update, {0, 0})); + } + } + } + + return output; +} + +xla::StatusOr TriangularSolveLeftLooking( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a) { + TF_ASSIGN_OR_RETURN(std::unique_ptr a_shape, + builder->GetShape(a)); + TF_ASSIGN_OR_RETURN(std::unique_ptr b_shape, + builder->GetShape(b)); + const int64 m = xla::ShapeUtil::GetDimension(*b_shape, -2); + const int64 n = xla::ShapeUtil::GetDimension(*b_shape, -1); + const int64 ndims = xla::ShapeUtil::Rank(*a_shape); + + std::vector batch_dimensions; + for (int i = 0; i < ndims - 2; ++i) { + int64 a_size = a_shape->dimensions(i); + batch_dimensions.push_back(a_size); + } + + auto prepend_batch_dims = [&](std::array indices) { + std::vector output(ndims); + std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin()); + std::copy(indices.begin(), indices.end(), + output.begin() + batch_dimensions.size()); + return output; + }; + + auto maybe_conj = [&](xla::ComputationBuilder* builder, + xla::ComputationDataHandle x) { + auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a; + return perform_conj ? builder->Conj(x) : x; + }; + + // The main computation is performed in a While loop. + + // Allocate the output and set its first or last row, + // output = np.zeros_like(b) + // if transpose_a: + // output[..., m-1:, :] = b[..., m-1:, :] / a[..., m-1:, m-1:] + // else: + // output[..., :1, :] = b[..., :1, :] / a[..., :1, :1] + xla::ComputationDataHandle output = Zeros(builder, *b_shape); + { + auto i = transpose_a ? m - 1 : 0; TF_ASSIGN_OR_RETURN(auto a_slice, - SliceInMinorDims(builder, a, {i, i}, {i + k, i + k})); + SliceInMinorDims(builder, a, {i, i}, {i + 1, i + 1})); TF_ASSIGN_OR_RETURN(auto b_slice, - SliceInMinorDims(builder, b, {0, i}, {m, i + k})); - xla::ComputationDataHandle update; - if (k > 1) { - TF_ASSIGN_OR_RETURN(xla::Computation * solve, - get_base_triangular_solve(k)); - update = builder->Call(*solve, {a_slice, b_slice}); + SliceInMinorDims(builder, b, {i, 0}, {i + 1, n})); + auto update = builder->Div(b_slice, maybe_conj(builder, a_slice)); + TF_ASSIGN_OR_RETURN( + output, UpdateSliceInMinorDims(builder, output, update, {i, 0})); + } + + // Construct the initial loop carry tuple, + // if transpose_a: + // init = (m-2, output, a, b) + // else: + // init = (1, output, a, b) + std::vector tuple_shapes = { + // The loop iteration counter is a scalar, incremented each iteration. + xla::ShapeUtil::MakeShape(xla::S32, {}), + // The output has the shape of b, with one row updated each iteration. + *b_shape, + // The coefficient matrix a is a loop invariant. + *a_shape, + // The right-hand-side matrix b is a loop invariant. + *b_shape}; + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); + auto init_i = builder->ConstantR0(transpose_a ? m - 2 : 1); + auto init = builder->Tuple({init_i, output, a, b}); + + // Construct the loop condition function, + // def cond_fun(loop_carry): + // i, output, a, b = loop_carry + // return i >= 0 if transpose_a else i < m + std::unique_ptr condb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileCond"); + { + auto i = condb->GetTupleElement( + condb->Parameter(0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"), + 0); + if (transpose_a) { + condb->Ge(i, condb->ConstantR0(0)); } else { - update = builder->Div(b_slice, a_slice); + condb->Lt(i, condb->ConstantR0(m)); } + } + TF_ASSIGN_OR_RETURN(auto cond, condb->Build()); - TF_ASSIGN_OR_RETURN( - output, UpdateSliceInMinorDims(builder, output, update, {0, i})); - // b[..., :, i+k:] -= np.dot(output[..., :, i:i+k], - // np.transpose(..., a[i+k:, i:i+k])) - if (i + k < n) { - TF_ASSIGN_OR_RETURN(auto a_slice_2, - SliceInMinorDims(builder, a, {i + k, i}, {n, i + k})); - TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(builder, update, a_slice_2, - /*transpose_x=*/false, - /*transpose_y=*/true)); - - TF_ASSIGN_OR_RETURN(auto b_slice_2, - SliceInMinorDims(builder, b, {0, i + k}, {m, n})); - b_update = builder->Sub(b_slice_2, b_update); - TF_ASSIGN_OR_RETURN( - b, UpdateSliceInMinorDims(builder, b, b_update, {0, i + k})); + // Construct the loop body function, + // def body_fun(loop_carry): + // i, output, a, b = loop_carry + // if transpose_a: + // a_row = np.swapaxes(a[..., i+1:, i:i+1], -1 -2) + // else: + // a_row = a[..., i:i+1, :i] + // result_row = b[..., i:i+1, :] - np.matmul(a_row, output[..., :, :]) + // output[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + // if transpose_a: + // return (i - 1, output, a, b) + // else: + // return (i + 1, output, a, b) + // We have to do some extra FLOPs propagating zeros in the matrix multiply + // because we can't have the size of its arguments depend on the loop counter. + std::unique_ptr bodyb = + builder->CreateSubBuilder("TriangularSolveLeftLookingWhileBody"); + { + auto input_tuple = bodyb->Parameter(0, tuple_shape, + "TriangularSolveLeftLookingWhileTuple"); + + // i, output, a, b = loop_carry + auto i = bodyb->GetTupleElement(input_tuple, 0); + auto body_out = bodyb->GetTupleElement(input_tuple, 1); + auto body_a = bodyb->GetTupleElement(input_tuple, 2); + auto body_b = bodyb->GetTupleElement(input_tuple, 3); + auto zero = bodyb->ConstantR0(0); + + // Set up some helper functions. + auto prepend_zeros = [&](std::array starts) { + auto zero = bodyb->Reshape(bodyb->ConstantR0(0), {1}); + std::vector padded_starts(ndims, zero); + padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1}); + padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1}); + return bodyb->ConcatInDim(padded_starts, 0); + }; + + auto dynamic_slice = [&](xla::ComputationDataHandle x, + std::array starts, + std::array sizes) { + auto padded_starts = prepend_zeros(starts); + auto padded_sizes = prepend_batch_dims(sizes); + return bodyb->DynamicSlice(x, padded_starts, padded_sizes); + }; + + auto update = [&](xla::ComputationDataHandle x, + xla::ComputationDataHandle update, + std::array starts) { + auto padded_starts = prepend_zeros(starts); + return bodyb->DynamicUpdateSlice(x, update, padded_starts); + }; + + // We'd like to implement this: + // if transpose_a: + // a_row = T(a[..., i+1:, i:i+1]) + // result_row = (b[..., i:i+1, :] + // - np.matmul(a_row, body_out[..., i+1:, :])) + // else: + // result_row = (b[..., i:i+1, :] + // - np.matmul(a[..., i:i+1, :i], body_out[..., :i, :])) + // But since we can't have intermediate array sizes depend on the loop + // counter, we instead exploit the fact that we initialized the output to + // all zeros and use that as zero-padding (doing unnecessary FLOPs). + xla::ComputationDataHandle a_row; + if (transpose_a) { + a_row = dynamic_slice(body_a, {zero, i}, {m, 1}); + } else { + a_row = dynamic_slice(body_a, {i, zero}, {1, m}); } + TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out, + /*transpose_x=*/transpose_a, + /*transpose_y=*/false, + /*conjugate_x=*/conjugate_a, + /*conjugate_y=*/false)); + auto result_row = + bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update); + + // body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1] + auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1}); + auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt)); + body_out = update(body_out, div_result, {i, zero}); + + // if transpose_a: + // return (i - 1, body_out, a, b) + // else: + // return (i + 1, body_out, a, b) + auto next_i = bodyb->Add(i, bodyb->ConstantR0(transpose_a ? -1 : 1)); + bodyb->Tuple({next_i, body_out, body_a, body_b}); } - return output; + TF_ASSIGN_OR_RETURN(auto body, bodyb->Build()); + + // Construct the While loop and return the result, + // return while_loop(cond_fun, body_fun, init)[1] + auto triangular_solve_left_looking_while = builder->While(cond, body, init); + return builder->GetTupleElement(triangular_solve_left_looking_while, 1); } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.h b/tensorflow/compiler/tf2xla/lib/triangular_solve.h index 501d026411c80359c7efa406ece5929a2e46ac1f..e32223bfdddda800b1fd4de3e4f0c8061e0f81d8 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve.h +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.h @@ -21,25 +21,50 @@ limitations under the License. namespace tensorflow { -// Solves systems of linear equations with upper or lower triangular matrices by -// backsubstitution. +// Solves systems of linear equations with lower or upper triangular coefficient +// matrices by forward- or back-substitution. Broadcasting along leading +// dimensions, this routine solves one of the matrix systems +// `op(a) * x = b`, or `x * op(a) = b`, +// for the variable `x` given `a` and `b`, where `op(a)` is either +// `op(a) = a`, or `op(a) = transpose(a)`, or `op(a) = conj(transpose(a))`. +// That is, the innermost matrices in the output satisfy a scalar system +// depending on the value of the value of (left_side, transpose_a, conjugate_a) +// according to: +// (F, F, F) => `output[..., i, k] a[..., k, j] = b[..., i, j]`, +// (F, F, T) => `output[..., i, k] a*[..., k, j] = b[..., i, j]`, +// (F, T, F) => `output[..., i, k] a[..., j, k] = b[..., i, j]`, +// (F, T, T) => `output[..., i, k] a*[..., j, k] = b[..., i, j]`, +// (T, F, F) => ` a[..., i, k] output[..., k, j] = b[..., i, j]`, +// (T, F, T) => `a*[..., i, k] output[..., k, j] = b[..., i, j]`, +// (T, T, F) => ` a[..., i, k] output[..., j, k] = b[..., i, j]`, +// (T, T, T) => `a*[..., i, k] output[..., j, k] = b[..., i, j]`, +// where * denotes complex conjugation and where the index `k` is summed over. // -// `a` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form -// square matrices. The strictly upper triangular part of each inner-most matrix -// is assumed to be zero and not accessed. -// `b` is a tensor of shape `[..., M, K]`. -// -// The innermost matrices in the output satisfy matrix equations -// `output[..., i, j] * adjoint(a[..., k, j]) = b[..., i, k]`. +// `a` is a tensor of shape `[..., M, M]` whose innermost 2 dimensions form +// square matrices. If lower is true (false), then the strictly upper (lower) +// triangular part of each innermost matrix in `a` is assumed to be zero and is +// not accessed. +// `b` is a tensor of shape `[..., M, K]` if left_side is true, otherwise a +// tensor of shape `[..., K, M]`. +// `left_side` is a boolean, indicating whether to solve a system of the form +// op(a) * x = b (true) or x * op(a) = b (false). +// `lower` is a boolean, indicating whether the argument `a` is lower-triangular +// (true) or upper-triangular (false). +// `transpose_a` is a boolean indicating whether the matrix `a` is transposed. +// `conjugate_a` is a boolean indicating whether the entries of `a` are complex +// conjugated (independently of whether they are transposed), so that when both +// transpose_a and conjugate_a are true the effect is a Hermitian adjoint. // // Uses a blocked algorithm if `block_size` is > 1; if block_size == 1 then no // blocking is used. -// TODO(phawkins): equivalent to the BLAS TRSM routine with side=right, -// kind=lower, and transposed_a=true. Implement the other possible combinations -// of side, kind and transposed_a. xla::StatusOr TriangularSolve( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, - xla::ComputationDataHandle b, int64 block_size = 256); + xla::ComputationDataHandle b, bool left_side, bool lower, bool transpose_a, + bool conjugate_a, int64 block_size = 256); + +xla::StatusOr TriangularSolveLeftLooking( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& a, + const xla::ComputationDataHandle& b, bool transpose_a, bool conjugate_a); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc index 671d9aa4fe0c042a3cc44468074653d51c2be75d..661707062916263fd0d5d935ce41698a7655df02 100644 --- a/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc +++ b/tensorflow/compiler/tf2xla/lib/triangular_solve_test.cc @@ -27,32 +27,134 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { namespace { using TriangularSolveTest = xla::ClientLibraryTestBase; +using TriangularSolveLeftLookingTest = xla::ClientLibraryTestBase; +using complex64 = xla::complex64; -XLA_TEST_F(TriangularSolveTest, Simple) { +xla::Array2D AValsLower() { + return {{2, 0, 0, 0}, {3, 6, 0, 0}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +xla::Array2D AValsUpper() { + return {{2, 3, 4, 5}, {0, 6, 7, 8}, {0, 0, 9, 10}, {0, 0, 0, 11}}; +} + +xla::Array2D BValsRight() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeft() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsLowerComplex() { + return {{2, 0, 0, 0}, + {complex64(3, 1), 6, 0, 0}, + {4, complex64(7, 2), 9, 0}, + {5, 8, complex64(10, 3), 11}}; +} + +xla::Array2D AValsUpperComplex() { + return {{2, 3, complex64(4, 3), 5}, + {0, 6, complex64(7, 2), 8}, + {0, 0, complex64(9, 1), 10}, + {0, 0, 0, 11}}; +} + +xla::Array2D BValsRightComplex() { + return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}; +} + +xla::Array2D BValsLeftComplex() { + return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}}; +} + +xla::Array2D AValsFull() { + return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}}; +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTranspose) { xla::ComputationBuilder builder(client_, TestName()); - xla::Array2D a_vals({ - {2, 0, 0, 0}, - {3, 6, 0, 0}, - {4, 7, 9, 0}, - {5, 8, 10, 11}, + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 0.08333334, 0.04629629, 0.03367003}, + {2.5, -0.25, -0.1388889, -0.1010101}, + {4.5, -0.58333331, -0.32407406, -0.23569024}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); - xla::Array2D b_vals({ - {1, 2, 3, 4}, - {5, 6, 7, 8}, - {9, 10, 11, 12}, + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.16414141, -0.06902357, -0.07070707, 0.36363636}, + {0.64393939, 0.06565657, -0.03030303, 0.72727273}, + {1.4520202, 0.2003367, 0.01010101, 1.09090909}, }); + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightUpperNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + xla::ComputationDataHandle a, b; - auto a_data = CreateR2Parameter(a_vals, 0, "a", &builder, &a); - auto b_data = CreateR2Parameter(b_vals, 1, "b", &builder, &b); - auto result = TriangularSolve(&builder, a, b, /*block_size=*/2); + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsRight(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); TF_ASSERT_OK(result.status()); xla::Array2D expected({ @@ -62,7 +164,201 @@ XLA_TEST_F(TriangularSolveTest, Simple) { }); ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, - xla::ErrorSpec(2e-3, 2e-3)); + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftLowerNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/true, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperNotranspose) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsUpper(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/false, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {-0.89646465, -0.69444444, -0.49242424}, + {-0.27441077, -0.24074074, -0.20707071}, + {-0.23232323, -0.22222222, -0.21212121}, + {0.90909091, 1., 1.09090909}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleRightLowerTransposeConjugate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = + CreateR2Parameter(AValsLowerComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsRightComplex(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/false, /*lower=*/true, + /*transpose_a=*/true, /*conjugate_a=*/true, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, complex64(0.08333333, 0.08333333), + complex64(0.02777778, -0.0462963), complex64(0.06313131, -0.01094276)}, + {2.5, complex64(-0.25, 0.41666667), complex64(-0.23148148, -0.37962963), + complex64(0.08670034, -0.02104377)}, + {4.5, complex64(-0.58333333, 0.75), complex64(-0.49074074, -0.71296296), + complex64(0.11026936, -0.03114478)}, + }); + + ComputeAndCompareR2(&builder, expected, + {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveTest, SimpleLeftUpperTransposeNoconjugate) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = + CreateR2Parameter(AValsUpperComplex(), 0, "a", &builder, &a); + auto b_data = + CreateR2Parameter(BValsLeftComplex(), 1, "b", &builder, &b); + auto result = TriangularSolve(&builder, a, b, + /*left_side=*/true, /*lower=*/false, + /*transpose_a=*/true, /*conjugate_a=*/false, + /*block_size=*/2); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1., 1.5}, + {0.41666667, 0.33333333, 0.25}, + {complex64(0.20020325, -2.81504065e-01), + complex64(0.13821138, -4.22764228e-01), + complex64(0.07621951, -5.64024390e-01)}, + {complex64(0.19678492, 2.55912786e-01), + complex64(0.17738359, 3.84331116e-01), + complex64(0.15798226, 5.12749446e-01)}, + }); + + ComputeAndCompareR2(&builder, expected, + {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveLeftLookingTest, Simple) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsLower(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolveLeftLooking(&builder, a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); +} + +XLA_TEST_F(TriangularSolveLeftLookingTest, NonzeroUpperTriangle) { + xla::ComputationBuilder builder(client_, TestName()); + + xla::ComputationDataHandle a, b; + auto a_data = CreateR2Parameter(AValsFull(), 0, "a", &builder, &a); + auto b_data = CreateR2Parameter(BValsLeft(), 1, "b", &builder, &b); + auto result = TriangularSolveLeftLooking(&builder, a, b, + /*transpose_a=*/false, + /*conjugate_a=*/false); + TF_ASSERT_OK(result.status()); + + xla::Array2D expected({ + {0.5, 1.0, 1.5}, + {0.41666667, 0.33333333, 0.25}, + {0.23148148, 0.18518519, 0.13888889}, + {0.16835017, 0.13468013, 0.1010101}, + }); + + ComputeAndCompareR2(&builder, expected, {a_data.get(), b_data.get()}, + xla::ErrorSpec(1e-2, 1e-2)); } } // namespace diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc index 7ffe0aa6df9b21c4311eb6c8d311fba1e115b3f4..f579669bbd852b514e021ce71d635f8ce5e4fe4d 100644 --- a/tensorflow/compiler/tf2xla/lib/util.cc +++ b/tensorflow/compiler/tf2xla/lib/util.cc @@ -28,7 +28,7 @@ limitations under the License. namespace tensorflow { xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - xla::Shape& shape) { + const xla::Shape& shape) { return builder->Broadcast( builder->ConstantLiteral(xla::Literal::Zero(shape.element_type())), xla::AsInt64Slice(shape.dimensions())); @@ -40,6 +40,9 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, case xla::F16: return builder->ConstantR0(static_cast(value)); break; + case xla::BF16: + return builder->ConstantR0(static_cast(value)); + break; case xla::F32: return builder->ConstantR0(static_cast(value)); break; @@ -54,6 +57,61 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, } } +xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, + int64 value) { + xla::Literal literal; + switch (type) { + case xla::U8: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::U32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::U64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S8: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::S64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::F32: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::F64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::C64: + literal = std::move(*xla::Literal::CreateR0(value)); + break; + case xla::PRED: + LOG(FATAL) << "pred element type is not integral"; + case xla::S16: + case xla::U16: + LOG(FATAL) << "u16/s16 literals not yet implemented"; + case xla::BF16: + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); + break; + case xla::F16: + literal = std::move( + *xla::Literal::CreateR0(static_cast(value))); + break; + case xla::TUPLE: + LOG(FATAL) << "tuple element type is not integral"; + case xla::OPAQUE: + LOG(FATAL) << "opaque element type is not integral"; + default: + LOG(FATAL) << "unhandled element type " << type; + } + return builder->ConstantLiteral(literal); +} + xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, gtl::ArraySlice start, gtl::ArraySlice end) { @@ -104,4 +162,15 @@ xla::StatusOr UpdateSliceInMinorDims( return UpdateSlice(builder, x, update, padded_start); } +xla::StatusOr TransposeInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x) { + TF_ASSIGN_OR_RETURN(std::unique_ptr shape, builder->GetShape(x)); + const int64 n_dims = xla::ShapeUtil::Rank(*shape); + TF_RET_CHECK(n_dims >= 2); + std::vector permutation(n_dims); + std::iota(permutation.begin(), permutation.end(), 0); + std::swap(permutation[n_dims - 1], permutation[n_dims - 2]); + return builder->Transpose(x, permutation); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/util.h b/tensorflow/compiler/tf2xla/lib/util.h index 8fba6b5cf247e9b2c26533c53ece8b0d7d4f4c36..51f8baaf00bd8fd25baa1a87be8cb0089dfb22b5 100644 --- a/tensorflow/compiler/tf2xla/lib/util.h +++ b/tensorflow/compiler/tf2xla/lib/util.h @@ -25,13 +25,18 @@ namespace tensorflow { // Returns a zero-filled tensor with shape `shape`. xla::ComputationDataHandle Zeros(xla::ComputationBuilder* builder, - xla::Shape& shape); + const xla::Shape& shape); // Returns a floating point scalar constant of 'type' with 'value'. // If 'type' is complex, returns a real value with zero imaginary component. xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder, xla::PrimitiveType type, double value); +// Returns a integer scalar constant of 'type' with 'value'. +// If 'type' is complex, returns a real value with zero imaginary component. +xla::ComputationDataHandle IntegerLiteral(xla::ComputationBuilder* builder, + xla::PrimitiveType type, int64 value); + // Performs a slice in the minor dimensions of a Tensor. xla::StatusOr SliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, @@ -49,6 +54,10 @@ xla::StatusOr UpdateSliceInMinorDims( xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x, const xla::ComputationDataHandle& update, gtl::ArraySlice start); +// Transposes a stack of matrices `x` by swapping the last two dimensions. +xla::StatusOr TransposeInMinorDims( + xla::ComputationBuilder* builder, const xla::ComputationDataHandle& x); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_LIB_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.cc b/tensorflow/compiler/tf2xla/lib/while_loop.cc new file mode 100644 index 0000000000000000000000000000000000000000..86c02ac2e65c12d3527c4022df0cc603e522ef7a --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/while_loop.cc @@ -0,0 +1,125 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/lib/while_loop.h" +#include "tensorflow/compiler/tf2xla/lib/util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace tensorflow { + +xla::StatusOr> XlaWhileLoop( + const LoopConditionFunction& condition_function, + const LoopBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder) { + int arity = initial_values.size(); + std::vector var_shapes; + var_shapes.reserve(arity); + for (const xla::ComputationDataHandle& input : initial_values) { + TF_ASSIGN_OR_RETURN(auto shape, builder->GetShape(input)); + var_shapes.push_back(std::move(*shape)); + } + xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(var_shapes); + + // Unpacks a tuple into its component parts. + auto unpack_tuple = [](xla::ComputationDataHandle tuple, int arity, + xla::ComputationBuilder* builder) { + std::vector elements(arity); + for (int i = 0; i < arity; ++i) { + elements[i] = builder->GetTupleElement(tuple, i); + } + return elements; + }; + + // Build the condition. + std::unique_ptr cond_builder = + builder->CreateSubBuilder(strings::StrCat(name, "_condition")); + { + auto parameter = cond_builder->Parameter(0, tuple_shape, "parameter"); + + TF_ASSIGN_OR_RETURN( + auto result, + condition_function(unpack_tuple(parameter, arity, cond_builder.get()), + cond_builder.get())); + TF_RETURN_IF_ERROR(cond_builder->SetReturnValue(result)); + } + TF_ASSIGN_OR_RETURN(auto cond, cond_builder->Build()); + + // Build the body. + std::unique_ptr body_builder = + builder->CreateSubBuilder(strings::StrCat(name, "_body")); + { + auto parameter = body_builder->Parameter(0, tuple_shape, "parameter"); + + TF_ASSIGN_OR_RETURN( + auto result, + body_function(unpack_tuple(parameter, arity, body_builder.get()), + body_builder.get())); + + TF_RET_CHECK(result.size() == initial_values.size()); + body_builder->Tuple(result); + } + TF_ASSIGN_OR_RETURN(auto body, body_builder->Build()); + + auto outputs = builder->While(cond, body, builder->Tuple(initial_values)); + + return unpack_tuple(outputs, arity, builder); +} + +xla::StatusOr> XlaForEachIndex( + int64 num_iterations, xla::PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder) { + auto while_cond_fn = [&](gtl::ArraySlice values, + xla::ComputationBuilder* cond_builder) + -> xla::StatusOr { + return cond_builder->Lt( + values[0], + IntegerLiteral(cond_builder, num_iterations_type, num_iterations)); + }; + auto while_body_fn = [&](gtl::ArraySlice values, + xla::ComputationBuilder* body_builder) + -> xla::StatusOr> { + xla::ComputationDataHandle iteration = values[0]; + + std::vector updated_values; + updated_values.reserve(values.size()); + updated_values.push_back(body_builder->Add( + iteration, + body_builder->ConstantLiteral(xla::Literal::One(num_iterations_type)))); + + values.remove_prefix(1); + TF_ASSIGN_OR_RETURN(std::vector body_outputs, + body_function(iteration, values, body_builder)); + updated_values.insert(updated_values.end(), body_outputs.begin(), + body_outputs.end()); + return updated_values; + }; + + std::vector values; + values.reserve(initial_values.size() + 1); + values.push_back( + builder->ConstantLiteral(xla::Literal::Zero(num_iterations_type))); + values.insert(values.end(), initial_values.begin(), initial_values.end()); + + TF_ASSIGN_OR_RETURN(values, XlaWhileLoop(while_cond_fn, while_body_fn, values, + name, builder)); + values.erase(values.begin(), values.begin() + 1); + return values; +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/lib/while_loop.h b/tensorflow/compiler/tf2xla/lib/while_loop.h new file mode 100644 index 0000000000000000000000000000000000000000..2e67a0c99b6deb65fa16ab2dec1727f5cb5fcb92 --- /dev/null +++ b/tensorflow/compiler/tf2xla/lib/while_loop.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ +#define TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { + +// Function that builds a loop condition. Takes as input a sequence of input +// values, and returns a boolean value representing if the condition succeeds. +typedef std::function( + gtl::ArraySlice, xla::ComputationBuilder*)> + LoopConditionFunction; + +// Function that builds a loop body. Takes as input a sequence of input values +// and returns a sequence of output values. +typedef std::function>( + gtl::ArraySlice, xla::ComputationBuilder*)> + LoopBodyFunction; + +// Helper function for building an XLA while loop, where the values carried by +// the loop are a tuple of values, e.g., (a, b, c): +// while( +// condition: (a, b, c) -> bool, +// body: (a, b, c) -> (a, b, c) +// init: (a, b, c) +// ) +// 'name' is a descriptive name for the loop. +xla::StatusOr> XlaWhileLoop( + const LoopConditionFunction& condition_function, + const LoopBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder); + +// Builds an XLA loop that repeats a computation `num_iterations` times. +// +// The body function (ForEachIndexBodyFunction) takes as input a pair of +// (current iteration number, loop-carried values), and returns an updated +// vector of the loop-carried values. +typedef std::function>( + xla::ComputationDataHandle, gtl::ArraySlice, + xla::ComputationBuilder*)> + ForEachIndexBodyFunction; + +xla::StatusOr> XlaForEachIndex( + int64 num_iterations, xla::PrimitiveType num_iterations_type, + const ForEachIndexBodyFunction& body_function, + gtl::ArraySlice initial_values, + StringPiece name, xla::ComputationBuilder* builder); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_LIB_WHILE_LOOP_H_ diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 576cd9bf9abb43e29d9eb8f706e0f42ac2d038e9..fcbd157c6191655865d5e250fdf71338780bc2a6 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -23,17 +23,17 @@ limitations under the License. namespace tensorflow { Status HostTensorToLiteral(const Tensor& host_tensor, xla::Literal* literal) { - literal->Clear(); + xla::Shape literal_shape; TF_RETURN_IF_ERROR(TensorShapeToXLAShape( - host_tensor.dtype(), host_tensor.shape(), literal->mutable_shape())); + host_tensor.dtype(), host_tensor.shape(), &literal_shape)); - literal->Reserve(host_tensor.NumElements()); + *literal = xla::Literal(literal_shape); // memcpy over the payload ... // TODO(phawkins): handle string types. size_t total_bytes = host_tensor.TotalBytes(); if (total_bytes > 0) { - void* dst_ptr = literal->MutableInternalData(); + void* dst_ptr = literal->untyped_data(); const void* src_ptr = DMAHelper::base(&host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } @@ -56,7 +56,7 @@ Status LiteralToHostTensor(const xla::Literal& literal, DataType target_type, *host_tensor = Tensor(target_type, shape); size_t total_bytes = host_tensor->TotalBytes(); if (total_bytes > 0) { - const void* src_ptr = literal.InternalData(); + const void* src_ptr = literal.untyped_data(); void* dst_ptr = DMAHelper::base(host_tensor); memcpy(dst_ptr, src_ptr, total_bytes); } diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc index d9c839b61019b92b6de3a77a7bec610ae848a9a4..1a0e09758f7cc6714793300c6ece14093a8ad246 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.cc +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -14,34 +14,59 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/util/device_name_utils.h" namespace tensorflow { +namespace { +const char kDeviceSuffixReplicatedCore[] = "REPLICATED_CORE"; +const char kShardingAttribute[] = "_XlaSharding"; +} // namespace -static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE"; +namespace { +xla::StatusOr> +GetShardingFromNodeDef(const NodeDef& node_def) { + if (!HasNodeAttr(node_def, kShardingAttribute)) { + return tensorflow::gtl::optional(); + } + string value; + xla::OpSharding sharding; + TF_RETURN_IF_ERROR(GetNodeAttr(node_def, kShardingAttribute, &value)); + if (!sharding.ParseFromString(value)) { + return xla::InvalidArgument( + "Experimental _XlaSharding attribute was not a valid encoded " + "xla::OpSharding proto."); + } + return tensorflow::gtl::optional(sharding); +} -static Status CoreOutOfRangeError(int core, int num_cores_per_replica) { +Status CoreOutOfRangeError(int core, int num_cores_per_replica) { return errors::InvalidArgument( "Invalid replicated core id: ", core, "; num_cores_per_replica=", num_cores_per_replica); } +} // namespace xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { +ParseShardingFromDevice( + const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional explicit_sharding) { if (device_name.empty()) { return tensorflow::gtl::optional(); } - DeviceNameUtils::ParsedName parsed_device; if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { return errors::InvalidArgument("Malformed assigned device '", device_name, "'"); } - if (!parsed_device.has_type || - !StringPiece(parsed_device.type) - .ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) { + + if (explicit_sharding.has_value()) { + return explicit_sharding; + } else if (!parsed_device.has_type || !parsed_device.has_id || + !StringPiece(parsed_device.type) + .contains(kDeviceSuffixReplicatedCore)) { return tensorflow::gtl::optional(); } else { const int core = parsed_device.id; @@ -49,24 +74,38 @@ ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { return CoreOutOfRangeError(core, num_cores_per_replica); } return tensorflow::gtl::optional( - xla::ShardingBuilder::AssignDevice(core)); + xla::sharding_builder::AssignDevice(core)); } } +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica) { + const string& device_name = node_def.device(); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node_def)); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); +} + xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { string device_name = node.assigned_device_name(); if (device_name.empty()) { device_name = node.requested_device(); } - return ParseShardingFromDevice(device_name, num_cores_per_replica); + TF_ASSIGN_OR_RETURN(tensorflow::gtl::optional sharding, + GetShardingFromNodeDef(node.def())); + return ParseShardingFromDevice(device_name, num_cores_per_replica, sharding); } + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { string device_name = src.assigned_device_name(); if (device_name.empty()) { device_name = src.requested_device(); } dst->set_assigned_device_name(device_name); + if (const AttrValue* attr = src.attrs().Find(kShardingAttribute)) { + dst->AddAttr(kShardingAttribute, *attr); + } } } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h index f6468bba9f950fec88dcc6b3ec760f014d3a0ef3..b1c817bdcc211648b16e395313ca171d1acb9ea9 100644 --- a/tensorflow/compiler/tf2xla/sharding_util.h +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -17,7 +17,7 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/sharding_builder.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" @@ -29,14 +29,21 @@ namespace tensorflow { // - if the device name is invalid. // - the core is parsed and is out of the range [0, num_cores_per_replica). // -// Otherwise, returns either a non-value or a sharding set as per -// xla:ShardingBuilder::AssignDevice. +// Otherwise, returns either: +// - explicit_sharding if explicit_sharding.has_value() +// - a non-value if there is no assigned core or +// - a sharding set as per xla::sharding_builder::AssignDevice. xla::StatusOr> -ParseShardingFromDevice(const string& device_name, int num_cores_per_replica); +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica, + tensorflow::gtl::optional + explicit_sharding = tensorflow::gtl::nullopt); xla::StatusOr> ParseShardingFromDevice(const Node& node, int num_cores_per_replica); +xla::StatusOr> +ParseShardingFromDevice(const NodeDef& node_def, int num_cores_per_replica); + void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index a14c93a2b9494b89f579bc20ee0510c136f8f01b..6051d7dffd7493d8cffb07c1b5d10500e7e75522 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -241,9 +241,7 @@ Status CreateXlaArgs(const Graph& graph, XlaCompiler::Argument arg; arg.kind = XlaCompiler::Argument::kParameter; TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &arg.type)); - TensorShape shape; - TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &shape)); - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, &arg.shape)); + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kShapeAttr, &arg.shape)); TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), kDebugNameAttr, &arg.name)); xla_args->push_back(arg); } @@ -253,8 +251,7 @@ Status CreateXlaArgs(const Graph& graph, // Converts the TensorFlow graph into an XLA computation, by executing the // graph symbolically, with each op building up the XLA HLO. Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { XlaOpRegistry::RegisterCompilationKernels(); for (Node* node : graph->nodes()) { node->set_assigned_device_name( @@ -277,7 +274,6 @@ Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, TF_RETURN_IF_ERROR(compiler.CompileGraph(XlaCompiler::CompileOptions(), "tfcompile", std::move(graph), xla_args, &result)); - *requires_runtime_context = result.requires_runtime_context; *computation = std::move(*result.computation); int num_const_results = 0; @@ -352,12 +348,10 @@ Status InitGraph(const GraphDef& graph_def, const tf2xla::Config& config, Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context) { + xla::Computation* computation) { std::unique_ptr graph; TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); - TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation, - requires_runtime_context)); + TF_RETURN_IF_ERROR(ConvertGraphToXla(std::move(graph), client, computation)); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla.h b/tensorflow/compiler/tf2xla/tf2xla.h index ab99beebf7946237425d4d304a858ac6817177b8..473c431b12d441c652f1d0d6c11c5e87836ab36d 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.h +++ b/tensorflow/compiler/tf2xla/tf2xla.h @@ -30,13 +30,9 @@ namespace tensorflow { // // The computation is built in the context of the given `client`, which may // subsequently be used to compile or execute the computation. -// -// If `requires_runtime_context` is filled with true, this indicates the last -// argument of the computation is XlaLocalRuntimeContext*. Status ConvertGraphDefToXla(const GraphDef& graph_def, const tf2xla::Config& config, xla::Client* client, - xla::Computation* computation, - bool* requires_runtime_context); + xla::Computation* computation); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..7aca889a266439538c4cd1c153460e6cc871b246 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace tf2xla { +namespace { + +void PrintSupportedOps(const string& device, const string& regen_run) { + XlaOpRegistry::RegisterCompilationKernels(); + + std::vector kdefs = + XlaOpRegistry::DeviceKernels(device, + /*include_compilation_only_kernels=*/true); + std::sort( + kdefs.begin(), kdefs.end(), + [](const KernelDef* a, const KernelDef* b) { return a->op() < b->op(); }); + + std::cout << "**Supported operators for device: " << device << "**\n\n" + << "Operator | Type Constraint\n" + << "-------- | ---------------" << std::endl; + for (const KernelDef* kdef : kdefs) { + std::vector constraints; + for (const KernelDef::AttrConstraint& constraint : kdef->constraint()) { + std::vector types; + for (int type : constraint.allowed_values().list().type()) { + types.push_back(DataTypeString(static_cast(type))); + } + std::sort(types.begin(), types.end()); + constraints.push_back("`" + constraint.name() + "={" + + str_util::Join(types, ",") + "}`"); + } + std::cout << "`" << kdef->op() << "` | " + << str_util::Join(constraints, "
") << std::endl; + } + + std::cout << "\nTo regenerate this table, run:\n\n```shell\n" + << regen_run << " --device=" << device << "\n```" << std::endl; +} + +} // namespace + +void SupportedOpsMain(int argc, char** argv, const char* regen_run) { + std::vector device_names = XlaOpRegistry::BackendNames(); + std::sort(device_names.begin(), device_names.end()); + + // Set up and parse flags. + string device; + std::vector flag_list = { + {"device", &device, + "Name of the compilation device for which to print supported ops, " + "one of: " + + str_util::Join(device_names, ",")}, + }; + string usage = Flags::Usage(argv[0], flag_list); + bool parsed_flags_ok = Flags::Parse(&argc, argv, flag_list); + QCHECK(parsed_flags_ok) << "\n" << usage; + QCHECK(XlaOpRegistry::IsBackendRegistered(device)) + << "\nUnknown device: " << device << "\n" + << usage; + + // Run the program. + port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(argc == 1) << "\nERROR: This command does not take any arguments " + "other than flags\n\n" + << usage; + PrintSupportedOps(device, regen_run); +} + +} // namespace tf2xla +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..1b45fb4cdd3b0173b04e130b7416874a9a406dc5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ + +namespace tensorflow { +namespace tf2xla { + +// The implementation of a main function for a binary that prints a table of +// supported tf2xla operators for a given device, along with their type +// constraints, to stdout. +// +// Pass the argc and argv from main, unmodified. Use regen_run to specify the +// command used to regenerate the table. +void SupportedOpsMain(int argc, char** argv, const char* regen_run); + +} // namespace tf2xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_SUPPORTED_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc new file mode 100644 index 0000000000000000000000000000000000000000..690666c2400d45e33c1a5d1818b68a86a70a5be3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/tf2xla_supported_ops_main.cc @@ -0,0 +1,22 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/tf2xla_supported_ops.h" + +int main(int argc, char** argv) { + const char* regen_run = + "bazel run -c opt -- tensorflow/compiler/tf2xla:tf2xla_supported_ops"; + tensorflow::tf2xla::SupportedOpsMain(argc, argv, regen_run); +} diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index ecd15652fe84b0c19d2f7fc18f877236547f9be9..a9978e697b091715ce120f0d18fdddd259e08b32 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -70,10 +70,7 @@ TEST(ConvertGraphDefToXla, Sum) { xla::LocalClient* client = xla::ClientLibrary::LocalClientOrDie(); xla::Computation computation; - bool requires_runtime_context; - TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation, - &requires_runtime_context)); - ASSERT_FALSE(requires_runtime_context); + TF_EXPECT_OK(ConvertGraphDefToXla(graph_def, config, client, &computation)); // Set up arguments. auto x_literal = xla::Literal::CreateR0(10); diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 55f2f3149c6ba7bfa18608f961c8a76103a50756..f428a194328935fec1210ea96245344de859e611 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -88,8 +88,8 @@ Status ValidateConfig(const tf2xla::Config& config) { TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names)); } TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names)); - if (config.feed().empty() || config.fetch().empty()) { - return errors::InvalidArgument("feeds and fetches must be specified"); + if (config.fetch().empty()) { + return errors::InvalidArgument("fetches must be specified"); } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index 436039e154842443f779aba276bc571fc2ab7537..ed10d80609641b090cf78bf2e17364fe2fa89c31 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -58,24 +58,14 @@ TEST(ValidateConfig, Good) { TEST(ValidateConfig, BadEmpty) { tf2xla::Config config; - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); -} - -TEST(ValidateConfig, BadNoFeed) { - tf2xla::Config config; - tf2xla::Fetch* fetch = config.add_fetch(); - fetch->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadNoFetch) { tf2xla::Config config; tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("foo"); - ExpectErrorContains(ValidateConfig(config), - "feeds and fetches must be specified"); + ExpectErrorContains(ValidateConfig(config), "fetches must be specified"); } TEST(ValidateConfig, BadFeedNodeName) { diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 4f32c29954b2d809d31ef8c584b6a6c3dcdf5cef..fcb0a4e63814b4afc114bdaea312a92dd8396a2e 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -100,7 +100,7 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, b->SetOpMetadata(metadata); auto sharding_parse_result = ParseShardingFromDevice( - op_kernel->requested_device(), std::numeric_limits::max()); + op_kernel->def(), std::numeric_limits::max()); OP_REQUIRES_OK(context, sharding_parse_result.status()); tensorflow::gtl::optional op_sharding = sharding_parse_result.ValueOrDie(); @@ -135,98 +135,4 @@ void XlaExpression::set_constant_value(Tensor value) { constant_value_ = std::move(value); } -Status XlaResource::GetXlaShape(xla::ComputationBuilder* builder, - xla::Shape* shape) const { - auto shape_or_status = builder->GetShape(value); - if (!shape_or_status.ok()) { - return shape_or_status.status(); - } - *shape = *shape_or_status.ValueOrDie(); - return Status::OK(); -} - -Status XlaResource::GetShape(xla::ComputationBuilder* builder, - TensorShape* shape) const { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &xla_shape)); - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, shape)); - return Status::OK(); -} - -Status XlaResource::GetOrCreateTensorArrayGradient( - const string& source, xla::ComputationBuilder* builder, - XlaResource** gradient_out) { - VLOG(2) << "Gradient lookup for resource: " << name - << " gradient: " << source; - TF_RET_CHECK(kind == kTensorArray); - std::unique_ptr& gradient = tensor_array_gradients[source]; - if (!gradient) { - gradient.reset(new XlaResource); - gradient->kind = XlaResource::kTensorArray; - gradient->name = strings::StrCat("TensorArrayGrad: ", name); - gradient->type = type; - gradient->tensor_array_size = tensor_array_size; - - TensorShape ta_shape; - TF_RETURN_IF_ERROR(GetShape(builder, &ta_shape)); - gradient->value = builder->Broadcast(XlaHelpers::Zero(builder, type), - ta_shape.dim_sizes()); - gradient->initial_value = gradient->value; - } - *gradient_out = gradient.get(); - return Status::OK(); -} - -Status XlaResource::PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const { - if (tensor_array_gradients.empty()) { - return GetXlaShape(builder, packed_shape); - } - TF_RET_CHECK(kind == kTensorArray); - std::vector elem_shapes(1 + tensor_array_gradients.size()); - int pos = 0; - TF_RETURN_IF_ERROR(GetXlaShape(builder, &elem_shapes[pos++])); - for (const auto& gradient : tensor_array_gradients) { - TF_RETURN_IF_ERROR( - gradient.second->GetXlaShape(builder, &elem_shapes[pos++])); - } - *packed_shape = xla::ShapeUtil::MakeTupleShape(elem_shapes); - return Status::OK(); -} - -Status XlaResource::Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const { - if (tensor_array_gradients.empty()) { - *pack = value; - } else { - TF_RET_CHECK(kind == kTensorArray); - std::vector elems; - elems.push_back(value); - for (const auto& gradient : tensor_array_gradients) { - elems.push_back(gradient.second->value); - } - *pack = builder->Tuple(elems); - } - return Status::OK(); -} - -Status XlaResource::SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder) { - if (gradient_sources.empty()) { - value = pack; - } else { - TF_RET_CHECK(kind == kTensorArray); - int pos = 0; - value = builder->GetTupleElement(pack, pos++); - for (const auto& source : gradient_sources) { - XlaResource* gradient; - TF_RETURN_IF_ERROR( - GetOrCreateTensorArrayGradient(source, builder, &gradient)); - gradient->value = builder->GetTupleElement(pack, pos++); - } - } - return Status::OK(); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 6230acd718bc330f178007b575b5119de5b3d4f4..0243ee332fbdca0fe5e28b1a7d9530df4417f807 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -66,87 +67,6 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// Represents a resource, such as a Variable or TensorArray. -// TODO(phawkins): make this into a properly abstracted class. -struct XlaResource { - enum Kind { - kInvalid, - kVariable, - kTensorArray, - kStack, - }; - - Kind kind = kInvalid; - - // If this resource is visible externally, what was its argument number? - int arg_num = -1; - - // A descriptive name for the resource, used in error messages. - string name; - - // Current type and value of the resource. Uninitialized resources are - // represented by a default (zero) handle and type DT_INVALID. - // While the type of a resource is notionally fixed during execution, when - // a resource is first initialized we do not yet know its type, so we keep - // track of its type dynamically. - DataType type = DT_INVALID; - xla::ComputationDataHandle value; - - // Value of the resource at computation entry. Used to detect which - // variables have new values that need to be written back. - xla::ComputationDataHandle initial_value; - - // TensorArray-specific fields - - // 'tensor_array_size' stores the expected size of the TensorArray. We need - // to store this since sometimes TensorArrays must be initialized lazily since - // we do not know the element shape at construction time. - int64 tensor_array_size = -1; - - // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes - // to an XlaResource containing the gradient TensorArrays. We store a pointer - // here since there should only be one gradient TensorArray per 'source' - // string, irrespective of the number of calls to TensorArrayGrad. The map - // is ordered since values are packed into tuples by Pack() sorted by name - // order. - std::map> tensor_array_gradients; - - // Returns the shape of the resource as an xla::Shape. - Status GetXlaShape(xla::ComputationBuilder* builder, xla::Shape* shape) const; - - // Returns the shape of the resource as an TensorShape. Fails if the shape is - // not representable as a TensorShape. - Status GetShape(xla::ComputationBuilder* builder, TensorShape* shape) const; - - // Looks up the gradient for `source`, or creates it if it does not already - // exist. The call target must be an initialized TensorArray resource. A - // TensorArray can have multiple named gradients; see the operator - // documentation for TensorArrayGradV3 for details. - Status GetOrCreateTensorArrayGradient(const string& source, - xla::ComputationBuilder* builder, - XlaResource** gradient_out); - - // Packs a resource into a single XLA value `pack`, suitable for use as - // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without - // gradients, sets `*pack` to `value`. - // For TensorArrays with gradients, packs the value and its gradient values in - // a tuple; the gradients values are packed in order by source name. - Status Pack(xla::ComputationDataHandle* pack, - xla::ComputationBuilder* builder) const; - - // Returns the shape of the `pack` value computed by `Pack()`. - Status PackedShape(xla::ComputationBuilder* builder, - xla::Shape* packed_shape) const; - - // Updates the resource with values from `pack`. If `gradient_sources` is - // non-empty, treats `pack` as a tuple that represents a TensorArray and - // its gradients, and unpacks and updates the gradient resources. Opposite - // of Pack(). - Status SetFromPack(const std::set& gradient_sources, - const xla::ComputationDataHandle& pack, - xla::ComputationBuilder* builder); -}; - // A XlaExpression wraps an XLA computation. Each Tensor on an // XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor // matches the shape of the subcomputation in the ComputationDataHandle. Each diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index b5c17c5273bb15e20184b2fefd93880d4828105e..672e19bd93449ccc31f4af5ded23257b197a3c39 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -28,9 +28,10 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, temps_(new void*[static_data.num_temps]), arg_names_(static_data.arg_names), result_names_(static_data.result_names), - program_shape_(static_data.program_shape) { + program_shape_(static_data.program_shape), + hlo_profile_printer_data_(static_data.hlo_profile_printer_data) { // Allocate arg and temp buffers. - if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) { alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( static_data.arg_sizes, static_data.num_args, args_, /*annotate_initialized=*/false); @@ -39,9 +40,13 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, static_data.temp_sizes, static_data.num_temps, temps_, /*annotate_initialized=*/true); - // The runtime context is always the last arg, if it is required. - if (static_data.requires_runtime_context) { - args_[static_data.num_args - 1] = &context_; + // If Hlo profiling is enabled the generated code expects an appropriately + // sized buffer to be passed in as the last argument. If Hlo profiling is + // disabled the last function argument is still present in the function + // signature, but it is ignored by the generated code and we pass in null for + // it. + if (hlo_profiling_enabled()) { + profile_counters_ = new int64[static_data.profile_counters_size](); } } @@ -50,6 +55,7 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); delete[] args_; delete[] temps_; + delete[] profile_counters_; } namespace { diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index f49a7889222ff989144217ab10b27595f89e4311..48a8c083cacf2f6ecf9dc1817b6174c01385d035 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -16,10 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ -#include +#include #include -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/core/platform/types.h" @@ -27,6 +26,7 @@ limitations under the License. // never use this functionality. namespace xla { class ProgramShape; +class HloProfilePrinterData; } namespace tensorflow { @@ -48,12 +48,10 @@ namespace tensorflow { class XlaCompiledCpuFunction { public: // Type of the raw function, produced by either JIT or AOT. - // - // TODO(toddw): Add support for hlo profiling, and replace std::function with - // a raw function pointer, for some codesize savings. - using RawFunction = std::function; + using RawFunction = void (*)(void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps, + int64* profile_counters); // StaticData represents the state necessary to run an XLA-compiled // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for @@ -71,9 +69,6 @@ class XlaCompiledCpuFunction { // The 0-based index of the result tuple, in the temp buffers. size_t result_index = 0; - // Is the final arg XlaLocalRuntimeContext? - bool requires_runtime_context = false; - // [Optional] Arrays of arg and result names. These are arrays of C-style // strings, where the array is terminated by nullptr. const char** arg_names = nullptr; @@ -81,21 +76,31 @@ class XlaCompiledCpuFunction { // [Optional] Arg and result shapes. const xla::ProgramShape* program_shape = nullptr; + + // [Optional] Profile printer data. Null if profiling is disabled. + const xla::HloProfilePrinterData* hlo_profile_printer_data = nullptr; + + // [Optional] The number of profile counters expected in the profile counter + // buffer by the generated code and hlo_profile_printer. 0 if profiling is + // disabled. This information is already present in + // hlo_profile_printer_data but xla::HloProfilePrinterData is forward + // declared so we don't have access to that information here. + int64 profile_counters_size = 0; }; // AllocMode controls the buffer allocation mode. enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, + // Allocate all buffers - args, results, profile and temps. + ARGS_RESULTS_PROFILES_AND_TEMPS, - // Only allocate result and temp buffers. + // Only allocate result, profile and temp buffers. // Use set_arg_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, + RESULTS_PROFILES_AND_TEMPS_ONLY, }; XlaCompiledCpuFunction( const StaticData& static_data, - AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS); virtual ~XlaCompiledCpuFunction(); XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; @@ -104,21 +109,22 @@ class XlaCompiledCpuFunction { // Sets the intra-op thread pool used to run individual ops concurrently. void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; } // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. bool Run() { - context_.error = false; - context_.error_msg.clear(); raw_function_(temps_[result_index_], &run_options_, - const_cast(args_), temps_); - return !context_.error; + const_cast(args_), temps_, profile_counters_); + return true; } // Returns the error message from the previous failed Run call. - const string& error_msg() const { return context_.error_msg; } + // + // TODO(fschneider): For now this always returns an empty string because there + // is no support for error reporting in XLA. Remove this once all callers are + // updated. + string error_msg() const { return {}; } // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. @@ -141,10 +147,6 @@ class XlaCompiledCpuFunction { // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in // tensorflow/compiler/aot/runtime.h to ensure correct alignment. // - // If StaticData.requires_runtime_context==true, the final argument is an - // XlaLocalRuntimeContext, which is managed internally by this class, and - // should not be changed. - // // Aliasing of argument and result buffers is not allowed, and results in // undefined behavior. void set_arg_data(size_t index, void* data) { args_[index] = data; } @@ -162,6 +164,16 @@ class XlaCompiledCpuFunction { return static_cast(temps_[result_index_]); } + // Profile counters for this XLA computation. + // + // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in + // this case) these counters are non-null and are automatically populated by + // `Run`. The counters can then be pretty-printed using + // `hlo_profile_printer()`. + // + // When Hlo profiling is disabled, this accessor returns null. + const int64* profile_counters() const { return profile_counters_; } + // Returns the buffer for the positional result at the given `index`. void* result_data(size_t index) { return results()[index]; } const void* result_data(size_t index) const { return results()[index]; } @@ -195,6 +207,14 @@ class XlaCompiledCpuFunction { // program shape isn't available. const xla::ProgramShape* ProgramShape() const { return program_shape_; } + bool hlo_profiling_enabled() const { + return hlo_profile_printer_data_ != nullptr; + } + const xla::HloProfilePrinterData& hlo_profile_printer_data() const { + assert(hlo_profiling_enabled()); + return *hlo_profile_printer_data_; + } + private: const RawFunction raw_function_; const size_t result_index_; @@ -208,14 +228,17 @@ class XlaCompiledCpuFunction { void* alloc_args_ = nullptr; void* alloc_temps_ = nullptr; + // Backing memory for profiling counters. + int64* profile_counters_ = nullptr; + // Options and context passed to the compiled function. xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; const xla::ProgramShape* program_shape_ = nullptr; + const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 48cebdf74c71f974bf075e0255626ec57eb9a149..59e88304422eaeaaf3f63cc4d476a8ec7ce95623 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -66,13 +66,14 @@ Status CheckSignature(const DataTypeVector& types, bool XlaCompiler::Argument::operator==( const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, tensor_array_size, + if (std::tie(kind, resource_kind, type, name, initialized, tensor_array_size, tensor_array_gradients) != std::tie(other.kind, other.resource_kind, other.type, other.name, - other.tensor_array_size, other.tensor_array_gradients)) { + other.initialized, other.tensor_array_size, + other.tensor_array_gradients)) { return false; } - if (!xla::ShapeUtil::Equal(shape, other.shape)) { + if (shape != other.shape) { return false; } if (constant_value.shape() != other.constant_value.shape()) { @@ -152,7 +153,8 @@ std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); OptimizerOptions opts; - opts.set_do_common_subexpression_elimination(true); + opts.set_opt_level(OptimizerOptions::L0); + opts.set_do_common_subexpression_elimination(false); opts.set_do_function_inlining(true); opts.set_do_constant_folding(true); GraphOptimizer optimizer(opts); @@ -183,8 +185,7 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, CheckSignature(fbody->arg_types, args), "Signature check failure while compiling: ", function.name()); - std::unique_ptr graph(new Graph(options_.flib_def)); - CopyGraph(*fbody->graph, graph.get()); + std::unique_ptr graph = GetGraph(fbody); // _Arg and _Retval nodes don't exist in the stored subgraph for the function; // they are added by the function body looked up. Therefore, they don't have @@ -212,15 +213,6 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, *graph); } - // Optimize the graph before running the compiler. - OptimizerOptions opts; - opts.set_do_common_subexpression_elimination(true); - opts.set_do_function_inlining(true); - opts.set_do_constant_folding(true); - GraphOptimizer optimizer(opts); - optimizer.Optimize(flib_runtime_, flib_runtime_->env(), - /*device=*/nullptr, &graph, /*shape_map=*/nullptr); - VLOG(1) << "===================================================="; TF_RETURN_IF_ERROR( CompileGraph(options, function_id, std::move(graph), args, result)); @@ -230,6 +222,64 @@ Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, return Status::OK(); } +// Computes the XLA shape for argument 'arg'. +/*static*/ Status XlaCompiler::XLAShapeForArgument( + const XlaCompiler::Argument& arg, xla::Shape* xla_shape) { + switch (arg.kind) { + case XlaCompiler::Argument::kConstant: + return TensorShapeToXLAShape(arg.type, arg.constant_value.shape(), + xla_shape); + case XlaCompiler::Argument::kParameter: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaCompiler::Argument::kResource: { + TF_RET_CHECK(arg.initialized); + + switch (arg.resource_kind) { + case XlaResource::kVariable: + return TensorShapeToXLAShape(arg.type, arg.shape, xla_shape); + case XlaResource::kTensorArray: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(arg.type, shape, xla_shape)); + + if (!arg.tensor_array_gradients.empty()) { + std::vector tuple_shape( + arg.tensor_array_gradients.size() + 1, *xla_shape); + *xla_shape = xla::ShapeUtil::MakeTupleShape(tuple_shape); + } + return Status::OK(); + } + case XlaResource::kStack: { + if (arg.tensor_array_size < 0) { + return errors::InvalidArgument( + "Negative tensor_array_size in XLAShapeForArgument"); + } + TensorShape shape; + shape.AddDim(arg.tensor_array_size); + shape.AppendShape(arg.shape); + xla::Shape buffer_shape; + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, shape, &buffer_shape)); + *xla_shape = xla::ShapeUtil::MakeTupleShape( + {buffer_shape, xla::ShapeUtil::MakeShape(xla::S32, {})}); + return Status::OK(); + } + + case XlaResource::kInvalid: + return errors::Internal( + "Invalid resource type in XLAShapeForArgument()"); + } + } + case XlaCompiler::Argument::kInvalid: + return errors::Internal("Invalid argument type in XLAShapeForArgument()"); + } +} + namespace { Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, @@ -268,14 +318,16 @@ Status BuildArguments(const Graph& graph, XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, - std::vector* input_shapes) { + std::vector* input_shapes, + bool is_entry_computation) { arg_expressions->resize(args.size()); *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. - std::vector parameters, resources; - parameters.reserve(args.size()); + input_mapping->clear(); + input_mapping->reserve(args.size()); + std::vector resources; resources.reserve(args.size()); // Fills in constant arguments, and computes non-constant argument order. @@ -289,18 +341,20 @@ Status BuildArguments(const Graph& graph, // TODO(phawkins): this code assumes that resource arguments do not // alias. XlaResource* resource; - TF_RETURN_IF_ERROR( - context->CreateResource(arg.resource_kind, i, arg.name, arg.type, - xla::ComputationDataHandle(), &resource)); - resource->tensor_array_size = arg.tensor_array_size; + TF_RETURN_IF_ERROR(context->CreateResource( + arg.resource_kind, i, arg.name, arg.type, arg.shape, + xla::ComputationDataHandle(), + /*tensor_array_size=*/arg.tensor_array_size, + /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); arg_expression.set_resource(resource); if (arg.initialized) { resources.push_back(i); } break; - case XlaCompiler::Argument::kParameter: - parameters.push_back(i); + case XlaCompiler::Argument::kParameter: { + input_mapping->push_back(i); break; + } case XlaCompiler::Argument::kConstant: arg_expression.set_constant_value(arg.constant_value); break; @@ -311,18 +365,23 @@ Status BuildArguments(const Graph& graph, // Append parameters containing variable values after the other runtime // parameters. - parameters.insert(parameters.end(), resources.begin(), resources.end()); - if (parameters.empty()) { + input_mapping->insert(input_mapping->end(), resources.begin(), + resources.end()); + if (input_mapping->empty()) { return Status::OK(); } - input_shapes->resize(parameters.size()); - input_mapping->resize(parameters.size()); - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + std::vector arg_shapes(input_mapping->size()); + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { // Computes the shapes of non-constant arguments. - (*input_shapes)[i] = arg.shape; - (*input_mapping)[i] = parameters[i]; + TF_RETURN_IF_ERROR(XlaCompiler::XLAShapeForArgument( + args[(*input_mapping)[i]], &arg_shapes[i])); + } + + if (use_tuple_arg) { + input_shapes->push_back(xla::ShapeUtil::MakeTupleShape(arg_shapes)); + } else { + *input_shapes = arg_shapes; } // Use the _Arg nodes in the graph to resolve core assignments. @@ -346,24 +405,38 @@ Status BuildArguments(const Graph& graph, } // Build parameter handles for non-constant arguments. - std::vector arg_handles(parameters.size()); + std::vector arg_handles(input_mapping->size()); if (use_tuple_arg) { - xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(*input_shapes); - xla::ComputationDataHandle tuple = - builder->Parameter(0, tuple_shape, "arg_tuple"); - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + xla::ComputationDataHandle tuple; + if (is_entry_computation) { + xla::OpSharding tuple_sharding; + tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); + for (int64 parameter : *input_mapping) { + const int core = (*arg_cores)[parameter]; + const int root_device = 0; + *tuple_sharding.add_tuple_shardings() = + core == -1 ? xla::sharding_builder::AssignDevice(root_device) + : xla::sharding_builder::AssignDevice(core); + } + xla::ScopedShardingAssignment assign_tuple_sharding(builder, + tuple_sharding); + tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + } else { + tuple = builder->Parameter(0, (*input_shapes)[0], "arg_tuple"); + } + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() - : xla::ShardingBuilder::AssignDevice(core)); + : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const int core = (*arg_cores)[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const int core = (*arg_cores)[input_mapping->at(i)]; xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() - : xla::ShardingBuilder::AssignDevice(core)); + : xla::sharding_builder::AssignDevice(core)); arg_handles[i] = builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); } @@ -371,12 +444,12 @@ Status BuildArguments(const Graph& graph, // Fill in the handles in non-constant arguments. VLOG(2) << "XLA computation inputs:"; - for (std::vector::size_type i = 0; i < parameters.size(); ++i) { - const XlaCompiler::Argument& arg = args[parameters[i]]; + for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { + const XlaCompiler::Argument& arg = args[input_mapping->at(i)]; VLOG(2) << " XLA arg " << i - << " shape: " << xla::ShapeUtil::HumanString((*input_shapes)[i]) - << " name: " << arg.name << " TF arg " << parameters[i]; - XlaExpression& arg_expression = (*arg_expressions)[parameters[i]]; + << " shape: " << xla::ShapeUtil::HumanString(arg_shapes[i]) + << " name: " << arg.name << " TF arg " << input_mapping->at(i); + XlaExpression& arg_expression = (*arg_expressions)[input_mapping->at(i)]; switch (arg.kind) { case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); @@ -385,10 +458,6 @@ Status BuildArguments(const Graph& graph, arg_handles[i], builder)); VLOG(2) << " resource: num_gradients: " << arg.tensor_array_gradients.size(); - resource->initial_value = resource->value; - for (const auto& gradient : resource->tensor_array_gradients) { - gradient.second->initial_value = gradient.second->value; - } break; } case XlaCompiler::Argument::kParameter: @@ -439,43 +508,44 @@ Status BuildComputation( std::vector arg_resources; arg_resources.reserve(resources.size()); for (const auto& resource : resources) { - if (resource->arg_num >= 0) { + if (resource->arg_num() >= 0) { arg_resources.push_back(resource.get()); } } std::sort(arg_resources.begin(), arg_resources.end(), [](const XlaResource* a, const XlaResource* b) { - return a->arg_num < b->arg_num; + return a->arg_num() < b->arg_num(); }); for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num]; - const int core = arg_cores[resource->arg_num]; - DCHECK_LT(resource->arg_num, arg_cores.size()); + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + const int core = arg_cores[resource->arg_num()]; + DCHECK_LT(resource->arg_num(), arg_cores.size()); bool modified = - resource->value.handle() != resource->initial_value.handle(); + resource->value().handle() != resource->initial_value().handle(); // TensorArray gradients were modified if their values changed or there are // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients) { - modified = - modified || - grad.second->value.handle() != grad.second->initial_value.handle() || - arg.tensor_array_gradients.count(grad.first) == 0; + for (const auto& grad : resource->tensor_array_gradients()) { + modified = modified || + grad.second->value().handle() != + grad.second->initial_value().handle() || + arg.tensor_array_gradients.count(grad.first) == 0; } if (return_updated_values_for_all_resources || modified) { resource_updates->emplace_back(); XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num; - update.type = resource->type; + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients) { + for (const auto& grad : resource->tensor_array_gradients()) { update.tensor_array_gradients_accessed.insert(grad.first); } // Request that the value be returned on a specific core. xla::ScopedShardingAssignment assign_sharding( builder, core == -1 ? tensorflow::gtl::optional() - : xla::ShardingBuilder::AssignDevice(core)); + : xla::sharding_builder::AssignDevice(core)); xla::ComputationDataHandle handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); @@ -502,18 +572,6 @@ Status BuildComputation( return Status::OK(); } -void AssignMajorToMinorLayout(xla::Shape* shape) { - if (xla::ShapeUtil::IsTuple(*shape)) { - for (xla::Shape& elem_shape : *shape->mutable_tuple_shapes()) { - AssignMajorToMinorLayout(&elem_shape); - } - } else { - auto& minor_to_major = *shape->mutable_layout()->mutable_minor_to_major(); - minor_to_major.Resize(xla::ShapeUtil::Rank(*shape), 0); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), 0); - } -} - } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, @@ -543,13 +601,12 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options.resolve_compile_time_constants); core::ScopedUnref context_unref(context); - result->tuple_arg = options.use_tuple_arg; - std::vector arg_expressions; std::vector arg_cores; - TF_RETURN_IF_ERROR(BuildArguments( - *graph, args, options.use_tuple_arg, &builder, context, &arg_cores, - &arg_expressions, &result->input_mapping, &result->xla_input_shapes)); + TF_RETURN_IF_ERROR( + BuildArguments(*graph, args, options.use_tuple_arg, &builder, context, + &arg_cores, &arg_expressions, &result->input_mapping, + &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, @@ -564,11 +621,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); - result->requires_runtime_context = context->has_context_parameter(); - - // Tuple arguments and runtime context parameters are incompatible. - TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); - VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; result->outputs.resize(context->retvals().size()); @@ -596,7 +648,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, << xla::ShapeUtil::HumanString(result->xla_output_shape); // Tensorflow expects a major-to-minor order of results. - AssignMajorToMinorLayout(&result->xla_output_shape); + xla::LayoutUtil::SetToDefaultLayout(&result->xla_output_shape); // Converts the output shapes to TensorShapes. int computation_output = 0; @@ -615,13 +667,6 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, ++computation_output; } } - - for (std::vector::size_type i = 0; - i < result->resource_updates.size(); ++i) { - result->resource_updates[i].shape = xla::ShapeUtil::GetTupleElementShape( - result->xla_output_shape, computation_output); - ++computation_output; - } return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index ac7d4cfb127d1de8c92f3a855191c45af77888ad..b86c82c0ab5ce379d35a13043857f459199e2ad2 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -54,8 +54,6 @@ namespace tensorflow { // +---------------------+-----------------------------------------+ // Within each block, the arguments are arranged by the _Arg index from which // they were derived. -// If `Options::requires_runtime_context` is true, then an additional runtime -// context argument is passed as a final argument. // // The run-time outputs of the XLA computation are arranged in the following // order: @@ -106,9 +104,17 @@ class XlaCompiler { // is the type of the variable's value, not DT_RESOURCE. DataType type; - // The shape of the argument. If the argument is a resource, this is the - // shape of the resource's value. - xla::Shape shape; + // The shape of the argument. For: + // * a parameter: the shape of the parameter. + // * a constant: ignored; the shape given by constant_value is used + // instead. + // * an uninitialized resource: ignored. We don't yet know the shape of an + // uninitialized resource (otherwise we would have initialized it!) + // * an initialized variable: the shape of the variable's value. + // * an initialized TensorArray or Stack resource: the shape of an entry in + // the TensorArray/Stack. Note this is the size of a single entry, not the + // XLA data structure that represents the complete stack/array. + TensorShape shape; // The value of the argument, if it is a compile-time constant. Must be a // host-memory tensor. @@ -154,6 +160,10 @@ class XlaCompiler { // as Tensors at compile-time, rather than as run-time outputs of the // computation. bool resolve_compile_time_constants = true; + + // True when compiling the entry computation, false for subcomputations + // (while, call, etc.) + bool is_entry_computation = true; }; struct OutputDescription { @@ -173,8 +183,9 @@ class XlaCompiler { int input_index; // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. DataType type; - xla::Shape shape; + TensorShape shape; // Was the value of the variable modified by the computation? // (Always true, unless `return_updated_values_for_all_resources` is true.) @@ -191,16 +202,9 @@ class XlaCompiler { // original arguments, and are not necessarily in the same order.) std::vector input_mapping; - // Does the computation require the local runtime context to be passed as - // the last argument? - bool requires_runtime_context = false; - // Input shapes of the computation. std::vector xla_input_shapes; - // Should the arguments be packed into a single tuple? - bool tuple_arg; - // Output shape in XLA format. The output shape is always a tuple. xla::Shape xla_output_shape; @@ -232,8 +236,7 @@ class XlaCompiler { int graph_def_version = TF_GRAPH_DEF_VERSION; // If 'allow_cpu_custom_calls' is true, kernels may make use of CustomCall() - // for CPU; additionally, an optional XlaLocalRuntimeContext* may be passed - // to the computation. + // for CPU. bool allow_cpu_custom_calls = false; // If not nullptr, populate_resource_manager is called with the @@ -241,6 +244,19 @@ class XlaCompiler { // device is created, and can be used to create metadata objects // that can be accessed by XLA op kernels. std::function* populate_resource_manager = nullptr; + + // If not nullptr, this memory allocator can be used by the compiler for + // temporary allocations it might want to make during compilation. + // + // For example, the compiler may want to try out different algorithms and + // choose the fastest one, and it might run those algorithms over buffers + // created using this allocator. + // + // The compiler can function correctly without an explicit allocator given + // here, but on some devices (notably, GPUs), TensorFlow tends to eagerly + // allocate most or all available memory on the device, leaving none for the + // compiler to access, unless it can use TensorFlow's allocator. + xla::DeviceMemoryAllocator* device_allocator = nullptr; }; explicit XlaCompiler(Options options); @@ -259,11 +275,10 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); - Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func, - const std::vector& types, - const std::vector& shapes, - const std::vector& expressions, - std::vector* args); + // Returns the shape of the XLA parameter for an argument 'arg'. + // See the class comment for more details about the argument passing + // convention. + static Status XLAShapeForArgument(const Argument& arg, xla::Shape* xla_shape); // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 93aae8485d157cd4afbf804d695d5c0ab8d7946c..65de4dbad75b7fb76a041bc799fc31dc5cb80d74 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -191,10 +191,10 @@ TEST_F(XlaCompilerTest, Simple) { std::vector args(2); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); args[1].kind = XlaCompiler::Argument::kParameter; args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[1].shape = TensorShape({2}); // Compiles the graph. XlaCompiler compiler(DefaultOptions()); @@ -227,6 +227,42 @@ TEST_F(XlaCompilerTest, Simple) { xla::LiteralTestUtil::ExpectEqual(*expected_literal, *actual_literal); } +TEST_F(XlaCompilerTest, HasSaneErrorOnNonCompileTimeConstantInputToReshape) { + // Builds a graph that adds reshapes a tensor, but with the shape not + // statically known. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + auto c = ops::Reshape(scope.WithOpName("C"), a, b); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the arguments. + std::vector args(2); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = TensorShape({2}); + args[1].kind = XlaCompiler::Argument::kParameter; + args[1].type = DT_INT32; + args[1].shape = TensorShape({2}); + + // Compiles the graph. + XlaCompiler compiler(DefaultOptions()); + + XlaCompiler::CompilationResult result; + Status status = + compiler.CompileGraph(XlaCompiler::CompileOptions(), "reshape", + std::move(graph), args, &result); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE( + StringPiece(status.error_message()).contains("depends on a parameter")) + << status.error_message(); + EXPECT_TRUE( + StringPiece(status.error_message()).contains("[[Node: C = Reshape")) + << status.error_message(); +} + // Tests handling of compile-time constant outputs. TEST_F(XlaCompilerTest, ConstantOutputs) { // Builds a graph with one compile-time constant output and one data-dependent @@ -245,7 +281,7 @@ TEST_F(XlaCompilerTest, ConstantOutputs) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); XlaCompiler::Options options = DefaultOptions(); XlaCompiler compiler(options); @@ -337,7 +373,7 @@ TEST_F(XlaCompilerTest, ResourceManager) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); DummyResourceForTest* resource = new DummyResourceForTest(); @@ -384,7 +420,7 @@ TEST_F(XlaCompilerTest, DeterministicCompilation) { std::vector args(1); args[0].kind = XlaCompiler::Argument::kParameter; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + args[0].shape = TensorShape({2}); // Compiles the graph. auto options = DefaultOptions(); @@ -436,9 +472,7 @@ TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad2"}; @@ -504,9 +538,7 @@ TEST_F(XlaCompilerTest, UnwrittenTensorArrayGradientsAreNotComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; @@ -538,9 +570,7 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { args[0].resource_kind = XlaResource::kTensorArray; args[0].initialized = true; args[0].type = DT_INT32; - args[0].shape = xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::S32, {2}), - xla::ShapeUtil::MakeShape(xla::S32, {2})}); + args[0].shape = TensorShape({}); args[0].tensor_array_size = 2; args[0].tensor_array_gradients = {"grad1"}; diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 651bafd6c5d946adfedd63ebbe93e4ea016f0b37..73878955e3fd54c103c0b07faf7f5ee5bcd84de0 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -70,24 +70,6 @@ XlaContext::XlaContext(XlaCompiler* compiler, xla::ComputationBuilder* builder, allow_cpu_custom_calls_(allow_cpu_custom_calls), resolve_compile_time_constants_(resolve_compile_time_constants) {} -const xla::ComputationDataHandle& -XlaContext::GetOrCreateRuntimeContextParameter() { - CHECK(allow_cpu_custom_calls_); - if (has_context_parameter_) return context_parameter_; - has_context_parameter_ = true; - - // Allocate the next available parameter for the context parameter. - int num_parameters = 0; - for (const XlaExpression& arg : args_) { - if (!arg.has_constant_value()) { - ++num_parameters; - } - } - context_parameter_ = builder_->Parameter( - num_parameters, xla::ShapeUtil::MakeOpaqueShape(), "tf_context"); - return context_parameter_; -} - string XlaContext::DebugString() { return "TLA JIT context"; } // This is called by the Retval Op to associate a computed value @@ -121,18 +103,15 @@ Status XlaContext::AddConstRetval(int retval_index, DataType dtype, xla::ComputationBuilder* XlaContext::builder() { return builder_; } -Status XlaContext::CreateResource(XlaResource::Kind kind, int arg_num, - string name, DataType type, - const xla::ComputationDataHandle& handle, - XlaResource** resource) { - resources_.emplace_back(new XlaResource); +Status XlaContext::CreateResource( + XlaResource::Kind kind, int arg_num, string name, DataType type, + TensorShape shape, const xla::ComputationDataHandle& handle, + int64 tensor_array_size, const std::set& tensor_array_gradients, + XlaResource** resource) { + resources_.emplace_back( + new XlaResource(kind, arg_num, std::move(name), type, std::move(shape), + handle, tensor_array_size, tensor_array_gradients)); *resource = resources_.back().get(); - XlaResource& r = **resource; - r.kind = kind; - r.arg_num = arg_num; - r.name = std::move(name); - r.type = type; - r.initial_value = r.value = handle; return Status::OK(); } @@ -178,6 +157,20 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) { }); } +const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) { + return LookupOrCreate(type, &mul_func_, [this, type] { + const string type_string = DataTypeString(type); + VLOG(1) << "Building Mul() for " << type_string; + xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">"); + xla::PrimitiveType xla_type; + TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type)); + auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x"); + auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y"); + b.Mul(x, y); + return b.Build().ConsumeValueOrDie(); + }); +} + const xla::Computation* XlaContext::LookupOrCreate( DataType type, ComputationMap* out, const std::function& create) { diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index de8aafa3628e6eebdabbc508cd95a2ac86e3472f..fac0352ae81e24597e1045981ac47a7cd09481da 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -56,15 +56,10 @@ class XlaContext : public ResourceBase { xla::ComputationBuilder* builder(); bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool has_context_parameter() const { return has_context_parameter_; } const std::vector& args() const { return args_; } void set_args(std::vector args); - // Get the runtime context parameter, adding one if it does not already exist. - // Dies if not compiling a local executable. - const xla::ComputationDataHandle& GetOrCreateRuntimeContextParameter(); - const std::vector& retvals() { return retvals_; } // This is called by the Retval Op to associate a computed value @@ -76,11 +71,15 @@ class XlaContext : public ResourceBase { Status AddConstRetval(int retval_index, DataType dtype, const xla::Literal& literal); - // Creates a resource with resource `kind` and initial type `type` and - // value `handle`. `name` is a descriptive name for use in error messages. + // Creates a resource with resource `kind` and initial value `handle`. `name` + // is a descriptive name for use in error messages. See the `XlaResource` + // constructor for a description of the remaining arguments. // Fails if the resource already exists. Status CreateResource(XlaResource::Kind kind, int arg_num, string name, - DataType type, const xla::ComputationDataHandle& handle, + DataType type, TensorShape shape, + const xla::ComputationDataHandle& handle, + int64 tensor_array_size, + const std::set& tensor_array_gradients, XlaResource** resource); const std::vector>& resources() { @@ -102,6 +101,11 @@ class XlaContext : public ResourceBase { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Get an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + // The name of the XlaContext resource during symbolic graph execution. static const char kXlaContextResourceName[]; @@ -116,16 +120,9 @@ class XlaContext : public ResourceBase { const bool allow_cpu_custom_calls_; // If true, constant return values are returned as Tensors instead of - // run-time computation outptus. + // run-time computation outputs. const bool resolve_compile_time_constants_; - // When 'has_context_parameter_' is true, this is the computation handle - // for an additional final parameter to the computation, through which will be - // passed a XlaLocalRuntimeContext* at runtime. Created on demand by - // GetOrCreateRuntimeContextParameter(). - bool has_context_parameter_ = false; - xla::ComputationDataHandle context_parameter_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; @@ -155,6 +152,9 @@ class XlaContext : public ResourceBase { // Cached computation to compute Sum of two elements, specialized by type. ComputationMap add_func_; + // Cached computation to compute Mul of two elements, specialized by type. + ComputationMap mul_func_; + // Cached computation to compute Sigmoid of an element, specialized by type. ComputationMap sigmoid_func_; diff --git a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc index d504613d232c779e47a506657d2825d052e726dc..8ca757e72355d890c13b8b448d35c327d3986696 100644 --- a/tensorflow/compiler/tf2xla/xla_gpu_backend.cc +++ b/tensorflow/compiler/tf2xla/xla_gpu_backend.cc @@ -21,8 +21,6 @@ namespace tensorflow { bool GpuOpFilter(KernelDef* kdef) { // TODO(b/31361304): The GPU backend does not parallelize PRNG ops, leading to // slow code. - // TODO(b/34969189) The implementation of TruncatedNormal generates illegal - // code on GPU. if (kdef->op() == "RandomStandardNormal" || kdef->op() == "RandomUniform" || kdef->op() == "RandomUniformInt" || kdef->op() == "TruncatedNormal") { return false; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 9c3e15d2fa4c84af94d137f2e03107bcc980f4cd..f048662953e20b2a612271e2daeef6e370c4822a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// This file defines helper routines for Tla JIT compilation. +// This file defines helper routines for XLA compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/lib/util.h" @@ -121,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, DataType data_type) { switch (data_type) { + case DT_BFLOAT16: + return b->ConstantR0(bfloat16::epsilon()); case DT_FLOAT: return b->ConstantR0(std::numeric_limits::epsilon()); case DT_DOUBLE: @@ -133,54 +135,9 @@ xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationBuilder* b, DataType data_type, int64 value) { - xla::Literal literal; xla::PrimitiveType type; TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type)); - switch (type) { - case xla::U8: - literal = *xla::Literal::CreateR0(value); - break; - case xla::U32: - literal = *xla::Literal::CreateR0(value); - break; - case xla::U64: - literal = *xla::Literal::CreateR0(value); - break; - case xla::S8: - literal = *xla::Literal::CreateR0(value); - break; - case xla::S32: - literal = *xla::Literal::CreateR0(value); - break; - case xla::S64: - literal = *xla::Literal::CreateR0(value); - break; - case xla::F32: - literal = *xla::Literal::CreateR0(value); - break; - case xla::F64: - literal = *xla::Literal::CreateR0(value); - break; - case xla::C64: - literal = *xla::Literal::CreateR0(value); - break; - case xla::PRED: - LOG(FATAL) << "pred element type is not integral"; - case xla::S16: - case xla::U16: - LOG(FATAL) << "u16/s16 literals not yet implemented"; - case xla::F16: - literal = - *xla::Literal::CreateR0(static_cast(value)); - break; - case xla::TUPLE: - LOG(FATAL) << "tuple element type is not integral"; - case xla::OPAQUE: - LOG(FATAL) << "opaque element type is not integral"; - default: - LOG(FATAL) << "unhandled element type " << type; - } - return b->ConstantLiteral(literal); + return ::tensorflow::IntegerLiteral(b, type, value); } xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, @@ -207,8 +164,8 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, "elements."); } - *output = input; - output->mutable_shape()->Swap(&shape); + *output = input.Clone(); + output->mutable_shape_do_not_use()->Swap(&shape); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 1dd454ea8d57e21526e5bcde0c8efc5514983b93..1fe6e69ff2dc838152032ac3d7b21de41684c6f6 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -37,27 +37,14 @@ namespace { // Returns a vector of positional argument buffer sizes. xla::StatusOr> ComputeArgSizes( - const xla::ProgramShape& program_shape, bool requires_runtime_context) { + const xla::ProgramShape& program_shape) { std::vector arg_sizes; const size_t num_args = program_shape.parameters_size(); arg_sizes.reserve(num_args); for (int i = 0; i < num_args; ++i) { const xla::Shape& arg_shape = program_shape.parameters(i); - if (i == num_args - 1 && requires_runtime_context) { - // If the compiled function needs an XlaLocalRuntimeContext* arg, it's - // always last, and must be represented as an opaque type. - const xla::PrimitiveType type = arg_shape.element_type(); - if (type != xla::OPAQUE) { - return errors::InvalidArgument( - "expected final context arg to be opaque, but got type: ", - xla::PrimitiveType_Name(type), ", from program shape: ", - xla::ShapeUtil::HumanString(program_shape)); - } - arg_sizes.push_back(-1); - } else { - constexpr size_t kPointerSize = sizeof(void*); - arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); - } + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); } return std::move(arg_sizes); } @@ -90,21 +77,6 @@ xla::StatusOr ComputeResultIndex( return result_slice.index(); } -// Adapt ComputeFunctionType, which includes a final profile_counters arg, to -// RawFunction, which doesn't include that final arg. -// -// TODO(toddw): Change RawFunction and AOT to also pass the final -// profile_counters arg, and remove this adapter. -XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( - xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { - return [compute_function](void* result, - const xla::ExecutableRunOptions* run_options, - const void** args, void** temps) { - return compute_function(result, run_options, args, temps, - /*profile_counters=*/nullptr); - }; -} - // Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold // the actual strings in nonempty_names, and hold arrays of pointers in // name_ptrs, terminated by a nullptr entry. @@ -144,9 +116,8 @@ XlaJitCompiledCpuFunction::Compile( TF_ASSIGN_OR_RETURN(xla::LocalClient * client, xla::ClientLibrary::GetOrCreateLocalClient()); xla::Computation computation; - bool requires_runtime_context; - TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( - graph_def, config, client, &computation, &requires_runtime_context)); + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla(graph_def, config, client, + &computation)); // Get and verify the program shape. TF_ASSIGN_OR_RETURN(std::unique_ptr program_shape, @@ -177,14 +148,13 @@ XlaJitCompiledCpuFunction::Compile( const xla::cpu::CpuExecutable* cpu_executable = static_cast(executable->executable()); XlaCompiledCpuFunction::RawFunction raw_function = - RawFunctionAdapter(cpu_executable->compute_function()); + cpu_executable->compute_function(); const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); // Compute buffer sizes and the result index, needed to run the raw function. - TF_ASSIGN_OR_RETURN( - std::vector arg_sizes, - ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector arg_sizes, + ComputeArgSizes(*program_shape)); TF_ASSIGN_OR_RETURN(std::vector temp_sizes, ComputeTempSizes(buffer_assignment)); TF_ASSIGN_OR_RETURN(size_t result_index, @@ -203,7 +173,6 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.temp_sizes = jit->temp_sizes_.data(); jit->static_data_.num_temps = jit->temp_sizes_.size(); jit->static_data_.result_index = result_index; - jit->static_data_.requires_runtime_context = requires_runtime_context; // Optional metadata is collected and set below. CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); CollectNames(config.fetch(), &jit->nonempty_result_names_, @@ -211,6 +180,14 @@ XlaJitCompiledCpuFunction::Compile( jit->static_data_.arg_names = jit->arg_names_.data(); jit->static_data_.result_names = jit->result_names_.data(); jit->static_data_.program_shape = jit->program_shape_.get(); + + if (cpu_executable->hlo_profiling_enabled()) { + jit->static_data_.hlo_profile_printer_data = + &cpu_executable->hlo_profile_printer_data(); + jit->static_data_.profile_counters_size = + cpu_executable->hlo_profile_printer_data().profile_counters_size(); + } + return std::move(jit_unique_ptr); } diff --git a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h b/tensorflow/compiler/tf2xla/xla_local_runtime_context.h deleted file mode 100644 index dca420d6ee3fec45f88ac3b450ab0cb4fb83d38a..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/xla_local_runtime_context.h +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ -#define TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ - -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/types.h" - -// Forward-declare the ThreadPoolDevice so that it can be ignored unless it's -// actually used. E.g. some ahead-of-time compiled computations don't need a -// thread pool. -namespace Eigen { -struct ThreadPoolDevice; -} - -namespace tensorflow { - -// An instance of this class is passed to each call from tensorflow into a -// compiled XLA computation. See xla_launch_ops.cc. -struct XlaLocalRuntimeContext { - public: - XlaLocalRuntimeContext() {} - - // Kernels implemented using custom call ops set this if they encounter an - // error. The error is checked after the entire XLA computation is - // complete. - // - // error+error_msg are used instead of Status to reduce the binary size - // overhead for ahead-of-time compiled binaries. - bool error = false; - string error_msg; - - // Kernels that need a thread pool can get it from here. - const Eigen::ThreadPoolDevice* thread_pool = nullptr; - - private: - TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalRuntimeContext); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_LOCAL_RUNTIME_CONTEXT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 2b4cc9ba2d62b0e559e1456e6bfe6ab1e094e1df..ee29158646fa96fe554d089e11d50afb47e3e300 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -118,13 +118,36 @@ Status XlaOpKernelContext::ConstantInputReshaped( std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + xla::StatusOr is_constant = builder()->IsConstant(handle); + if (!is_constant.ok()) { + Status status = is_constant.status(); + errors::AppendToMessage(&status, "while evaluating input ", index, " of ", + context_->op_kernel().type_string(), + " operator as a compile-time constant."); + return status; + } + + if (!is_constant.ValueOrDie()) { + return errors::InvalidArgument( + "Input ", index, " to ", context_->op_kernel().type_string(), + " operator must be a compile-time constant.\n" + "\n" + "XLA compilation requires that operator arguments that represent " + "shapes or dimensions be evaluated to concrete values at compile time. " + "This error means that a shape or dimension argument could not be " + "evaluated at compile time, usually because the value of the argument " + "depends on a parameter to the computation, on a variable, or on a " + "stateful operation such as a random number generator."); + } + // Ask the XLA compiler to evaluate the data handle to a literal. xla::StatusOr> computed = builder()->ComputeConstant(handle, &layout); if (!computed.ok()) { - return errors::InvalidArgument( - "Error evaluating ", context_->op_kernel().name(), " input ", index, - ": ", computed.status().error_message()); + return errors::Internal("Error evaluating ", context_->op_kernel().name(), + " input ", index, + "as a compile-time constant.\nError: ", + computed.status().error_message()); } *constant_literal = std::move(*computed.ValueOrDie()); @@ -206,15 +229,15 @@ Status XlaOpKernelContext::ConstantInputAsInt64Literal(int index, xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); switch (literal.shape().element_type()) { - case xla::S32: - out->Clear(); - *out->mutable_shape() = literal.shape(); - out->mutable_shape()->set_element_type(xla::S64); - for (int32 x : literal.s32s()) { - out->add_s64s(x); + case xla::S32: { + *out = xla::Literal( + xla::ShapeUtil::ChangeElementType(literal.shape(), xla::S64)); + auto src_data = literal.data(); + for (int64 i = 0; i < src_data.size(); ++i) { + out->data()[i] = src_data[i]; } return Status::OK(); - + } case xla::S64: *out = std::move(literal); return Status::OK(); @@ -263,17 +286,26 @@ Status XlaOpKernelContext::ConstantInputList( } Status XlaOpKernelContext::ReadVariableInput( - int index, xla::ComputationDataHandle* value) { + int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value) { const Tensor& tensor = context_->input(index); const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); - TF_RET_CHECK(variable->kind == XlaResource::kVariable); - if (variable->value.handle() == 0) { + TF_RET_CHECK(variable->kind() == XlaResource::kVariable); + if (!variable->initialized()) { return errors::InvalidArgument("Read of uninitialized variable ", - variable->name); + variable->name()); + } + if (variable->type() != type) { + return errors::InvalidArgument( + "Type mismatch for read of variable ", variable->name(), ". Expected ", + DataTypeString(type), "; got ", DataTypeString(variable->type())); + } + *value = variable->value(); + if (shape) { + *shape = variable->shape(); } - *value = variable->value; return Status::OK(); } @@ -283,18 +315,13 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, const XlaExpression* expression = CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); - TF_RET_CHECK(variable->kind == XlaResource::kVariable); - if (variable->value.handle() == 0) { + TF_RET_CHECK(variable->kind() == XlaResource::kVariable); + if (!variable->initialized()) { return errors::InvalidArgument("Read of uninitialized variable ", - variable->name); - } - *type = variable->type; - auto shape_or_status = builder()->GetShape(variable->value); - if (!shape_or_status.ok()) { - return shape_or_status.status(); + variable->name()); } - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), shape)); + *type = variable->type(); + *shape = variable->shape(); return Status::OK(); } @@ -381,26 +408,38 @@ Status XlaOpKernelContext::AssignVariable( CastExpressionFromTensor(context_->input(input_index)); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); - TF_RET_CHECK(variable->kind == XlaResource::kVariable); - if (!((variable->type == DT_INVALID && type != DT_INVALID) || - (variable->type == type))) { - return errors::InvalidArgument( - "Types of variables cannot change after initialization: old type was ", - DataTypeString(variable->type), ", new type is ", DataTypeString(type)); + TF_RET_CHECK(variable->kind() == XlaResource::kVariable); + + auto shape_or_status = builder()->GetShape(handle); + if (!shape_or_status.ok()) { + return shape_or_status.status(); } - variable->type = type; - variable->value = handle; - return Status::OK(); + TensorShape shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(*shape_or_status.ValueOrDie(), &shape)); + + TF_RETURN_IF_ERROR(variable->SetTypeAndShape(type, shape)); + return variable->SetValue(handle); } XlaCompiler* XlaOpKernelContext::compiler() const { return XlaContext::Get(context_).compiler(); } -void XlaOpKernelContext::CtxFailure(Status s) { context_->CtxFailure(s); } -void XlaOpKernelContext::CtxFailureWithWarning(Status s) { +void XlaOpKernelContext::CtxFailure(const Status& s) { + context_->CtxFailure(s); +} +void XlaOpKernelContext::CtxFailureWithWarning(const Status& s) { context_->CtxFailureWithWarning(s); } +void XlaOpKernelContext::CtxFailure(const char* file, int line, + const Status& s) { + context_->CtxFailure(file, line, s); +} +void XlaOpKernelContext::CtxFailureWithWarning(const char* file, int line, + const Status& s) { + context_->CtxFailureWithWarning(file, line, s); +} const xla::Computation* XlaOpKernelContext::GetOrCreateMax( const DataType type) { @@ -417,6 +456,11 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd( return XlaContext::Get(context_).GetOrCreateAdd(type); } +const xla::Computation* XlaOpKernelContext::GetOrCreateMul( + const DataType type) { + return XlaContext::Get(context_).GetOrCreateMul(type); +} + XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {} void XlaOpKernel::Compute(OpKernelContext* context) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 76bcf594e6a0601763844847583c18ee26d8adf3..e1fd0f55c6d2501b4813c90171630a8df567f78a 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -164,21 +164,28 @@ class XlaOpKernelContext { TensorShape* shape) const; // Reads the current value of the resouce variable referred to by input - // 'index'. - Status ReadVariableInput(int index, xla::ComputationDataHandle* value); + // 'index'. If `shape` is not nullptr, sets `*shape` to the shape of the + // variable. Returns an error if the variable has not been initialized, or if + // its type does not match `type`. + Status ReadVariableInput(int index, DataType type, TensorShape* shape, + xla::ComputationDataHandle* value); // Assigns the value `handle` to the variable referenced by input - // `input_index`. Marks the operator as having side effects. + // `input_index`. The variable must be of `type`. Returns an error if the + // variable has been initialized with a different type or with a + // different shape. Status AssignVariable(int input_index, DataType type, const xla::ComputationDataHandle& handle); // Helper routines for the OP_REQUIRES macros - void CtxFailure(Status s); - void CtxFailureWithWarning(Status s); + void CtxFailure(const Status& s); + void CtxFailureWithWarning(const Status& s); + void CtxFailure(const char* file, int line, const Status& s); + void CtxFailureWithWarning(const char* file, int line, const Status& s); // If this kernel invocation is within a function execution, // call_frame() returns the call frame for the function call. - FunctionCallFrame* call_frame() const { return context_->call_frame(); } + CallFrameInterface* call_frame() const { return context_->call_frame(); } FunctionLibraryRuntime* function_library() const { return context_->function_library(); @@ -210,6 +217,11 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::Computation* GetOrCreateAdd(const DataType type); + // Gets an XLA lambda to compute Mul. This is cached in the + // XlaContext since it may be used by multiple Ops. There is a + // separate specialization of the computation for each DataType. + const xla::Computation* GetOrCreateMul(const DataType type); + private: OpKernelContext* const context_; }; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 02318cf7fa1d4edc12507f6b4d66a8e897cbe100..0dde6a986c61bdd5b0b2e6d7a16b29ab95be98ab 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/kernel_def.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def_util.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -82,6 +83,11 @@ XlaOpRegistry::~XlaOpRegistry() = default; return false; } } + if (x.compile_time_constant_inputs != y.compile_time_constant_inputs) { + LOG(WARNING) << "Registrations of " << x.name + << " have incompatible compile time constant inputs."; + return false; + } return true; } @@ -155,7 +161,14 @@ void XlaOpRegistry::RegisterCompilationKernels() { const string& op_name = op.first; const std::unique_ptr& op_registration = op.second; const OpDef* op_def; - TF_CHECK_OK(op_registry->LookUpOpDef(op_name, &op_def)); + Status lookup_status = op_registry->LookUpOpDef(op_name, &op_def); + if (!lookup_status.ok()) { + LOG(ERROR) << lookup_status.error_message(); + XLA_LOG_LINES( + ERROR, "Ops registered: \n" + + dynamic_cast(op_registry)->DebugString(true)); + } + TF_CHECK_OK(lookup_status); std::unordered_set type_attrs; for (const OpDef::AttrDef& attr_def : op_def->attr()) { @@ -187,22 +200,39 @@ void XlaOpRegistry::RegisterCompilationKernels() { // Constrain each type attribute to the intersection of: // a) the types supported by the backend, and - // b) the attribute's type constraints. - // TODO(phawkins): it may be necessary to also take the intersection with - // the set of types supported by the OpDef. + // b) the types allowed by the OpDef, and + // c) the type constraints. for (const string& type_attr : type_attrs) { KernelDef::AttrConstraint* attr_constraint = kdef->add_constraint(); attr_constraint->set_name(type_attr); auto* allowed_values = attr_constraint->mutable_allowed_values()->mutable_list(); - auto it = op_registration->type_constraints.find(type_attr); + const OpDef::AttrDef& op_def_attr = *FindAttr(type_attr, *op_def); + const auto* op_def_allowed_types = + op_def_attr.has_allowed_values() + ? &op_def_attr.allowed_values().list().type() + : nullptr; + auto constraint_it = op_registration->type_constraints.find(type_attr); + const std::set* type_constraints = + constraint_it != op_registration->type_constraints.end() + ? &constraint_it->second + : nullptr; for (DataType dtype : backend.second.supported_types) { - if (it == op_registration->type_constraints.end() || - (it != op_registration->type_constraints.end() && - it->second.find(dtype) != it->second.end())) { - allowed_values->add_type(dtype); + // Filter out types that aren't allowed by the OpDef. + if (op_def_allowed_types != nullptr && + std::find(op_def_allowed_types->begin(), + op_def_allowed_types->end(), + dtype) == op_def_allowed_types->end()) { + continue; + } + // Filter out types based on the type constraints. + if (type_constraints != nullptr && + type_constraints->find(dtype) == type_constraints->end()) { + continue; } + // Passed all the filters, this type is allowed. + allowed_values->add_type(dtype); } if (op_registration->allow_resource_types) { allowed_values->add_type(DT_RESOURCE); @@ -245,6 +275,33 @@ std::vector XlaOpRegistry::DeviceKernels( return kernels; } +/* static */ const std::unordered_set* +XlaOpRegistry::CompileTimeConstantInputs(const string& op) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + auto it = registry.ops_.find(op); + if (it == registry.ops_.end()) { + return nullptr; + } + return &it->second->compile_time_constant_inputs; +} + +std::vector XlaOpRegistry::BackendNames() { + std::vector names; + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + for (const auto& backend_pair : registry.backends_) { + names.push_back(backend_pair.first); + } + return names; +} + +bool XlaOpRegistry::IsBackendRegistered(const string& name) { + XlaOpRegistry& registry = Instance(); + mutex_lock lock(registry.mutex_); + return registry.backends_.find(name) != registry.backends_.end(); +} + XlaOpRegistry& XlaOpRegistry::Instance() { static XlaOpRegistry* r = new XlaOpRegistry; return *r; @@ -303,6 +360,12 @@ XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::TypeConstraint( return *this; } +XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::CompileTimeConstInput( + StringPiece input_name) { + registration_->compile_time_constant_inputs.insert(input_name.ToString()); + return *this; +} + std::unique_ptr XlaOpRegistrationBuilder::Build( XlaOpRegistry::Factory factory) { registration_->factory = factory; diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 6aee8c91cc01b4382ef867fa8e438eede008ac73..ff7453194af3a85bded86a5ce298f8779422dccb 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -45,11 +45,11 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; -constexpr std::array kFloatTypes = { - {DT_HALF, DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kNumericTypes = { +constexpr std::array kFloatTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; +constexpr std::array kNumericTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64}}; + DT_COMPLEX64, DT_BFLOAT16}}; constexpr std::array kCpuAllTypes = { {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, @@ -97,6 +97,12 @@ class XlaOpRegistry { gtl::ArraySlice supported_types, BackendOpFilter op_filter); + // Returns the names of the registered backends. + static std::vector BackendNames(); + + // Returns true iff a backend with the given name is registered. + static bool IsBackendRegistered(const string& name); + // Registers `device_name` for XLA compilation, using information from // `registration`. static void RegisterCompilationDevice(const string& device_name, @@ -116,12 +122,17 @@ class XlaOpRegistry { static void RegisterCompilationKernels(); // Returns KernelDefs for compilation ops registered on - // 'compilation_device_name'. - // Does not include kernels registered as CompilationOnly. + // 'compilation_device_name'. Does not include kernels registered as + // CompilationOnly, iff include_compilation_only_kernels=false. static std::vector DeviceKernels( const string& compilation_device_name, bool include_compilation_only_kernels); + // Returns the set of compile-time constant inputs to 'op'. Returns nullptr + // if the op is not registered. + static const std::unordered_set* CompileTimeConstantInputs( + const string& op); + private: friend class XlaBackendRegistrar; friend class XlaOpRegistrar; @@ -175,6 +186,9 @@ class XlaOpRegistry { bool has_device_whitelist = false; std::unordered_set device_whitelist; + // Names of arguments that must be compile-time constants. + std::unordered_set compile_time_constant_inputs; + // Factory used to build OpKernels that perform symbolic execution. Factory factory; }; @@ -236,6 +250,9 @@ class XlaOpRegistrationBuilder { // Allow DT_RESOURCE types for type parameters. XlaOpRegistrationBuilder& AllowResourceTypes(); + // Mark 'input_name' as an argument whose value must be known at compile-time. + XlaOpRegistrationBuilder& CompileTimeConstInput(StringPiece input_name); + std::unique_ptr Build( XlaOpRegistry::Factory factory); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc new file mode 100644 index 0000000000000000000000000000000000000000..c2075b44b82ba279d1246ec6bfcf305d12c418a6 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -0,0 +1,194 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_resource.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" + +namespace tensorflow { + +XlaResource::XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients) + : kind_(kind), + arg_num_(arg_num), + name_(std::move(name)), + type_(type), + shape_(std::move(shape)), + value_(initial_value), + initial_value_(initial_value), + tensor_array_size_(tensor_array_size) { + CHECK(kind_ != kInvalid); + + for (const string& gradient : tensor_array_gradients) { + tensor_array_gradients_[gradient].reset( + new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + type_, shape_, xla::ComputationDataHandle(), + tensor_array_size_, /*tensor_array_gradients=*/{})); + } +} + +Status XlaResource::SetTypeAndShape(DataType type, const TensorShape& shape) { + if (type == DT_INVALID) { + return errors::InvalidArgument("Attempted to set type of resource '", name_, + "'' to an invalid type"); + } + if (initialized() && type_ != type) { + return errors::InvalidArgument("Type of resource ", name_, + " cannot be changed after initialization: " + "old type was ", + DataTypeString(type_), ", new type is ", + DataTypeString(type)); + } + if (initialized() && shape_ != shape) { + return errors::InvalidArgument("Shape of resource ", name_, + " cannot be changed after initialization: " + "old shape was ", + shape_.DebugString(), ", new shape is ", + shape.DebugString()); + } + type_ = type; + shape_ = shape; + return Status::OK(); +} + +Status XlaResource::SetValue(const xla::ComputationDataHandle& value) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); + } + value_ = value; + return Status::OK(); +} + +Status XlaResource::SetZeroValue(xla::ComputationBuilder* builder) { + if (type_ == DT_INVALID) { + return errors::InvalidArgument( + "Resource '", name_, + "' must be initialized with a valid type before use."); + } + switch (kind_) { + case kVariable: { + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + shape_.dim_sizes()); + break; + } + case kTensorArray: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()); + break; + } + case kStack: { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + value_ = + builder->Tuple({builder->Broadcast(XlaHelpers::Zero(builder, type_), + ta_shape.dim_sizes()), + builder->ConstantR0(0)}); + break; + } + + case kInvalid: + default: + LOG(FATAL) << "Invalid resource type"; + } + return Status::OK(); +} + +Status XlaResource::GetOrCreateTensorArrayGradient( + const string& source, xla::ComputationBuilder* builder, + XlaResource** gradient_out) { + VLOG(2) << "Gradient lookup for resource: " << name_ + << " gradient: " << source; + TF_RET_CHECK(kind_ == kTensorArray); + std::unique_ptr& gradient = tensor_array_gradients_[source]; + if (!gradient) { + TensorShape ta_shape; + ta_shape.AddDim(tensor_array_size_); + ta_shape.AppendShape(shape_); + xla::ComputationDataHandle gradient_value = builder->Broadcast( + XlaHelpers::Zero(builder, type_), ta_shape.dim_sizes()); + gradient.reset( + new XlaResource(/*kind=*/kTensorArray, /*arg_num=*/-1, + /*name=*/strings::StrCat("TensorArrayGrad: ", name_), + type_, shape_, gradient_value, tensor_array_size_, + /*tensor_array_gradients=*/{})); + } + *gradient_out = gradient.get(); + return Status::OK(); +} + +Status XlaResource::Pack(xla::ComputationDataHandle* pack, + xla::ComputationBuilder* builder) const { + if (tensor_array_gradients_.empty()) { + *pack = value_; + } else { + TF_RET_CHECK(kind_ == kTensorArray); + std::vector elems; + elems.push_back(value_); + for (const auto& gradient : tensor_array_gradients_) { + elems.push_back(gradient.second->value_); + } + *pack = builder->Tuple(elems); + } + return Status::OK(); +} + +Status XlaResource::SetFromPack(const std::set& gradient_sources, + const xla::ComputationDataHandle& pack, + xla::ComputationBuilder* builder) { + if (gradient_sources.empty()) { + if (!initialized()) { + initial_value_ = pack; + } + value_ = pack; + } else { + TF_RET_CHECK(kind_ == kTensorArray); + int pos = 0; + auto v = builder->GetTupleElement(pack, pos++); + if (!initialized()) { + initial_value_ = v; + } + value_ = v; + + for (const auto& source : gradient_sources) { + XlaResource* gradient; + TF_RETURN_IF_ERROR( + GetOrCreateTensorArrayGradient(source, builder, &gradient)); + auto v = builder->GetTupleElement(pack, pos++); + if (!gradient->initialized()) { + gradient->initial_value_ = v; + } + gradient->value_ = v; + } + } + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h new file mode 100644 index 0000000000000000000000000000000000000000..1bb2c7274ecdf0954768fd96def51194e52deee8 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -0,0 +1,157 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Represents a resource, such as a Variable or TensorArray. +class XlaResource { + public: + enum Kind { + kInvalid, + kVariable, + kTensorArray, + kStack, + }; + + XlaResource(Kind kind, int arg_num, string name, DataType type, + TensorShape shape, + const xla::ComputationDataHandle& initial_value, + int64 tensor_array_size, + const std::set& tensor_array_gradients); + + XlaResource(const XlaResource&) = delete; + XlaResource(XlaResource&&) = delete; + XlaResource& operator=(const XlaResource&) = delete; + XlaResource& operator=(XlaResource&&) = delete; + + Kind kind() const { return kind_; } + + // If this resource is visible externally to the computation, what was its + // argument number? + // < 0 means "not visible externally". + int arg_num() const { return arg_num_; } + + // A descriptive name for the resource, used in error messages. + const string& name() const { return name_; } + + // Current type and value of the resource. Uninitialized resources are + // represented by a default (zero) handle and type DT_INVALID. + // While the type of a resource is notionally fixed during execution, when + // a resource is first initialized we do not yet know its type, so we keep + // track of its type dynamically. + DataType type() const { return type_; } + + // Shape of the resource. For an uninitialized resource, this is ignored. + // For a Variable, this is the shape of the value. For a TensorArray or Stack + // this is the shape of each entry in the TensorArray/Stack. + const TensorShape& shape() const { return shape_; } + + const xla::ComputationDataHandle& value() const { return value_; } + + // Value of the resource at computation entry. Used to detect which + // variables have new values that need to be written back. + const xla::ComputationDataHandle& initial_value() const { + return initial_value_; + } + + // A variable is initialized if it has a value. + bool initialized() const { return value_.handle() > 0; } + + // Sets the type and shape of the resource. The type and shape of a resource + // must not change once the variable has been initialized. + Status SetTypeAndShape(DataType type, const TensorShape& shape); + + // Sets the current value of the resource. Returns an error if the type is not + // set to a valid value. + Status SetValue(const xla::ComputationDataHandle& value); + + // Sets the current value of the resource to an all-zero value. + Status SetZeroValue(xla::ComputationBuilder* builder); + + // Looks up the gradient for `source`, or creates it if it does not already + // exist. The call target must be an initialized TensorArray resource. A + // TensorArray can have multiple named gradients; see the operator + // documentation for TensorArrayGradV3 for details. + Status GetOrCreateTensorArrayGradient(const string& source, + xla::ComputationBuilder* builder, + XlaResource** gradient_out); + + // Packs a resource into a single XLA value `pack`, suitable for use as + // an XlaCompiler::Argument. For non-TensorArrays or TensorArrays without + // gradients, sets `*pack` to `value`. + // For TensorArrays with gradients, packs the value and its gradient values in + // a tuple; the gradients values are packed in order by source name. + Status Pack(xla::ComputationDataHandle* pack, + xla::ComputationBuilder* builder) const; + + // Updates the resource with values from `pack`. If `gradient_sources` is + // non-empty, treats `pack` as a tuple that represents a TensorArray and + // its gradients, and unpacks and updates the gradient resources. + // If `reset_initial_values` is true, sets the initial_values as well as the + // values. + // Opposite of Pack(). + Status SetFromPack(const std::set& gradient_sources, + const xla::ComputationDataHandle& pack, + xla::ComputationBuilder* builder); + + // TensorArray and Stack specific fields + + // 'tensor_array_size' stores the expected size of the TensorArray or Stack. + // We need to store this since sometimes TensorArrays must be initialized + // lazily since we do not know the element shape at construction time. + // Used by both TensorArrays and Stacks. + int64 tensor_array_size() const { return tensor_array_size_; } + void set_tensor_array_size(int64 size) { tensor_array_size_ = size; } + + // 'tensor_array_gradient' is a map from TensorArrayGradV3 'source' attributes + // to an XlaResource containing the gradient TensorArrays. We store a pointer + // here since there should only be one gradient TensorArray per 'source' + // string, irrespective of the number of calls to TensorArrayGrad. The map + // is ordered since values are packed into tuples by Pack() sorted by name + // order. + const std::map>& tensor_array_gradients() + const { + return tensor_array_gradients_; + } + + private: + const Kind kind_; + const int arg_num_; + const string name_; + + DataType type_; + TensorShape shape_; + xla::ComputationDataHandle value_; + xla::ComputationDataHandle initial_value_; + + int64 tensor_array_size_ = -1; + + std::map> tensor_array_gradients_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_RESOURCE_H_ diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d3f292207fee396fb4248dede5c0eeb5cd2b87c9..34e733bc8d80b364cec1783006eba0a5468b55ea 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -20,6 +20,10 @@ package_group( load("//tensorflow:tensorflow.bzl", "cc_header_only_library") load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow/compiler/xla:xla.bzl", "xla_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library_py", +) # Filegroup used to collect source files for dependency checking. filegroup( @@ -36,6 +40,12 @@ xla_proto_library( visibility = ["//visibility:public"], ) +tf_proto_library_py( + name = "xla_data_proto", # bzl adds a _py suffix + srcs = ["xla_data.proto"], + visibility = ["//visibility:public"], +) + xla_proto_library( name = "xla_proto", srcs = ["xla.proto"], @@ -78,7 +88,6 @@ cc_library( visibility = [":friends"], deps = [ "//tensorflow/core:framework_lite", - "//tensorflow/core:lib", "//third_party/eigen3", ], ) @@ -172,6 +181,7 @@ cc_library( deps = [ ":status", ":status_macros", + ":statusor", ":types", ":xla_data_proto", "//tensorflow/core:lib", @@ -250,6 +260,7 @@ tf_cc_test( srcs = ["shape_util_test.cc"], deps = [ ":shape_util", + ":status_macros", ":test", ":test_helpers", ":types", @@ -290,7 +301,9 @@ cc_library( ":array2d", ":array3d", ":array4d", + ":shape_tree", ":shape_util", + ":sparse_index_array", ":status_macros", ":types", ":util", @@ -617,6 +630,28 @@ tf_cc_test( ], ) +cc_library( + name = "sparse_index_array", + srcs = ["sparse_index_array.cc"], + hdrs = ["sparse_index_array.h"], + deps = [ + ":array2d", + ":shape_util", + ":xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "sparse_index_array_test", + srcs = ["sparse_index_array_test.cc"], + deps = [ + ":sparse_index_array", + ":test", + "//tensorflow/core:test_main", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h index 213e0bac6c77e9972de8d4dd7dfc8c7cf3a1b865..71aa057cd3a1c273c0e851497a78f94ba37c778e 100644 --- a/tensorflow/compiler/xla/array.h +++ b/tensorflow/compiler/xla/array.h @@ -22,6 +22,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index f953407a567b91fdf6ae727d6982a2a778c5873e..02356699a25e47be50eb15872df4c9c302fc289b 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -80,6 +80,18 @@ cc_library( ], ) +cc_library( + name = "executable_build_options", + srcs = ["executable_build_options.cc"], + hdrs = ["executable_build_options.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/core:lib", + ], +) + cc_library( name = "local_client", srcs = ["local_client.cc"], @@ -87,6 +99,7 @@ cc_library( deps = [ ":client", ":computation", + ":executable_build_options", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -98,6 +111,7 @@ cc_library( "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:local_service", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:source_map_util", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:support", @@ -186,6 +200,20 @@ cc_library( ], ) +cc_library( + name = "sharding_builder", + srcs = ["sharding_builder.cc"], + hdrs = ["sharding_builder.h"], + deps = [ + "//tensorflow/compiler/xla:array", + "//tensorflow/compiler/xla:shape_tree", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 66937d64aff18817bbd5310e0c24e19556e9d727..d15ccb0c28522c647617153aaa8e738d029dfaba 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -60,7 +60,7 @@ StatusOr> Client::Transfer( "server provided response without a literal in " "TransferToClient request"); } - return MakeUnique(response.literal()); + return Literal::CreateFromProto(*response.mutable_literal()); } StatusOr> Client::TransferToServer( @@ -142,7 +142,7 @@ StatusOr> Client::TransferFromOutfeed( "TransferToClient request"); } - return MakeUnique(response.literal()); + return Literal::CreateFromProto(response.literal()); } Status Client::ResetDevice() { diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index cce931000331e98b00f57025cb13a5d3982c2845..b1dcad6a49a270935b07e26de2d3945b912359d1 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -34,25 +34,9 @@ limitations under the License. namespace xla { -ComputationDataHandle ComputationBuilder::ParseOpResponse( - const Status& status, OpResponse* response) { - VLOG(2) << "done with op request"; - - if (!status.ok()) { - NoteError(status); - return ComputationDataHandle(); - } - - if (response->output().handle() == 0) { - NoteError(InternalError("No output handle")); - return ComputationDataHandle(); - } - return response->output(); -} - ComputationBuilder::ComputationBuilder(Client* client, const string& computation_name) - : name_(computation_name), first_error_(Status::OK()), client_(client) {} + : name_(computation_name), client_(client) {} ComputationBuilder::~ComputationBuilder() {} @@ -76,9 +60,8 @@ std::unique_ptr ComputationBuilder::CreateSubBuilder( } Status ComputationBuilder::PrepareComputation() { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); + if (!computation_.IsNull()) { return Status::OK(); } @@ -100,6 +83,49 @@ Status ComputationBuilder::PrepareComputation() { return Status::OK(); } +Status ComputationBuilder::RunOp(OpRequest* op_request, + OpResponse* op_response) { + TF_RETURN_IF_ERROR(first_error_); + TF_RETURN_IF_ERROR(PrepareComputation()); + + // Fill in fields that are set on every OpRequest. + *op_request->mutable_computation() = computation_.handle(); + *op_request->mutable_metadata() = metadata_; + if (sharding_) { + *op_request->mutable_sharding() = *sharding_; + } + + const string& op_name = + OpRequest::descriptor()->FindFieldByNumber(op_request->op_case())->name(); + VLOG(2) << "running op request: " << op_name; + Status status = client_->stub()->Op(op_request, op_response); + VLOG(2) << "done with op request: " << op_name; + return status; +} + +void ComputationBuilder::RunOpAndNoteError(OpRequest* op_request) { + OpResponse op_response; + Status status = RunOp(op_request, &op_response); + if (!status.ok()) { + NoteError(status); + } +} + +ComputationDataHandle ComputationBuilder::RunOpAndParseResponse( + OpRequest* op_request) { + OpResponse op_response; + Status status = RunOp(op_request, &op_response); + if (!status.ok()) { + NoteError(status); + return ComputationDataHandle(); + } + if (op_response.output().handle() == 0) { + NoteError(InternalError("No output handle")); + return ComputationDataHandle(); + } + return op_response.output(); +} + bool ComputationBuilder::MakeWindow( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, @@ -158,81 +184,75 @@ bool ComputationBuilder::MakeWindow( return true; } -ComputationDataHandle ComputationBuilder::ConstantOp( - const PopulateLiteral& populate) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ConstantRequest request; - Literal literal; - populate(&literal); - *request.mutable_literal() = literal.ToProto(); - VLOG(3) << "created constant: " << request.literal().ShortDebugString(); - OpRequest op_request; - *op_request.mutable_constant_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making constant request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); -} - ComputationDataHandle ComputationBuilder::ConstantLiteral( const Literal& literal) { - return ConstantOp( - [literal](Literal* mutable_literal) { *mutable_literal = literal; }); + OpRequest op_request; + ConstantRequest* request = op_request.mutable_constant_request(); + *request->mutable_literal() = literal.ToProto(); + VLOG(3) << "created constant: " << request->literal().ShortDebugString(); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Parameter(int64 parameter_number, const Shape& shape, const string& name) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ParameterRequest request; - *request.mutable_shape() = shape; - request.set_parameter(parameter_number); - request.set_name(name); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_parameter_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making parameter request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + ParameterRequest* request = op_request.mutable_parameter_request(); + *request->mutable_shape() = shape; + request->set_parameter(parameter_number); + request->set_name(name); + return RunOpAndParseResponse(&op_request); } -StatusOr> ComputationBuilder::GetShape( +StatusOr> ComputationBuilder::GetShapeWithoutNoteError( const ComputationDataHandle& operand) { - if (!first_error_.ok()) { - return first_error_; - } - GetLocalShapeRequest request; *request.mutable_computation() = computation_.handle(); *request.mutable_operand() = operand; GetLocalShapeResponse response; VLOG(2) << "making get-shape request"; - Status s = client_->stub()->GetLocalShape(&request, &response); + TF_RETURN_IF_ERROR(client_->stub()->GetLocalShape(&request, &response)); VLOG(2) << "done with request"; - if (!s.ok()) { - NoteError(s); - return first_error_; - } TF_RET_CHECK(response.has_shape()); std::unique_ptr shape = WrapUnique(response.release_shape()); TF_RET_CHECK(shape != nullptr); return std::move(shape); } +StatusOr> ComputationBuilder::GetShape( + const ComputationDataHandle& operand) { + TF_RETURN_IF_ERROR(first_error_); + + auto status_or_shape = GetShapeWithoutNoteError(operand); + if (!status_or_shape.ok()) { + NoteError(status_or_shape.status()); + return first_error_; + } + return status_or_shape; +} + +StatusOr ComputationBuilder::GetProgramShape() { + TF_RETURN_IF_ERROR(first_error_); + + GetComputationShapeRequest request; + *request.mutable_computation() = computation_.handle(); + GetComputationShapeResponse response; + + VLOG(2) << "making get-program-shape-request"; + Status status = client_->stub()->GetComputationShape(&request, &response); + VLOG(2) << "done with get-program-shape-request"; + + if (!status.ok()) { + first_error_ = status; + return status; + } + + TF_RET_CHECK(response.has_program_shape()); + return std::move(*response.mutable_program_shape()); +} + ComputationDataHandle ComputationBuilder::CheckShape( const ComputationDataHandle& operand, const Shape& expected_shape) { std::unique_ptr actual_shape = GetShape(operand).ConsumeValueOrDie(); @@ -258,30 +278,19 @@ ComputationDataHandle ComputationBuilder::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices, tensorflow::gtl::ArraySlice strides) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - SliceRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + SliceRequest* request = op_request.mutable_slice_request(); + *request->mutable_operand() = operand; for (int64 index : start_indices) { - request.add_start_indices(index); + request->add_start_indices(index); } for (int64 index : limit_indices) { - request.add_limit_indices(index); + request->add_limit_indices(index); } for (int64 index : strides) { - request.add_strides(index); + request->add_strides(index); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_slice_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making slice request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::SliceInDim( @@ -307,143 +316,78 @@ ComputationDataHandle ComputationBuilder::DynamicSlice( const ComputationDataHandle& operand, const ComputationDataHandle& start_indices, tensorflow::gtl::ArraySlice slice_sizes) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - DynamicSliceRequest request; - *request.mutable_operand() = operand; - *request.mutable_start_indices() = start_indices; + OpRequest op_request; + DynamicSliceRequest* request = op_request.mutable_dynamic_slice_request(); + *request->mutable_operand() = operand; + *request->mutable_start_indices() = start_indices; for (int64 index : slice_sizes) { - request.add_slice_sizes(index); + request->add_slice_sizes(index); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_dynamic_slice_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making dynamic slice request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::DynamicUpdateSlice( const ComputationDataHandle& operand, const ComputationDataHandle& update, const ComputationDataHandle& start_indices) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - DynamicUpdateSliceRequest request; - *request.mutable_operand() = operand; - *request.mutable_update() = update; - *request.mutable_start_indices() = start_indices; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_dynamic_update_slice_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making dynamic update slice request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + DynamicUpdateSliceRequest* request = + op_request.mutable_dynamic_update_slice_request(); + *request->mutable_operand() = operand; + *request->mutable_update() = update; + *request->mutable_start_indices() = start_indices; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::ConcatInDim( tensorflow::gtl::ArraySlice operands, int64 dimension) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ConcatenateRequest request; + OpRequest op_request; + ConcatenateRequest* request = op_request.mutable_concatenate_request(); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - request.set_dimension(dimension); - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_concatenate_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making concatenate request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + request->set_dimension(dimension); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Broadcast( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - BroadcastRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + BroadcastRequest* request = op_request.mutable_broadcast_request(); + *request->mutable_operand() = operand; for (int64 size : broadcast_sizes) { - request.add_broadcast_sizes(size); + request->add_broadcast_sizes(size); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_broadcast_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making broadcast request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Pad( const ComputationDataHandle& operand, const ComputationDataHandle& padding_value, const PaddingConfig& padding_config) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - PadRequest request; - *request.mutable_operand() = operand; - *request.mutable_padding_value() = padding_value; - *request.mutable_padding_config() = padding_config; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_pad_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making pad request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + PadRequest* request = op_request.mutable_pad_request(); + *request->mutable_operand() = operand; + *request->mutable_padding_value() = padding_value; + *request->mutable_padding_config() = padding_config; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Reshape( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice new_sizes) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReshapeRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + ReshapeRequest* request = op_request.mutable_reshape_request(); + *request->mutable_operand() = operand; for (int64 dimension : dimensions) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } for (int64 new_size : new_sizes) { - request.add_new_sizes(new_size); + request->add_new_sizes(new_size); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reshape_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reshape request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Reshape( @@ -455,7 +399,6 @@ ComputationDataHandle ComputationBuilder::Reshape( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } std::vector dimensions(shape.ValueOrDie()->dimensions().size()); @@ -485,7 +428,6 @@ ComputationDataHandle ComputationBuilder::Collapse( // dimensions by the product of their sizes. StatusOr> shape_or_status = GetShape(operand); if (!shape_or_status.ok()) { - first_error_ = shape_or_status.status(); return ComputationDataHandle(); } std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); @@ -517,26 +459,11 @@ ComputationDataHandle ComputationBuilder::Collapse( void ComputationBuilder::Trace(const string& tag, const ComputationDataHandle& operand) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return; - } - - TraceRequest request; - request.set_tag(tag); - *request.mutable_operand() = operand; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_trace_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making trace request"; - Status s = client_->stub()->Op(&op_request, &response); - VLOG(2) << "done with request"; - - if (!s.ok()) { - NoteError(s); - } + TraceRequest* request = op_request.mutable_trace_request(); + request->set_tag(tag); + *request->mutable_operand() = operand; + RunOpAndNoteError(&op_request); } ComputationDataHandle ComputationBuilder::Select( @@ -547,44 +474,23 @@ ComputationDataHandle ComputationBuilder::Select( ComputationDataHandle ComputationBuilder::Tuple( tensorflow::gtl::ArraySlice elements) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - VariadicOpRequest request; - request.set_varop(VAROP_TUPLE); + OpRequest op_request; + VariadicOpRequest* request = op_request.mutable_variadic_op_request(); + request->set_varop(VAROP_TUPLE); for (const ComputationDataHandle& operand : elements) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_variadic_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making variadic op request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::GetTupleElement( const ComputationDataHandle& tuple_data, int64 index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - GetTupleElementRequest request; - *request.mutable_operand() = tuple_data; - request.set_index(index); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_get_tuple_element_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making get tuple element op request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + GetTupleElementRequest* request = + op_request.mutable_get_tuple_element_request(); + *request->mutable_operand() = tuple_data; + request->set_index(index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Eq( @@ -625,16 +531,33 @@ ComputationDataHandle ComputationBuilder::Lt( ComputationDataHandle ComputationBuilder::Dot( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { - return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{}); + StatusOr> lhs_shape_or_status = GetShape(lhs); + if (!lhs_shape_or_status.ok()) { + return ComputationDataHandle(); + } + std::unique_ptr lhs_shape = lhs_shape_or_status.ConsumeValueOrDie(); + + DotDimensionNumbers dimension_numbers; + dimension_numbers.add_lhs_contracting_dimensions( + lhs_shape->dimensions_size() == 1 ? 0 : 1); + dimension_numbers.add_rhs_contracting_dimensions(0); + return DotGeneral(lhs, rhs, dimension_numbers); +} + +ComputationDataHandle ComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + OpRequest op_request; + DotRequest* request = op_request.mutable_dot_request(); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; + *request->mutable_dimension_numbers() = dimension_numbers; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Conv( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - return ConvWithGeneralDimensions( lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size())); @@ -644,10 +567,6 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralPadding( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - return ConvGeneral(lhs, rhs, window_strides, padding, CreateDefaultConvDimensionNumbers(window_strides.size())); } @@ -715,13 +634,11 @@ ComputationDataHandle ComputationBuilder::ConvWithGeneralDimensions( StatusOr> lhs_shape_or_status = GetShape(lhs); if (!lhs_shape_or_status.ok()) { - first_error_ = lhs_shape_or_status.status(); return ComputationDataHandle(); } StatusOr> rhs_shape_or_status = GetShape(rhs); if (!rhs_shape_or_status.ok()) { - first_error_ = rhs_shape_or_status.status(); return ComputationDataHandle(); } @@ -776,13 +693,11 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( StatusOr> lhs_shape_or_status = GetShape(lhs); if (!lhs_shape_or_status.ok()) { - first_error_ = lhs_shape_or_status.status(); return ComputationDataHandle(); } StatusOr> rhs_shape_or_status = GetShape(rhs); if (!rhs_shape_or_status.ok()) { - first_error_ = rhs_shape_or_status.status(); return ComputationDataHandle(); } @@ -800,122 +715,78 @@ ComputationDataHandle ComputationBuilder::ConvGeneralDilated( rhs_shape->dimensions(dimension_numbers.kernel_spatial_dimensions(i)); } - ConvolveRequest request; - *request.mutable_lhs() = lhs; - *request.mutable_rhs() = rhs; - *request.mutable_dimension_numbers() = dimension_numbers; + OpRequest op_request; + ConvolveRequest* request = op_request.mutable_convolve_request(); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; + *request->mutable_dimension_numbers() = dimension_numbers; if (!MakeWindow(window_dimensions, window_strides, padding, lhs_dilation, - rhs_dilation, request.mutable_window())) { + rhs_dilation, request->mutable_window())) { // Error is recorded in MakeWindow. return ComputationDataHandle(); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_convolve_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - VLOG(2) << "making convolve request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } -ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, - const string& config) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); +ComputationDataHandle ComputationBuilder::Fft( + const ComputationDataHandle& operand, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + OpRequest op_request; + FftRequest* request = op_request.mutable_fft_request(); + *request->mutable_operand() = operand; + request->set_fft_type(fft_type); + for (int64 dim_len : fft_length) { + request->add_fft_length(dim_len); } + return RunOpAndParseResponse(&op_request); +} - InfeedRequest request; - *request.mutable_shape() = shape; - *request.mutable_config() = config; +ComputationDataHandle ComputationBuilder::Infeed(const Shape& shape, + const string& config) { OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_infeed_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making infeed op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + InfeedRequest* request = op_request.mutable_infeed_request(); + *request->mutable_shape() = shape; + *request->mutable_config() = config; + return RunOpAndParseResponse(&op_request); } void ComputationBuilder::Outfeed(const ComputationDataHandle& operand, const Shape& shape, const string& outfeed_config) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return; - } - - OutfeedRequest request; - request.set_outfeed_config(outfeed_config); - *request.mutable_operand() = operand; - *request.mutable_shape() = shape; OpRequest op_request; - *op_request.mutable_outfeed_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making outfeed op request"; - tensorflow::Status s = client_->stub()->Op(&op_request, &response); - - if (!s.ok()) { - NoteError(s); - return; - } + OutfeedRequest* request = op_request.mutable_outfeed_request(); + request->set_outfeed_config(outfeed_config); + *request->mutable_operand() = operand; + *request->mutable_shape() = shape; + RunOpAndNoteError(&op_request); } ComputationDataHandle ComputationBuilder::Call( const Computation& computation, tensorflow::gtl::ArraySlice operands) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - CallRequest request; - *request.mutable_to_apply() = computation.handle(); + OpRequest op_request; + CallRequest* request = op_request.mutable_call_request(); + *request->mutable_to_apply() = computation.handle(); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_call_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making call op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::CustomCall( const string& call_target_name, tensorflow::gtl::ArraySlice operands, const Shape& shape) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - CustomCallRequest request; - request.set_call_target_name(call_target_name); + OpRequest op_request; + CustomCallRequest* request = op_request.mutable_custom_call_request(); + request->set_call_target_name(call_target_name); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - *request.mutable_shape() = shape; - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_custom_call_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making custom call op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + *request->mutable_shape() = shape; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Complex( @@ -1080,47 +951,25 @@ ComputationDataHandle ComputationBuilder::IsFinite( ComputationDataHandle ComputationBuilder::Transpose( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice permutation) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); TransposeRequest* request = op_request.mutable_transpose_request(); *request->mutable_operand() = operand; for (int64 dimension : permutation) { request->add_dimensions(dimension); } - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making transpose request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Rev( const ComputationDataHandle& operand, tensorflow::gtl::ArraySlice dimensions) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReverseRequest request; - *request.mutable_operand() = operand; + OpRequest op_request; + ReverseRequest* request = op_request.mutable_reverse_request(); + *request->mutable_operand() = operand; for (int64 dimension : dimensions) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reverse_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reverse op request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Sort( @@ -1148,24 +997,15 @@ ComputationDataHandle ComputationBuilder::ConvertElementType( StatusOr> shape_status = GetShape(operand); if (!shape_status.ok()) { - first_error_ = shape_status.status(); return ComputationDataHandle(); } std::unique_ptr original = shape_status.ConsumeValueOrDie(); - ConvertRequest request; - *request.mutable_operand() = operand; - request.set_new_element_type(new_element_type); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_convert_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making convert request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + ConvertRequest* request = op_request.mutable_convert_request(); + *request->mutable_operand() = operand; + request->set_new_element_type(new_element_type); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BitcastConvertType( @@ -1176,24 +1016,15 @@ ComputationDataHandle ComputationBuilder::BitcastConvertType( StatusOr> shape_status = GetShape(operand); if (!shape_status.ok()) { - first_error_ = shape_status.status(); return ComputationDataHandle(); } std::unique_ptr original = shape_status.ConsumeValueOrDie(); - ConvertRequest request; - *request.mutable_operand() = operand; - request.set_new_element_type(new_element_type); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_bitcast_convert_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making bitcast convert request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + ConvertRequest* request = op_request.mutable_bitcast_convert_request(); + *request->mutable_operand() = operand; + request->set_new_element_type(new_element_type); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::SquareF32( @@ -1221,107 +1052,57 @@ ComputationDataHandle ComputationBuilder::Clamp( ComputationDataHandle ComputationBuilder::UnaryOp( UnaryOperation unop, const ComputationDataHandle& operand) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - UnaryOpRequest request; - request.set_unop(unop); - *request.mutable_operand() = operand; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_unary_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making unop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + UnaryOpRequest* request = op_request.mutable_unary_op_request(); + request->set_unop(unop); + *request->mutable_operand() = operand; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BinaryOp( BinaryOperation binop, const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - BinaryOpRequest request; - request.set_binop(binop); - *request.mutable_lhs() = lhs; - *request.mutable_rhs() = rhs; + OpRequest op_request; + BinaryOpRequest* request = op_request.mutable_binary_op_request(); + request->set_binop(binop); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; for (int64 dimension : broadcast_dimensions) { - request.add_broadcast_dimensions(dimension); + request->add_broadcast_dimensions(dimension); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_binary_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making binop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::RngOp( RandomDistribution distribution, tensorflow::gtl::ArraySlice parameters, const Shape& shape) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - RngRequest request; - request.set_distribution(distribution); + OpRequest op_request; + RngRequest* request = op_request.mutable_rng_request(); + request->set_distribution(distribution); for (const ComputationDataHandle& param : parameters) { - *request.add_parameter() = param; + *request->add_parameter() = param; } - *request.mutable_shape() = shape; - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_rng_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making rngop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + *request->mutable_shape() = shape; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::TernaryOp( TernaryOperation triop, const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, const ComputationDataHandle& ehs) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - TernaryOpRequest request; - request.set_triop(triop); - *request.mutable_lhs() = lhs; - *request.mutable_rhs() = rhs; - *request.mutable_ehs() = ehs; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_ternary_op_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making triop request"; - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + TernaryOpRequest* request = op_request.mutable_ternary_op_request(); + request->set_triop(triop); + *request->mutable_lhs() = lhs; + *request->mutable_rhs() = rhs; + *request->mutable_ehs() = ehs; + return RunOpAndParseResponse(&op_request); } Status ComputationBuilder::SetReturnValue( const ComputationDataHandle& operand) { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); SetReturnValueRequest request; *request.mutable_computation() = computation_.handle(); @@ -1343,9 +1124,7 @@ Status ComputationBuilder::SetReturnValue( StatusOr ComputationBuilder::IsConstant( const ComputationDataHandle& operand, int64 num_parameters) { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); IsConstantRequest request; *request.mutable_computation() = computation_.handle(); @@ -1366,9 +1145,7 @@ StatusOr ComputationBuilder::IsConstant( StatusOr> ComputationBuilder::ComputeConstant( const ComputationDataHandle& operand, const Layout* output_layout, tensorflow::gtl::ArraySlice parameters) { - if (!first_error_.ok()) { - return first_error_; - } + TF_RETURN_IF_ERROR(first_error_); ComputeConstantRequest request; *request.mutable_computation() = computation_.handle(); @@ -1397,7 +1174,7 @@ StatusOr> ComputationBuilder::ComputeConstant( "no computed literal in the provided response in ComputeConstant " "request"); } - return MakeUnique(response.literal()); + return Literal::CreateFromProto(response.literal()); } ComputationDataHandle ComputationBuilder::Map( @@ -1405,30 +1182,19 @@ ComputationDataHandle ComputationBuilder::Map( const Computation& computation, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice static_operands) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - MapRequest request; + OpRequest op_request; + MapRequest* request = op_request.mutable_map_request(); for (const ComputationDataHandle& operand : operands) { - *request.add_operands() = operand; + *request->add_operands() = operand; } - *request.mutable_to_apply() = computation.handle(); + *request->mutable_to_apply() = computation.handle(); for (int64 dimension : dimensions) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } for (const ComputationDataHandle& sop : static_operands) { - *request.add_static_operands() = sop; + *request->add_static_operands() = sop; } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_map_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making Map request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::RngNormal( @@ -1443,57 +1209,46 @@ ComputationDataHandle ComputationBuilder::RngUniform( return RngOp(RandomDistribution::RNG_UNIFORM, {a, b}, shape); } -ComputationDataHandle ComputationBuilder::RngBernoulli( - const ComputationDataHandle& mean, const Shape& shape) { - return RngOp(RandomDistribution::RNG_BERNOULLI, {mean}, shape); -} - ComputationDataHandle ComputationBuilder::While( const Computation& condition, const Computation& body, const ComputationDataHandle& init) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - WhileRequest request; - *request.mutable_condition() = condition.handle(); - *request.mutable_body() = body.handle(); - *request.mutable_init() = init; OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_while_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making while request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + WhileRequest* request = op_request.mutable_while_request(); + *request->mutable_condition() = condition.handle(); + *request->mutable_body() = body.handle(); + *request->mutable_init() = init; + return RunOpAndParseResponse(&op_request); +} + +ComputationDataHandle ComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation) { + OpRequest op_request; + ConditionalRequest* request = op_request.mutable_conditional_request(); + *request->mutable_predicate() = predicate; + *request->mutable_true_operand() = true_operand; + *request->mutable_true_computation() = true_computation.handle(); + *request->mutable_false_operand() = false_operand; + *request->mutable_false_computation() = false_computation.handle(); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::Reduce( const ComputationDataHandle& operand, const ComputationDataHandle& init_value, const Computation& computation, tensorflow::gtl::ArraySlice dimensions_to_reduce) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReduceRequest request; - *request.mutable_operand() = operand; - *request.mutable_init_value() = init_value; + OpRequest op_request; + ReduceRequest* request = op_request.mutable_reduce_request(); + *request->mutable_operand() = operand; + *request->mutable_init_value() = init_value; for (int64 dimension : dimensions_to_reduce) { - request.add_dimensions(dimension); + request->add_dimensions(dimension); } - *request.mutable_to_apply() = computation.handle(); - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reduce_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reduce request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + *request->mutable_to_apply() = computation.handle(); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::ReduceAll( @@ -1505,7 +1260,6 @@ ComputationDataHandle ComputationBuilder::ReduceAll( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } @@ -1525,7 +1279,6 @@ ComputationDataHandle ComputationBuilder::ReduceWindow( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } @@ -1551,84 +1304,50 @@ ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, tensorflow::gtl::ArraySlice> padding) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReduceWindowRequest request; - *request.mutable_operand() = operand; - *request.mutable_to_apply() = computation.handle(); - *request.mutable_init_value() = init_value; + OpRequest op_request; + ReduceWindowRequest* request = op_request.mutable_reduce_window_request(); + *request->mutable_operand() = operand; + *request->mutable_to_apply() = computation.handle(); + *request->mutable_init_value() = init_value; if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request.mutable_window())) { + request->mutable_window())) { NoteError(InternalError("failed to make window")); return ComputationDataHandle(); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reduce_window_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - VLOG(2) << "making reduce-window request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BatchNormTraining( const ComputationDataHandle& operand, const ComputationDataHandle& scale, const ComputationDataHandle& offset, float epsilon, int64 feature_index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - BatchNormTrainingRequest request; - *request.mutable_operand() = operand; - *request.mutable_scale() = scale; - *request.mutable_offset() = offset; - request.set_epsilon(epsilon); - request.set_feature_index(feature_index); - OpRequest op_request; - *op_request.mutable_batch_norm_training_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - - OpResponse response; - - VLOG(2) << "making BatchNormTraining request"; - - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + BatchNormTrainingRequest* request = + op_request.mutable_batch_norm_training_request(); + *request->mutable_operand() = operand; + *request->mutable_scale() = scale; + *request->mutable_offset() = offset; + request->set_epsilon(epsilon); + request->set_feature_index(feature_index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BatchNormInference( const ComputationDataHandle& operand, const ComputationDataHandle& scale, const ComputationDataHandle& offset, const ComputationDataHandle& mean, const ComputationDataHandle& variance, float epsilon, int64 feature_index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - BatchNormInferenceRequest request; - *request.mutable_operand() = operand; - *request.mutable_scale() = scale; - *request.mutable_offset() = offset; - *request.mutable_mean() = mean; - *request.mutable_variance() = variance; - request.set_epsilon(epsilon); - request.set_feature_index(feature_index); - OpRequest op_request; - *op_request.mutable_batch_norm_inference_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - - OpResponse response; - - VLOG(2) << "making BatchNormInference request"; - - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + BatchNormInferenceRequest* request = + op_request.mutable_batch_norm_inference_request(); + *request->mutable_operand() = operand; + *request->mutable_scale() = scale; + *request->mutable_offset() = offset; + *request->mutable_mean() = mean; + *request->mutable_variance() = variance; + request->set_epsilon(epsilon); + request->set_feature_index(feature_index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::BatchNormGrad( @@ -1636,49 +1355,25 @@ ComputationDataHandle ComputationBuilder::BatchNormGrad( const ComputationDataHandle& mean, const ComputationDataHandle& var, const ComputationDataHandle& grad_output, float epsilon, int64 feature_index) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - BatchNormGradRequest request; - *request.mutable_operand() = operand; - *request.mutable_scale() = scale; - *request.mutable_mean() = mean; - *request.mutable_variance() = var; - *request.mutable_grad_output() = grad_output; - request.set_epsilon(epsilon); - request.set_feature_index(feature_index); - OpRequest op_request; - *op_request.mutable_batch_norm_grad_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - - OpResponse response; - - VLOG(2) << "making BatchNormGrad request"; - - Status s = client_->stub()->Op(&op_request, &response); - - return ParseOpResponse(s, &response); + BatchNormGradRequest* request = op_request.mutable_batch_norm_grad_request(); + *request->mutable_operand() = operand; + *request->mutable_scale() = scale; + *request->mutable_mean() = mean; + *request->mutable_variance() = var; + *request->mutable_grad_output() = grad_output; + request->set_epsilon(epsilon); + request->set_feature_index(feature_index); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::CrossReplicaSum( const ComputationDataHandle& operand) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - CrossReplicaSumRequest request; - *request.mutable_operand() = operand; OpRequest op_request; - *op_request.mutable_cross_replica_sum_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making cross-replica-sum request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + CrossReplicaSumRequest* request = + op_request.mutable_cross_replica_sum_request(); + *request->mutable_operand() = operand; + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::SelectAndScatter( @@ -1693,7 +1388,6 @@ ComputationDataHandle ComputationBuilder::SelectAndScatter( StatusOr> shape = GetShape(operand); if (!shape.ok()) { - first_error_ = shape.status(); return ComputationDataHandle(); } return SelectAndScatterWithGeneralPadding( @@ -1710,98 +1404,53 @@ ComputationDataHandle ComputationBuilder::SelectAndScatterWithGeneralPadding( tensorflow::gtl::ArraySlice> padding, const ComputationDataHandle& source, const ComputationDataHandle& init_value, const Computation& scatter) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - SelectAndScatterRequest request; - *request.mutable_operand() = operand; - *request.mutable_select() = select.handle(); - *request.mutable_source() = source; - *request.mutable_init_value() = init_value; - *request.mutable_scatter() = scatter.handle(); + OpRequest op_request; + SelectAndScatterRequest* request = + op_request.mutable_select_and_scatter_request(); + *request->mutable_operand() = operand; + *request->mutable_select() = select.handle(); + *request->mutable_source() = source; + *request->mutable_init_value() = init_value; + *request->mutable_scatter() = scatter.handle(); if (!MakeWindow(window_dimensions, window_strides, padding, {}, {}, - request.mutable_window())) { + request->mutable_window())) { NoteError(InternalError("failed to make window")); return ComputationDataHandle(); } - OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_select_and_scatter_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - VLOG(2) << "making select-and-scatter request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + return RunOpAndParseResponse(&op_request); } ComputationDataHandle ComputationBuilder::ReducePrecision( const ComputationDataHandle& operand, const int exponent_bits, const int mantissa_bits) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - ReducePrecisionRequest request; - *request.mutable_operand() = operand; - request.set_exponent_bits(exponent_bits); - request.set_mantissa_bits(mantissa_bits); OpRequest op_request; - *op_request.mutable_computation() = computation_.handle(); - *op_request.mutable_reduce_precision_request() = request; - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making reduce-precision request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + ReducePrecisionRequest* request = + op_request.mutable_reduce_precision_request(); + *request->mutable_operand() = operand; + request->set_exponent_bits(exponent_bits); + request->set_mantissa_bits(mantissa_bits); + return RunOpAndParseResponse(&op_request); } void ComputationBuilder::Send(const ComputationDataHandle& operand, const ChannelHandle& handle) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return; - } - - SendRequest request; - *request.mutable_operand() = operand; - *request.mutable_channel_handle() = handle; OpRequest op_request; - *op_request.mutable_send_request() = request; + SendRequest* request = op_request.mutable_send_request(); + *request->mutable_operand() = operand; + *request->mutable_channel_handle() = handle; *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making send request"; - Status s = client_->stub()->Op(&op_request, &response); - VLOG(2) << "done with op request"; - - if (!s.ok()) { - NoteError(s); - return; - } + RunOpAndNoteError(&op_request); } ComputationDataHandle ComputationBuilder::Recv(const Shape& shape, const ChannelHandle& handle) { - if (!first_error_.ok() || !PrepareComputation().ok()) { - return ComputationDataHandle(); - } - - RecvRequest request; - *request.mutable_shape() = shape; - *request.mutable_channel_handle() = handle; OpRequest op_request; - *op_request.mutable_recv_request() = request; - *op_request.mutable_computation() = computation_.handle(); - AddCommonFieldsToOpRequest(&op_request); - OpResponse response; - - VLOG(2) << "making recv request"; - Status s = client_->stub()->Op(&op_request, &response); - return ParseOpResponse(s, &response); + RecvRequest* request = op_request.mutable_recv_request(); + *request->mutable_shape() = shape; + *request->mutable_channel_handle() = handle; + return RunOpAndParseResponse(&op_request); } Computation ComputationBuilder::BuildAndNoteError() { @@ -1830,13 +1479,6 @@ StatusOr ComputationBuilder::Build() { return {std::move(computation_)}; } -void ComputationBuilder::AddCommonFieldsToOpRequest(OpRequest* request) const { - *request->mutable_metadata() = metadata_; - if (sharding_) { - *request->mutable_sharding() = *sharding_; - } -} - /* static */ ConvolutionDimensionNumbers ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { ConvolutionDimensionNumbers dimension_numbers; diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index d2dbbbbebbd5a9386f8841576de33a1fdb767000..7cae91e9e04bba8f28f2348c552a941e4f7a36b4 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -43,59 +43,6 @@ limitations under the License. namespace xla { -class ShardingBuilder { - public: - // A shaped array used to describe the assignment of tiles to devices. - using TileAssignment = Array; - - // Creates a replicated sharding - replicate a tensor on every device. - static OpSharding Replicate() { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); - return result; - } - // Creates a sharding that assigns a tensor to just one device. - static OpSharding AssignDevice(int device) { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - result.add_tile_assignment_dimensions(1); - result.add_tile_assignment_devices(device); - return result; - } - // Creates a tiled sharding with the given tile shape and assignment of tiles - // to devices. - static OpSharding Tile(Shape tile_shape, - const TileAssignment& tile_assignment) { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - *result.mutable_tile_shape() = tile_shape; - for (int64 dim : tile_assignment.dimensions()) { - result.add_tile_assignment_dimensions(dim); - } - for (uint32 device : tile_assignment) { - result.add_tile_assignment_devices(device); - } - return result; - } - // Creates a sharding in one dimension, with the given tile shape which must - // be rank 1 and using devices 0..num_tiles. - static OpSharding Tile1D(Shape tile_shape, int64 num_tiles) { - OpSharding result; - result.set_type(OpSharding::Type::OpSharding_Type_OTHER); - - CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); - std::vector dimensions(1, num_tiles); - auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; - tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); - *result.mutable_tile_shape() = tile_shape; - result.add_tile_assignment_dimensions(num_tiles); - for (int64 i = 0; i < num_tiles; ++i) { - result.add_tile_assignment_devices(i); - } - return result; - } -}; - // Wraps an XLA client with a convenient interface for building up // computations. Any errors encountered in building up the computation are // deferred from being handled until Build() is called. @@ -120,7 +67,7 @@ class ComputationBuilder { // OpMetadata is often applied to a series of XLA HLO instructions. As a // result, OpMetadata is set on the Computation Builder. All subsequent // instructions generated via this Computation Builder will have the same - // OpMetadata attached until a call to ClearOpMetdata. + // OpMetadata attached until a call to ClearOpMetadata. void SetOpMetadata(const OpMetadata& metadata) { metadata_ = metadata; } // Clears the HloMetadata state. @@ -154,6 +101,9 @@ class ComputationBuilder { StatusOr> GetShape( const ComputationDataHandle& operand); + // Retrieves the (inferred) result for the current computation's shape. + StatusOr GetProgramShape(); + // Checks that the operand has the given expected shape. Returns the operand // if yes, fails with a CHECK error if no. ComputationDataHandle CheckShape(const ComputationDataHandle& operand, @@ -393,6 +343,11 @@ class ComputationBuilder { ComputationDataHandle Dot(const ComputationDataHandle& lhs, const ComputationDataHandle& rhs); + // Enqueues a general dot instruction onto the computation. + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + // Default dimension numbers used for a 2D convolution. static constexpr int64 kConvBatchDimension = 0; static constexpr int64 kConvFeatureDimension = 1; @@ -458,14 +413,24 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice rhs_dilation, const ConvolutionDimensionNumbers& dimension_numbers); + // Enqueues an FFT instruction onto the computation, of the given type and + // with the given FFT length. + ComputationDataHandle Fft(const ComputationDataHandle& operand, + FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + // Enqueues an infeed instruction onto the computation, which writes data of // the given shape to the infeed buffer of the device. ComputationDataHandle Infeed(const Shape& shape, const string& config = ""); // Enqueues an outfeed instruction onto the computation. This instruction // generates outgoing data transfers for the given data. - void Outfeed(const ComputationDataHandle& operand, const Shape& shape, - const string& outfeed_config); + // + // shape_with_layout communicates the laid out shape that we want to outfeed + // -- if !ShapeUtil::Compatible(GetShape(operand), shape_with_layout) an error + // will occur. + void Outfeed(const ComputationDataHandle& operand, + const Shape& shape_with_layout, const string& outfeed_config); // Enqueues a call instruction onto the computation. ComputationDataHandle Call( @@ -726,16 +691,18 @@ class ComputationBuilder { const ComputationDataHandle& b, const Shape& shape); - // Enqueues a B(1, p) random number generation instruction onto the - // computation. - ComputationDataHandle RngBernoulli(const ComputationDataHandle& mean, - const Shape& shape); - // Enqueues a while node onto the computation. ComputationDataHandle While(const Computation& condition, const Computation& body, const ComputationDataHandle& init); + // Enqueues a conditional node onto the computation. + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const Computation& true_computation, + const ComputationDataHandle& false_operand, + const Computation& false_computation); + // Enqueues a ReducePrecision node onto the computation. ComputationDataHandle ReducePrecision(const ComputationDataHandle& operand, const int exponent_bits, @@ -751,7 +718,7 @@ class ComputationBuilder { ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters with higher index then + // constant does not depend on parameters with index greater than or equal to // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a // compile-time constant without evaluating the computation. @@ -811,7 +778,7 @@ class ComputationBuilder { // The operand must represent a constant value, which in this case // means that it must not statically depend on any parameter of the // computation that is being built other then the ones specified on the - // paramtere list. The parameters in the list will be indexed by their + // parameter list. The parameters in the list will be indexed by their // parameter id property so the number of parameters specified should be at // least as many as the largest used parameter index. // @@ -870,8 +837,6 @@ class ComputationBuilder { Status first_error() const { return first_error_; } private: - using PopulateLiteral = std::function; - // Limited checking of convolution parameters. Returns false on // error. bool VerifyConvolution(const Shape& lhs_shape, const Shape& rhs_shape, @@ -890,11 +855,6 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice rhs_dilation, Window* window); - // Internal helper method that makes a request for a constant operation -- the - // provided function is used to populate the literal before sending the - // request. - ComputationDataHandle ConstantOp(const PopulateLiteral& populate); - // Internal helper method that does the building for an arbitrary unary op. ComputationDataHandle UnaryOp(UnaryOperation binop, const ComputationDataHandle& operand); @@ -924,19 +884,28 @@ class ComputationBuilder { // This is used before any given operation is enqueued. Status PrepareComputation(); - // Helper function for parsing a method response and either returning the - // output computation data handle (on success) or a vacuous computation data - // handle (on failure). - ComputationDataHandle ParseOpResponse(const Status& status, - OpResponse* response); - // Notes that the error occurred by: // * storing it internally and capturing a backtrace if it's the first error // (this deferred value will be produced on the call to Build()) // * dying if die_immediately_on_error_ is true void NoteError(const Status& error); - void AddCommonFieldsToOpRequest(OpRequest* request) const; + // Helper function that runs the given op_request, filling in op_response. + // Before the op is run, PrepareComputation is called, and common fields in + // the op_request are filled in. + Status RunOp(OpRequest* op_request, OpResponse* op_response); + + // Helper function that calls RunOp and calls NoteError on failures. + void RunOpAndNoteError(OpRequest* op_request); + + // Helper function that calls RunOp and either returns the output computation + // data handle (on success) or a vacuous computation data handle (on failure). + ComputationDataHandle RunOpAndParseResponse(OpRequest* op_request); + + // Helper function that implements GetShape without noting errors. This makes + // it easier to ensure the real GetShape will note errors on every error path. + StatusOr> GetShapeWithoutNoteError( + const ComputationDataHandle& operand); string name_; // Name to use for the built computation. @@ -970,68 +939,66 @@ class ComputationBuilder { template ComputationDataHandle ComputationBuilder::ConstantR0(NativeT value) { - return ConstantOp([value](Literal* literal) { literal->PopulateR0(value); }); + return ConstantLiteral(*Literal::CreateR0(value)); } template ComputationDataHandle ComputationBuilder::ConstantR1( tensorflow::gtl::ArraySlice values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR1(values); }); + return ConstantLiteral(*Literal::CreateR1(values)); } template ComputationDataHandle ComputationBuilder::ConstantR1(int64 length, NativeT value) { - return ConstantOp([length, value](Literal* literal) { - literal->PopulateWithValue(value, {length}); - }); + Literal literal(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {length})); + literal.PopulateWithValue(value); + return ConstantLiteral(literal); } inline ComputationDataHandle ComputationBuilder::ConstantR1( const tensorflow::core::Bitmap& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR1(values); }); + return ConstantLiteral(*Literal::CreateR1(values)); } template ComputationDataHandle ComputationBuilder::ConstantR2( std::initializer_list> values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2(values); }); + return ConstantLiteral(*Literal::CreateR2(values)); } template ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( const Array& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateFromArrayWithLayout(values, layout); - }); + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); } template ComputationDataHandle ComputationBuilder::ConstantFromArray( const Array& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateFromArray(values); }); + return ConstantLiteral(*Literal::CreateFromArray(values)); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); + return ConstantLiteral( + *Literal::CreateFromArrayWithLayout(values, layout)); } template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantFromArray(values); + return ConstantLiteral(*Literal::CreateR2FromArray2D(values)); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - return ConstantFromArrayWithLayout(values, layout); + return ConstantLiteral( + *Literal::CreateR3FromArray3DWithLayout(values, layout)); } template diff --git a/tensorflow/compiler/xla/client/executable_build_options.cc b/tensorflow/compiler/xla/client/executable_build_options.cc new file mode 100644 index 0000000000000000000000000000000000000000..804e34f5e75ce2d153ac7627b94a543fda88e810 --- /dev/null +++ b/tensorflow/compiler/xla/client/executable_build_options.cc @@ -0,0 +1,79 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/executable_build_options.h" + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { + +ExecutableBuildOptions& ExecutableBuildOptions::set_device_allocator( + DeviceMemoryAllocator* allocator) { + device_allocator_ = allocator; + return *this; +} + +DeviceMemoryAllocator* ExecutableBuildOptions::device_allocator() const { + return device_allocator_; +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( + int device_ordinal) { + CHECK_GE(device_ordinal, 0); + device_ordinal_ = device_ordinal; + return *this; +} + +int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } + +ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( + const Shape& shape_with_layout) { + result_layout_set_ = true; + result_layout_ = shape_with_layout; + return *this; +} + +const Shape* ExecutableBuildOptions::result_layout() const { + return result_layout_set_ ? &result_layout_ : nullptr; +} + +string ExecutableBuildOptions::ToString() const { + string result_layout = "nullopt"; + if (result_layout_set_) { + result_layout = ShapeUtil::HumanStringWithLayout(result_layout_); + } + string generate_hlo_graph = "nullopt"; + if (generate_hlo_graph_.has_value()) { + generate_hlo_graph = generate_hlo_graph_.value(); + } + return tensorflow::strings::Printf( + "ExecutableBuildOptions{device_ordinal=%d, result_layout=%s, " + "generate_hlo_graph=%s}", + device_ordinal_, result_layout.c_str(), generate_hlo_graph.c_str()); +} + +ExecutableBuildOptions& ExecutableBuildOptions::set_generate_hlo_graph( + string regex) { + generate_hlo_graph_ = std::move(regex); + return *this; +} + +const tensorflow::gtl::optional& +ExecutableBuildOptions::generate_hlo_graph() const { + return generate_hlo_graph_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/client/executable_build_options.h b/tensorflow/compiler/xla/client/executable_build_options.h new file mode 100644 index 0000000000000000000000000000000000000000..3a52dbac9adb155ad9a7d91a8102707f70fe2fbf --- /dev/null +++ b/tensorflow/compiler/xla/client/executable_build_options.h @@ -0,0 +1,74 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/optional.h" + +namespace xla { + +// Class containing options for building an LocalExecutable with +// LocalClient::Compile. +class ExecutableBuildOptions { + public: + // If set, this is the device to build the computation for. Valid + // device_ordinal values are: 0 to # of devices - 1. These values are + // identical to the device ordinal values used by StreamExecutor. The built + // executable will be executable on any device equivalent to the specified + // device as determined by Backend::devices_equivalent(). A value of -1 + // indicates this option has not been set. + ExecutableBuildOptions& set_device_ordinal(int device_ordinal); + int device_ordinal() const; + + // If set, this specifies the layout of the result of the computation. If not + // set, the service will chose the layout of the result. A Shape is used to + // store the layout to accommodate tuple result shapes. A value of nullptr + // indicates the option has not been set. + ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); + const Shape* result_layout() const; + + // If set, this specifies an allocator that can be used to allocate temporary + // space on the device during compilation. For example, the compiler might + // want to run various algorithms on the device and pick the fastest one -- it + // might allocate buffers for use by these algorithms using this allocator. + // + // This does not need to be the same as the DeviceMemoryAllocator passed when + // running the executable. + ExecutableBuildOptions& set_device_allocator( + DeviceMemoryAllocator* allocator); + DeviceMemoryAllocator* device_allocator() const; + + // If set, specifies a regexp of HLO graphs to dump (as in DebugOptions). + ExecutableBuildOptions& set_generate_hlo_graph(string regex); + const tensorflow::gtl::optional& generate_hlo_graph() const; + + // Returns a string representation of the build options, suitable for + // debugging. + string ToString() const; + + private: + int device_ordinal_ = -1; + Shape result_layout_; + bool result_layout_set_ = false; + tensorflow::gtl::optional generate_hlo_graph_; + DeviceMemoryAllocator* device_allocator_ = nullptr; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_CLIENT_EXECUTABLE_BUILD_OPTIONS_H_ diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 5f2b55713e342aa3d0251386d57cb52481fe748d..b63a1465ea755b906853860d47768ecbeaa0dcdd 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -31,14 +31,43 @@ limitations under the License. namespace xla { namespace { +// Calculates the number of bytes required to store the data within the +// specified shape. In case of a (nested) tuple shape this is the total byte +// size of all sub-shapes within the tuple. +int64 DataSizeOfShape(const Shape& shape) { + if (ShapeUtil::IsArray(shape)) { + return ShapeUtil::ByteSizeOf(shape); + } + + int64 total_size = 0; + for (const Shape& s : shape.tuple_shapes()) { + total_size += DataSizeOfShape(s); + } + return total_size; +} + +// Create a ComputationDataHandle for an op what generates fake data with the +// given shape. +ComputationDataHandle BuildFakeDataOpOnDevice(const Shape& shape, + ComputationBuilder* builder) { + if (ShapeUtil::IsArray(shape)) { + return builder->Broadcast( + builder->ConstantLiteral(Literal::One(shape.element_type())), + AsInt64Slice(shape.dimensions())); + } + std::vector parts; + for (const Shape& s : shape.tuple_shapes()) { + parts.push_back(BuildFakeDataOpOnDevice(s, builder)); + } + return builder->Tuple(parts); +} + std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, Client* client) { ComputationBuilder b( client, tensorflow::strings::StrCat("make_fake_", ShapeUtil::HumanString(shape))); - // TODO(b/26811613): Replace this when RNG is supported on all backends. - b.Broadcast(b.ConstantLiteral(Literal::One(shape.element_type())), - AsInt64Slice(shape.dimensions())); + BuildFakeDataOpOnDevice(shape, &b); Computation computation = b.Build().ConsumeValueOrDie(); auto execution_options = CreateDefaultExecutionOptions(); @@ -51,7 +80,7 @@ std::unique_ptr MakeFakeDataViaDeviceOrDie(const Shape& shape, std::unique_ptr MakeFakeDataOrDie(const Shape& shape, Client* client) { - if (ShapeUtil::ByteSizeOf(shape) < (1LL << 20)) { + if (DataSizeOfShape(shape) < (1LL << 20)) { StatusOr> literal_status = MakeFakeLiteral(shape); if (!literal_status.ok()) { // If we got an Unimplemented error, fall back to making the fake data via diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index b051955f0fd85b7ca886bc0238068aeb94427209..ef98dbb6403beedb0c08ab9a0fc9e7d4ee31ab3b 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -21,30 +21,14 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/service_executable_run_options.h" +#include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/status_macros.h" namespace se = ::perftools::gputools; -namespace xla { - -ExecutableBuildOptions& ExecutableBuildOptions::set_device_ordinal( - int device_ordinal) { - device_ordinal_ = device_ordinal; - return *this; -} +using xla::source_map_util::InvalidParameterArgument; -int ExecutableBuildOptions::device_ordinal() const { return device_ordinal_; } - -ExecutableBuildOptions& ExecutableBuildOptions::set_result_layout( - const Shape& shape_with_layout) { - result_layout_set_ = true; - result_layout_ = shape_with_layout; - return *this; -} - -const Shape* ExecutableBuildOptions::result_layout() const { - return result_layout_set_ ? &result_layout_ : nullptr; -} +namespace xla { namespace { StatusOr BorrowStreamForDevice(int device_ordinal, @@ -57,16 +41,18 @@ StatusOr BorrowStreamForDevice(int device_ordinal, } // namespace LocalExecutable::LocalExecutable(std::unique_ptr executable, - Backend* backend, int device_ordinal, - const ExecutableBuildOptions& build_options) + Backend* backend, + ExecutableBuildOptions build_options) : executable_(std::move(executable)), backend_(backend), - build_device_ordinal_(device_ordinal), - build_options_(build_options) {} + build_options_(std::move(build_options)) { + CHECK_GE(build_options_.device_ordinal(), 0) + << "Must have a valid device ordinal that the executable was built for."; +} tensorflow::Status LocalExecutable::ValidateExecutionOptions( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options, const Backend& backend) { + const ExecutableRunOptions& run_options, const Backend& backend) { const ComputationLayout& computation_layout = executable_->module_config().entry_computation_layout(); @@ -78,25 +64,26 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( } for (int i = 0; i < arguments.size(); ++i) { if (!computation_layout.parameter_layout(i).MatchesLayoutInShape( - arguments[i]->shape())) { - return InvalidArgument( - "argument does not match shape or layout of computation parameter " - "%d: expected %s, got %s", + arguments[i]->on_host_shape())) { + return InvalidParameterArgument( + executable_.get(), i, + "Argument does not match shape or layout of computation parameter " + "%d: want %s, got %s", i, ShapeUtil::HumanString(computation_layout.parameter_layout(i).shape()) .c_str(), - ShapeUtil::HumanString(arguments[i]->shape()).c_str()); + ShapeUtil::HumanString(arguments[i]->on_host_shape()).c_str()); } } - if (options.stream() != nullptr) { - if (!options.stream()->ok()) { + if (run_options.stream() != nullptr) { + if (!run_options.stream()->ok()) { return InvalidArgument("stream is uninitialized or in an error state"); } // Check stream matches service platform. const se::Platform* stream_platform = - options.stream()->parent()->platform(); + run_options.stream()->parent()->platform(); if (stream_platform != backend_->platform()) { return InvalidArgument( "stream is for platform %s, but service targets platform %s", @@ -106,7 +93,7 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( // Cannot specify device_ordinal with a stream. The stream determines these // values. - if (options.device_ordinal() != -1) { + if (run_options.device_ordinal() != -1) { return InvalidArgument( "cannot set both device ordinal and stream options in " "ExecutableRunOptions; the stream determines the device ordinal"); @@ -115,34 +102,34 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( // Verify that the device the executable was built for is equivalent to the // device it will run on. - int run_device_ordinal = options.device_ordinal() == -1 + int run_device_ordinal = run_options.device_ordinal() == -1 ? backend_->default_device_ordinal() - : options.device_ordinal(); - TF_ASSIGN_OR_RETURN( - bool devices_equivalent, - backend_->devices_equivalent(run_device_ordinal, build_device_ordinal_)); + : run_options.device_ordinal(); + TF_ASSIGN_OR_RETURN(bool devices_equivalent, + backend_->devices_equivalent( + run_device_ordinal, build_options_.device_ordinal())); if (!devices_equivalent) { TF_ASSIGN_OR_RETURN(se::StreamExecutor * run_executor, backend_->stream_executor(run_device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * build_executor, - backend_->stream_executor(build_device_ordinal_)); + backend_->stream_executor(build_device_ordinal())); return InvalidArgument( "executable is built for device %s of type \"%s\"; cannot run it on " "device %s of type \"%s\"", - backend_->device_name(build_device_ordinal_).c_str(), + backend_->device_name(build_device_ordinal()).c_str(), build_executor->GetDeviceDescription().name().c_str(), backend_->device_name(run_device_ordinal).c_str(), run_executor->GetDeviceDescription().name().c_str()); } - if (!options.allocator()) { + if (!run_options.allocator()) { return InvalidArgument("an allocator must be provided to ExecuteLocally"); } - if (options.allocator()->platform() != backend.platform()) { + if (run_options.allocator()->platform() != backend.platform()) { return InvalidArgument( "allocator platform (%s) does not match service platform (%s)", - options.allocator()->platform()->Name().c_str(), + run_options.allocator()->platform()->Name().c_str(), backend.platform()->Name().c_str()); } @@ -151,23 +138,22 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( StatusOr> LocalExecutable::Run( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options) { - TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); - - ExecutableRunOptions actual_options = options; + ExecutableRunOptions run_options) { + TF_RETURN_IF_ERROR( + ValidateExecutionOptions(arguments, run_options, *backend_)); Backend::StreamPtr stream; - if (options.stream() == nullptr) { + if (run_options.stream() == nullptr) { // NB! The lifetime of `stream` needs to match the lifetime of // `actual_options` (otherwise we will end up using a returned stream in // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if" // scope. TF_ASSIGN_OR_RETURN( - stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); - actual_options.set_stream(stream.get()); + stream, BorrowStreamForDevice(run_options.device_ordinal(), backend_)); + run_options.set_stream(stream.get()); } - if (options.allocator() == nullptr) { - actual_options.set_allocator(backend_->memory_allocator()); + if (run_options.allocator() == nullptr) { + run_options.set_allocator(backend_->memory_allocator()); } // For local client execution on CPU backends: @@ -176,7 +162,7 @@ StatusOr> LocalExecutable::Run( // *) The thread pool used for XLA CPU ops is from // backend_->eigen_intra_op_thread_pool(). ServiceExecutableRunOptions service_options( - actual_options, backend_->StreamBorrower(), + run_options, backend_->StreamBorrower(), backend_->eigen_intra_op_thread_pool()); if (executable_->dumping()) { @@ -184,10 +170,9 @@ StatusOr> LocalExecutable::Run( } TF_ASSIGN_OR_RETURN( std::unique_ptr result, - executable_->ExecuteOnStreamWrapper>( - &service_options, options.execution_profile(), arguments)); - return ScopedShapedBuffer::MakeScoped(result.get(), - actual_options.allocator()); + executable_->ExecuteOnStreamWrapper( + &service_options, run_options.execution_profile(), arguments)); + return ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator()); } StatusOr> LocalExecutable::ExecuteAndDump( @@ -263,16 +248,19 @@ StatusOr> LocalClient::Compile( const Computation& computation, const tensorflow::gtl::ArraySlice argument_layouts, const ExecutableBuildOptions& options) { - int device_ordinal = options.device_ordinal() == -1 - ? default_device_ordinal() - : options.device_ordinal(); - TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - local_service_->CompileExecutable( - computation.handle(), argument_layouts, - options.result_layout(), device_ordinal)); + ExecutableBuildOptions updated_options = options; + if (options.device_ordinal() == -1) { + updated_options.set_device_ordinal(default_device_ordinal()); + VLOG(3) << "Set device ordinal to default value of: " + << updated_options.device_ordinal(); + } + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + local_service_->CompileExecutable(computation.handle(), argument_layouts, + updated_options)); return WrapUnique(new LocalExecutable(std::move(executable), local_service_->mutable_backend(), - device_ordinal, options)); + updated_options)); } StatusOr> @@ -281,13 +269,9 @@ LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, if (allocator == nullptr) { allocator = backend().memory_allocator(); } - TF_ASSIGN_OR_RETURN( - auto scoped_buffer, - ScopedShapedBuffer::Allocate( - literal.shape(), allocator, device_ordinal, - [this](const Shape& shape) { - return backend().transfer_manager()->GetByteSizeRequirement(shape); - })); + TF_ASSIGN_OR_RETURN(auto scoped_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + literal.shape(), allocator, device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( @@ -322,4 +306,8 @@ StatusOr> LocalClient::TransferFromOutfeedLocal( return std::move(literal); } +StatusOr LocalClient::ReplicaNumberToDeviceOrdinal(int replica_number) { + return local_service_->ReplicaNumberToDeviceOrdinal(replica_number); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index 3ca0d2ef5513cfb6b0dbfbc63b311f81a318356e..b52a30f5a0b92e0094e6b0de3241c10a5a909cad 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client.h" #include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/executable_run_options.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -33,39 +34,13 @@ limitations under the License. namespace xla { -// Class containing options for building an LocalExecutable with -// LocalClient::Compile. -class ExecutableBuildOptions { - public: - // If set, this is the device to build the computation for. Valid - // device_ordinal values are: 0 to # of devices - 1. These values are - // identical to the device ordinal values used by StreamExecutor. The built - // executable will be executable on any device equivalent to the specified - // device as determined by Backend::devices_equivalent(). A value of -1 - // indicates this option has not been set. - ExecutableBuildOptions& set_device_ordinal(int device_ordinal); - int device_ordinal() const; - - // If set, this specifies the layout of the result of the computation. If not - // set, the service will chose the layout of the result. A Shape is used to - // store the layout to accommodate tuple result shapes. A value of nullptr - // indicates the option has not been set. - ExecutableBuildOptions& set_result_layout(const Shape& shape_with_layout); - const Shape* result_layout() const; - - private: - int device_ordinal_ = -1; - Shape result_layout_; - bool result_layout_set_ = false; -}; - class LocalExecutable { public: // Run the compiled computation with the given arguments and options and // return the result. StatusOr> Run( const tensorflow::gtl::ArraySlice arguments, - const ExecutableRunOptions& options); + ExecutableRunOptions run_options); // Return the layout (contained in a shape) of the result produced by the // computation. @@ -88,8 +63,7 @@ class LocalExecutable { // Constructor invoked by LocalClient. LocalExecutable(std::unique_ptr executable, Backend* backend, - int device_ordinal, - const ExecutableBuildOptions& build_options); + ExecutableBuildOptions build_options); // Validates that the given arguments and options satisfy various constraints // of the computation. @@ -117,19 +91,19 @@ class LocalExecutable { StatusOr> LiteralFromShapedBuffer( const ShapedBuffer& shaped_buffer); + // The ordinal of the device which this executable was compiled for. The + // executable can run on all equivalent devices (as determined by + // Backend::devices_equivalent). + int build_device_ordinal() const { return build_options_.device_ordinal(); } + // Compiled computation. std::unique_ptr executable_; // Execution backend. - Backend* backend_; - - // The ordinal of the device which this executable was compiled for. The - // executable can run on all equivalent devices (as determined by - // Backend::devices_equivalent). - int build_device_ordinal_; + Backend* backend_ = nullptr; // Options used to build the executable. - const ExecutableBuildOptions& build_options_; + const ExecutableBuildOptions build_options_; }; // An XLA Client specialization for use when the client and service run in @@ -176,6 +150,13 @@ class LocalClient : public Client { StatusOr> TransferFromOutfeedLocal( const Shape& shape, int device_ordinal); + // Returns the device ordinal that corresponds to the given replica number. + // + // This returns an error if there is not a one-to-one correspondence of + // replicas to device ordinals, but is useful as a short term mechanism for + // the "easy" case where a single replica is a single device. + StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); + // Returns the platform that the underlying service targets. perftools::gputools::Platform* platform() const; diff --git a/tensorflow/compiler/xla/client/sharding_builder.cc b/tensorflow/compiler/xla/client/sharding_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..176802b33ef824a1f898255a19e44def3c1fc982 --- /dev/null +++ b/tensorflow/compiler/xla/client/sharding_builder.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/sharding_builder.h" + +namespace xla { +namespace sharding_builder { + +OpSharding Replicate() { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + return result; +} + +OpSharding AssignDevice(int device) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + result.add_tile_assignment_dimensions(1); + result.add_tile_assignment_devices(device); + return result; +} + +OpSharding Tile(const Shape& tile_shape, + const TileAssignment& tile_assignment) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *result.mutable_tile_shape() = tile_shape; + for (int64 dim : tile_assignment.dimensions()) { + result.add_tile_assignment_dimensions(dim); + } + for (uint32 device : tile_assignment) { + result.add_tile_assignment_devices(device); + } + return result; +} + +OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + + CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); + std::vector dimensions(1, num_tiles); + *result.mutable_tile_shape() = tile_shape; + auto& tile_dimension = + (*result.mutable_tile_shape()->mutable_dimensions())[0]; + tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); + result.add_tile_assignment_dimensions(num_tiles); + for (int64 i = 0; i < num_tiles; ++i) { + result.add_tile_assignment_devices(i); + } + return result; +} + +OpSharding Tuple(const ShapeTree& shardings) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_TUPLE); + for (const auto& index_to_sharding : shardings.leaves()) { + *result.add_tuple_shardings() = index_to_sharding.second; + } + return result; +} + +} // namespace sharding_builder +} // namespace xla diff --git a/tensorflow/compiler/xla/client/sharding_builder.h b/tensorflow/compiler/xla/client/sharding_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..34763e54d946690289ff42a7712b980168933eee --- /dev/null +++ b/tensorflow/compiler/xla/client/sharding_builder.h @@ -0,0 +1,59 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_CLIENT_SHARDING_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_CLIENT_SHARDING_BUILDER_H_ + +#include + +#include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/shape_tree.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace sharding_builder { +// A shaped array used to describe the assignment of tiles to devices. +using TileAssignment = Array; + +// Creates a replicated sharding - replicate a tensor on every device. +OpSharding Replicate(); + +// Creates a sharding that assigns a tensor to just one device. +OpSharding AssignDevice(int device); + +// Creates a tiled sharding with the given tile shape and assignment of tiles +// to devices. +// +// If tile_shape is not evenly divisible by the number of devices in +// tile_assignment, operations behave as if implicit padding had been inserted. +// The value of this padding is undefined. +OpSharding Tile(const Shape& tile_shape, const TileAssignment& tile_assignment); + +// Creates a sharding in one dimension, with the given tile shape which must +// be rank 1 and using devices [0..num_tiles). +// +// This is simply a convenience wrapper for Tile(). +OpSharding Tile1D(const Shape& tile_shape, int64 num_tiles); + +// Creates a tuple sharding from the given ShapeTree of element shardings. +OpSharding Tuple(const ShapeTree& shardings); + +} // namespace sharding_builder +} // namespace xla + +#endif diff --git a/tensorflow/compiler/xla/executable_run_options.cc b/tensorflow/compiler/xla/executable_run_options.cc index 33d5b6f1d4d15d5143a3421c87eab9b7a7d11345..392ad9010ab81923a089c7b00a79ddc281af92bb 100644 --- a/tensorflow/compiler/xla/executable_run_options.cc +++ b/tensorflow/compiler/xla/executable_run_options.cc @@ -83,7 +83,7 @@ ExecutableRunOptions& ExecutableRunOptions::set_device_assignment( return *this; } -DeviceAssignment* ExecutableRunOptions::device_assignment() const { +const DeviceAssignment* ExecutableRunOptions::device_assignment() const { return device_assignment_; } diff --git a/tensorflow/compiler/xla/executable_run_options.h b/tensorflow/compiler/xla/executable_run_options.h index deb3ddb203d263d25bef0499a8a53a6098d0de0c..d4fcbf0493c936ebcd0639a432e56b62ee15672c 100644 --- a/tensorflow/compiler/xla/executable_run_options.h +++ b/tensorflow/compiler/xla/executable_run_options.h @@ -82,7 +82,7 @@ class ExecutableRunOptions { ExecutableRunOptions& set_device_assignment( DeviceAssignment* device_assignment); - DeviceAssignment* device_assignment() const; + const DeviceAssignment* device_assignment() const; private: DeviceMemoryAllocator* allocator_ = nullptr; diff --git a/tensorflow/compiler/xla/execution_options_util.h b/tensorflow/compiler/xla/execution_options_util.h index 562da78e837ea6c4a01f0d1170797340fd421ad8..a8ca27ec8dfdc01267ccc9efa6c39093c43d4e2d 100644 --- a/tensorflow/compiler/xla/execution_options_util.h +++ b/tensorflow/compiler/xla/execution_options_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ #include "tensorflow/compiler/xla/xla.pb.h" @@ -26,4 +26,4 @@ ExecutionOptions CreateDefaultExecutionOptions(); } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_EXECUTION_OPTIONS_UTIL_H_ diff --git a/tensorflow/compiler/xla/index_util.cc b/tensorflow/compiler/xla/index_util.cc index 76c0168f370ff1f0749759705b7ecff359a80341..ffd1fb79e986f82e1c2721f0eefbf3b4c0838e41 100644 --- a/tensorflow/compiler/xla/index_util.cc +++ b/tensorflow/compiler/xla/index_util.cc @@ -78,7 +78,7 @@ namespace xla { int64 scale = 1; int64 linear_index = 0; bool first = true; - for (auto dimension : shape.layout().minor_to_major()) { + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { if (first) { // Avoid two multiplies on the first loop iteration linear_index = multi_index[dimension]; @@ -110,7 +110,7 @@ namespace xla { // Accumulated product D{L(0)} * D{L(1)} * ... int64 divisor = 1; - for (auto dimension : shape.layout().minor_to_major()) { + for (auto dimension : LayoutUtil::MinorToMajor(shape)) { multi_index[dimension] = (linear_index / divisor) % shape.dimensions(dimension); divisor *= shape.dimensions(dimension); @@ -133,21 +133,49 @@ namespace xla { /* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape, int64 dimension) { - const Layout& layout = shape.layout(); - int64 pdim_size = layout.padded_dimensions_size(); + int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size(); int64 stride = 1; DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size()); - for (auto dim : layout.minor_to_major()) { + for (auto dim : LayoutUtil::MinorToMajor(shape)) { if (dim == dimension) { break; } if (pdim_size == 0) { stride *= shape.dimensions(dim); } else { - stride *= layout.padded_dimensions(dim); + stride *= LayoutUtil::PaddedDimension(shape, dim); } } return stride; } +/* static */ bool IndexUtil::IndexInBounds( + const Shape& shape, tensorflow::gtl::ArraySlice index) { + int64 rank = ShapeUtil::Rank(shape); + if (rank != index.size()) { + return false; + } + for (int64 d = 0; d < rank; ++d) { + if (index[d] >= shape.dimensions(d)) { + return false; + } + } + return true; +} + +/* static */ int IndexUtil::CompareIndices( + tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs) { + int64 rank = lhs.size(); + CHECK_EQ(rhs.size(), rank); + for (int64 dim = 0; dim < rank; ++dim) { + if (lhs[dim] < rhs[dim]) { + return -1; + } else if (lhs[dim] > rhs[dim]) { + return 1; + } + } + return 0; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/index_util.h b/tensorflow/compiler/xla/index_util.h index c9838966a5b67397eb5fc4afe3ab9d98e82eb2b1..142006f2626e83d3254f2de65fc28fd5d6694e53 100644 --- a/tensorflow/compiler/xla/index_util.h +++ b/tensorflow/compiler/xla/index_util.h @@ -37,7 +37,7 @@ class IndexUtil { static int64 MultidimensionalIndexToLinearIndex( const Shape& shape, tensorflow::gtl::ArraySlice multi_index); - // Coverts a linear index into multidimensional index (eg {x, y, z}) based on + // Converts a linear index into multidimensional index (eg {x, y, z}) based on // the shape and its layout. The first index in the returned multidimensional // index is dimension 0. static std::vector LinearIndexToMultidimensionalIndex( @@ -69,6 +69,18 @@ class IndexUtil { // sizeof(dimension(3)) * sizeof(dimension(2)) == 4 * 10 static int64 GetDimensionStride(const Shape& shape, int64 dimension); + // Returns true iff the given multi-index is contained in the bounds for the + // shape. + static bool IndexInBounds(const Shape& shape, + tensorflow::gtl::ArraySlice index); + + // Compares the given indices in lexicographic order. lhs[0] and rhs[0] are + // compared first, and lhs[rank-1] and rhs[rank-1] last. If lhs is larger, + // then -1 is returned. If rhs is larger, then 1 is returned. Otherwise, 0 is + // returned. + static int CompareIndices(tensorflow::gtl::ArraySlice lhs, + tensorflow::gtl::ArraySlice rhs); + private: TF_DISALLOW_COPY_AND_ASSIGN(IndexUtil); }; diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h index a39999705eddc5728dce028dab64b7358395757e..a8bb8c7a7e6784e555f4e9dad73ecc78c668ac42 100644 --- a/tensorflow/compiler/xla/iterator_util.h +++ b/tensorflow/compiler/xla/iterator_util.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ #include #include @@ -95,4 +95,4 @@ UnwrappingIterator MakeUnwrappingIterator(NestedIter iter) { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 5c2cc2a7a99cc51ded3d98c9dd5903e4b3078548..fdc4bbdd8b162b7115788e267c2a53e73c186123 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -57,17 +57,26 @@ void SetDefaultLayoutToContainer( /* static */ Layout LayoutUtil::MakeLayout( tensorflow::gtl::ArraySlice minor_to_major) { Layout layout; + layout.set_format(DENSE); for (int64 dimension_number : minor_to_major) { layout.add_minor_to_major(dimension_number); } return layout; } +/* static */ Layout LayoutUtil::MakeSparseLayout(int64 max_sparse_elements) { + Layout layout; + layout.set_format(SPARSE); + layout.set_max_sparse_elements(max_sparse_elements); + return layout; +} + namespace { // Internal helper that creates a default layout for an array of the given rank. Layout CreateDefaultLayoutForRank(int64 rank) { Layout layout; + layout.set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = layout.mutable_minor_to_major(); minor_to_major->Resize(rank, 0); @@ -105,7 +114,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { for (auto& element_shape : *shape->mutable_tuple_shapes()) { SetToDefaultLayout(&element_shape); } + shape->clear_layout(); + } else if (ShapeUtil::IsOpaque(*shape)) { + shape->clear_layout(); } else { + shape->mutable_layout()->set_format(DENSE); tensorflow::protobuf::RepeatedField* minor_to_major = shape->mutable_layout()->mutable_minor_to_major(); minor_to_major->Resize(shape->dimensions_size(), 0); @@ -137,8 +150,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { TF_RETURN_IF_ERROR(ValidateLayoutInShape(element_shape)); } return tensorflow::Status::OK(); - } else if (ShapeUtil::Rank(shape) == 0 && !shape.has_layout()) { - // A scalar without a layout is ok. + } else if (ShapeUtil::IsOpaque(shape)) { + if (shape.has_layout()) { + return InvalidArgument("opaque should not have a layout field"); + } return tensorflow::Status::OK(); } else { // Array shape. @@ -156,46 +171,59 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return InvalidArgument("a single Layout is not valid for tuple shapes"); } - if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { + if (ShapeUtil::IsOpaque(shape)) { + return tensorflow::Status::OK(); + } + + if (layout.format() == INVALID_FORMAT) { return InvalidArgument( - "layout minor_to_major field contains %d elements, " - "but shape is rank %lld: {%s}; shape: %s", - layout.minor_to_major_size(), ShapeUtil::Rank(shape), - tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), - shape.ShortDebugString().c_str()); + "Layout does not have a valid format: layout {%s}, shape {%s}", + layout.ShortDebugString().c_str(), shape.ShortDebugString().c_str()); } - std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); - for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { - int64 dim = layout.minor_to_major(i); - if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + if (layout.format() == DENSE) { + if (layout.minor_to_major_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "layout minor_to_major field has out-of-bounds value: %s", - HumanString(layout).c_str()); + "layout minor_to_major field contains %d elements, " + "but shape is rank %lld: {%s}; shape: %s", + layout.minor_to_major_size(), ShapeUtil::Rank(shape), + tensorflow::str_util::Join(layout.minor_to_major(), ", ").c_str(), + shape.ShortDebugString().c_str()); } - if (dimensions_in_layout[dim]) { - return InvalidArgument( - "layout minor_to_major field has duplicate values: {%s}", - HumanString(layout).c_str()); - } - dimensions_in_layout[dim] = true; - } - if (layout.padded_dimensions_size() > 0) { - if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { - return InvalidArgument( - "layout has %d padded dimensions, but shape is rank %lld", - layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); + std::vector dimensions_in_layout(ShapeUtil::Rank(shape), false); + for (int64 i = 0; i < ShapeUtil::Rank(shape); ++i) { + int64 dim = layout.minor_to_major(i); + if (dim < 0 || dim >= ShapeUtil::Rank(shape)) { + return InvalidArgument( + "layout minor_to_major field has out-of-bounds value: %s", + HumanString(layout).c_str()); + } + if (dimensions_in_layout[dim]) { + return InvalidArgument( + "layout minor_to_major field has duplicate values: {%s}", + HumanString(layout).c_str()); + } + dimensions_in_layout[dim] = true; } - for (int i = 0; i < layout.padded_dimensions_size(); ++i) { - if (layout.padded_dimensions(i) < shape.dimensions(i)) { + + if (layout.padded_dimensions_size() > 0) { + if (layout.padded_dimensions_size() != ShapeUtil::Rank(shape)) { return InvalidArgument( - "for dimension %d, dimension padding (%lld) is smaller than " - "the dimension size (%lld) of the shape", - i, layout.padded_dimensions(i), shape.dimensions(i)); + "layout has %d padded dimensions, but shape is rank %lld", + layout.padded_dimensions_size(), ShapeUtil::Rank(shape)); + } + for (int i = 0; i < layout.padded_dimensions_size(); ++i) { + if (layout.padded_dimensions(i) < shape.dimensions(i)) { + return InvalidArgument( + "for dimension %d, dimension padding (%lld) is smaller than " + "the dimension size (%lld) of the shape", + i, layout.padded_dimensions(i), shape.dimensions(i)); + } } } } + return tensorflow::Status::OK(); } @@ -213,12 +241,23 @@ Layout CreateDefaultLayoutForRank(int64 rank) { LayoutUtil::ClearLayout(program_shape->mutable_result()); } +/* static */ bool LayoutUtil::IsDenseArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && shape.has_layout() && + IsDense(shape.layout()); +} + +/* static */ bool LayoutUtil::IsDense(const Layout& layout) { + return layout.format() == DENSE; +} + /* static */ bool LayoutUtil::IsMonotonicWithDim0Minor(const Layout& layout) { + CHECK(layout.format() == DENSE); return std::is_sorted(layout.minor_to_major().begin(), layout.minor_to_major().end()); } /* static */ bool LayoutUtil::IsMonotonicWithDim0Major(const Layout& layout) { + CHECK(layout.format() == DENSE); return std::is_sorted(layout.minor_to_major().begin(), layout.minor_to_major().end(), std::greater()); } @@ -228,6 +267,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { shape.layout().padded_dimensions_size() == 0) { return false; } + CHECK(IsDenseArray(shape)); CHECK_EQ(shape.dimensions_size(), shape.layout().padded_dimensions_size()); for (int64 i = 0; i < shape.dimensions_size(); ++i) { if (shape.layout().padded_dimensions(i) > shape.dimensions(i)) { @@ -237,15 +277,46 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return false; } +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::PaddedDimensions( + const Shape& shape) { + CHECK(IsDenseArray(shape)); + return AsInt64Slice(shape.layout().padded_dimensions()); +} + +/* static */ int64 LayoutUtil::PaddedDimension(const Shape& shape, + int64 index) { + CHECK(IsDenseArray(shape)); + return shape.layout().padded_dimensions(index); +} + +/* static */ PaddingValue LayoutUtil::GetPaddingValue(const Shape& shape) { + CHECK(IsDenseArray(shape)); + return shape.layout().padding_value(); +} + +/* static */ bool LayoutUtil::IsSparseArray(const Shape& shape) { + return ShapeUtil::IsArray(shape) && shape.has_layout() && + IsSparse(shape.layout()); +} + +/* static */ bool LayoutUtil::IsSparse(const Layout& layout) { + return layout.format() == SPARSE; +} + +/* static */ int64 LayoutUtil::MaxSparseElements(const Layout& layout) { + CHECK(IsSparse(layout)); + return layout.max_sparse_elements(); +} + /* static */ bool LayoutUtil::HasLayout(const Shape& shape) { if (ShapeUtil::IsTuple(shape)) { // Tuple shape: all subshapes must have a layout. return std::all_of(shape.tuple_shapes().begin(), shape.tuple_shapes().end(), [](const Shape& s) { return HasLayout(s); }); + } else if (ShapeUtil::IsOpaque(shape)) { + return true; } - // A scalar trivially always has a layout. - return (ShapeUtil::Rank(shape) == 0 || - (shape.has_layout() && (shape.layout().minor_to_major_size() > 0))); + return shape.has_layout() && shape.layout().format() != INVALID_FORMAT; } /* static */ bool LayoutUtil::HasLayout(const ProgramShape& program_shape) { @@ -261,6 +332,18 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return protobuf_util::ProtobufEquals(lhs, rhs); } +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( + const Shape& shape) { + CHECK(IsDenseArray(shape)); + return AsInt64Slice(shape.layout().minor_to_major()); +} + +/* static */ tensorflow::gtl::ArraySlice LayoutUtil::MinorToMajor( + const Layout& layout) { + CHECK(layout.format() == DENSE); + return AsInt64Slice(layout.minor_to_major()); +} + /* static */ int64 LayoutUtil::Major(const Layout& layout, int64 physical_dimension_number) { CHECK_LE(0, physical_dimension_number); @@ -271,6 +354,7 @@ Layout CreateDefaultLayoutForRank(int64 rank) { /* static */ int64 LayoutUtil::Minor(const Layout& layout, int64 physical_dimension_number) { + CHECK_EQ(layout.format(), DENSE); CHECK_LE(0, physical_dimension_number); CHECK_LT(physical_dimension_number, layout.minor_to_major_size()); return layout.minor_to_major(physical_dimension_number); @@ -287,6 +371,11 @@ Layout CreateDefaultLayoutForRank(int64 rank) { } /* static */ string LayoutUtil::HumanString(const Layout& layout) { + if (IsSparse(layout)) { + return tensorflow::strings::StrCat("sparse{", layout.max_sparse_elements(), + "}"); + } + CHECK(IsDense(layout)); return tensorflow::strings::StrCat( "{", tensorflow::str_util::Join(layout.minor_to_major(), ","), "}"); } @@ -356,6 +445,7 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, /* static */ bool LayoutUtil::AreDimensionsConsecutive( const Layout& layout, tensorflow::gtl::ArraySlice dims) { + CHECK(IsDense(layout)); std::vector positions_in_layout; for (int64 dim : dims) { positions_in_layout.push_back( @@ -370,4 +460,9 @@ tensorflow::Status LayoutUtil::CopyLayoutBetweenShapes(const Shape& src, return true; } +std::ostream& operator<<(std::ostream& out, const Layout& layout) { + out << LayoutUtil::HumanString(layout); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index bc42e222292933be35e82d1fe50802e8830d16b3..6c54eb2201b66a4a0c5695bceb14bb2367133935 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -36,6 +36,10 @@ class LayoutUtil { // convenience function for protobuf construction.) static Layout MakeLayout(tensorflow::gtl::ArraySlice minor_to_major); + // Creates a sparse layout with the given maximum number of elements. (This is + // a convenience function for protobuf construction.) + static Layout MakeSparseLayout(int64 max_sparse_elements); + // Returns default layout for the given shape. static Layout GetDefaultLayoutForShape(const Shape& shape); @@ -71,6 +75,12 @@ class LayoutUtil { // Clears the layout on all Shapes within the given ProgramShape. static void ClearLayout(ProgramShape* program_shape); + // Returns whether the given Shape is an array and has a dense format layout. + static bool IsDenseArray(const Shape& shape); + + // Returns whether the given Layout has a dense format. + static bool IsDense(const Layout& layout); + // Returns whether the layout is monotonic and dim 0 is minor in the layout. // * R0 and R1: this is always trivially true. // * R2+: equivalent to column-major. Dimension 0 is the minor, dimension 1 is @@ -88,6 +98,30 @@ class LayoutUtil { // dimension size). static bool IsPadded(const Shape& shape); + // Returns the padded_dimensions array for the given Shape. Requires that the + // shape is an array and has a dense layout. + static tensorflow::gtl::ArraySlice PaddedDimensions( + const Shape& shape); + + // Returns the given index of the padded_dimensions array for the given Shape. + // Requires that the shape is an array and has a dense layout. + static int64 PaddedDimension(const Shape& shape, int64 index); + + // Returns the padding_value for the given Shape. Requires that the shape is + // an array and has a dense layout. + static PaddingValue GetPaddingValue(const Shape& shape); + + // Returns whether the given Shape is an array (i.e. not a tuple) and has a + // sparse format layout. + static bool IsSparseArray(const Shape& shape); + + // Returns whether the given Layout has a sparse format. + static bool IsSparse(const Layout& layout); + + // Returns the maximum number of elements that can be stored in a sparse + // layout. + static int64 MaxSparseElements(const Layout& layout); + // Returns whether the given shape has a layout. For tuple shapes, true is // returned only if all elements have layouts. static bool HasLayout(const Shape& shape); @@ -98,7 +132,12 @@ class LayoutUtil { // Returns whether lhs and rhs are identical. static bool Equal(const Layout& lhs, const Layout& rhs); - // Major(0) is the most major logical dimension number, major(1) is the + // Returns the minor_to_major array for the given Shape. Requires that the + // shape is an array and has a dense layout. + static tensorflow::gtl::ArraySlice MinorToMajor(const Shape& shape); + static tensorflow::gtl::ArraySlice MinorToMajor(const Layout& layout); + + // Major(0) is the most major logical dimension number, Major(1) is the // second-most-major logical dimension number and so on. // // This can be used to translate physical dimension numbers to logical @@ -160,6 +199,8 @@ class LayoutUtil { TF_DISALLOW_COPY_AND_ASSIGN(LayoutUtil); }; +std::ostream& operator<<(std::ostream& out, const Layout& layout); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_LAYOUT_UTIL_H_ diff --git a/tensorflow/compiler/xla/layout_util_test.cc b/tensorflow/compiler/xla/layout_util_test.cc index 331bb9afa94e9e7c97d9c880dbac31c60ac0da18..4fd1d818e3e3b417eee9f6b14bb598bfb9480c6e 100644 --- a/tensorflow/compiler/xla/layout_util_test.cc +++ b/tensorflow/compiler/xla/layout_util_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/layout_util.h" + +#include + #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -30,6 +33,14 @@ class LayoutUtilTest : public ::testing::Test { *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); return shape; } + + Shape MakeShapeWithSparseLayout(PrimitiveType element_type, + tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + return shape; + } }; TEST_F(LayoutUtilTest, TupleLayoutComparison) { @@ -81,6 +92,29 @@ TEST_F(LayoutUtilTest, CopyLayoutArray) { EXPECT_FALSE(dst.has_layout()); } +TEST_F(LayoutUtilTest, CopyLayoutSparse) { + Shape src = MakeShapeWithSparseLayout(F32, {2, 3}, 2); + Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // Should work if destination has no layout. + dst.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + + // If source is cleared, then destination should be cleared. + src.clear_layout(); + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_TRUE(dst.has_layout()); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_FALSE(dst.has_layout()); +} + TEST_F(LayoutUtilTest, CopyLayoutTuple) { Shape src = ShapeUtil::MakeTupleShape( {MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -100,6 +134,25 @@ TEST_F(LayoutUtilTest, CopyLayoutTuple) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutTupleSparse) { + Shape src = ShapeUtil::MakeTupleShape( + {MakeShapeWithSparseLayout(F32, {2, 3}, 4), + MakeShapeWithSparseLayout(F32, {42, 123}, 4), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithSparseLayout(F32, {1, 2, 3}, 6)})}); + Shape dst = ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {2, 3}, {1, 0}), + MakeShapeWithLayout(F32, {42, 123}, {1, 0}), + ShapeUtil::MakeTupleShape( + {MakeShapeWithLayout(F32, {}, {}), + MakeShapeWithLayout(F32, {1, 2, 3}, {1, 2, 0})})}); + + EXPECT_FALSE(LayoutUtil::LayoutsInShapesEqual(src, dst)); + EXPECT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); @@ -107,6 +160,13 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleSameRank) { EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleSameRank) { + Shape src = MakeShapeWithSparseLayout(F32, {123, 42, 7}, 6); + Shape dst = MakeShapeWithLayout(F32, {2, 3, 5}, {1, 0}); + ASSERT_IS_OK(LayoutUtil::CopyLayoutBetweenShapes(src, &dst)); + EXPECT_TRUE(LayoutUtil::LayoutsInShapesEqual(src, dst)); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); Shape dst = MakeShapeWithLayout(F32, {2, 3}, {1, 0}); @@ -116,6 +176,15 @@ TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleDifferentRank) { ::testing::ContainsRegex("cannot copy layout from shape")); } +TEST_F(LayoutUtilTest, CopyLayoutSparseNotCompatibleDifferentRank) { + Shape src = MakeShapeWithLayout(F32, {123, 42, 7}, {2, 0, 1}); + Shape dst = MakeShapeWithSparseLayout(F32, {2, 3}, 4); + auto status = LayoutUtil::CopyLayoutBetweenShapes(src, &dst); + EXPECT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::ContainsRegex("cannot copy layout from shape")); +} + TEST_F(LayoutUtilTest, CopyLayoutNotCompatibleTuple) { Shape src = ShapeUtil::MakeTupleShape({MakeShapeWithLayout(F32, {2, 3}, {0, 1}), @@ -221,5 +290,16 @@ TEST_F(LayoutUtilTest, DefaultLayoutGettersMajorToMinor) { ShapeUtil::MakeShape(F32, {10, 20, 30, 15, 25})))); } +TEST_F(LayoutUtilTest, SparseLayoutMaxElements) { + EXPECT_EQ(LayoutUtil::MaxSparseElements(LayoutUtil::MakeSparseLayout(101)), + 101); +} + +TEST_F(LayoutUtilTest, StreamOut) { + std::ostringstream oss; + oss << LayoutUtil::MakeLayout({0, 1, 2}); + EXPECT_EQ(oss.str(), "{0,1,2}"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index bfafef0a40f55e13ac94b2d1750df25146081784..c8ed3e3a2b009ddffdfb79a9a6ced8d5e736bee6 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -40,6 +40,10 @@ void SetDebugOptionsDefaults(DebugOptions* flags) { flags->set_xla_cpu_multi_thread_eigen(true); flags->set_xla_gpu_cuda_data_dir("./cuda_sdk_lib"); flags->set_xla_eliminate_hlo_implicit_broadcast(true); + + // Set cudnn batchnorm off by default; it does not provide a performance win + // on average. + flags->set_xla_gpu_use_cudnn_batchnorm(false); } // Allocates flag_values and flag_objects; this function must not be called more @@ -96,179 +100,195 @@ void AllocateFlags() { option_proto, reduce_precision_option_value); }; - flag_objects = new std::vector( - {tensorflow::Flag( - "xla_generate_hlo_graph", - flag_values->mutable_xla_generate_hlo_graph(), - "HLO modules matching this regex will be dumped to a .dot file " - "throughout various stages in compilation."), - tensorflow::Flag( - "xla_hlo_graph_addresses", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), - flag_values->xla_hlo_graph_addresses(), - "With xla_generate_hlo_graph, show addresses of HLO ops in " - "graph dump."), - tensorflow::Flag( - "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), - "With xla_generate_hlo_graph, dump the graphs into this path."), - tensorflow::Flag( - "xla_hlo_dump_as_graphdef", - bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), - flag_values->xla_hlo_dump_as_graphdef(), - "Dump HLO graphs as TensorFlow GraphDefs."), - tensorflow::Flag( - "xla_hlo_graph_sharding_color", - bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), - flag_values->xla_hlo_graph_sharding_color(), - "Assign colors based on sharding assignments when generating the " - "HLO graphs."), - tensorflow::Flag( - "xla_hlo_tfgraph_device_scopes", - bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), - flag_values->xla_hlo_tfgraph_device_scopes(), - "When generating TensorFlow HLO graphs, if the HLO instructions " - "are assigned to a specific device, prefix the name scope with " - "\"devX\" with X being the device ordinal."), - tensorflow::Flag( - "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), - "HLO modules matching this regex will be dumped to LOG(INFO)."), - tensorflow::Flag( - "xla_generate_hlo_text_to", - flag_values->mutable_xla_generate_hlo_text_to(), - "Dump all HLO modules as text into the provided directory path."), - tensorflow::Flag( - "xla_enable_fast_math", - bool_setter_for(&DebugOptions::set_xla_enable_fast_math), - flag_values->xla_enable_fast_math(), - "Enable unsafe fast-math optimizations in the compiler; " - "this may produce faster code at the expense of some accuracy."), - tensorflow::Flag( - "xla_llvm_enable_alias_scope_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), - flag_values->xla_llvm_enable_alias_scope_metadata(), - "In LLVM-based backends, enable the emission of " - "!alias.scope metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_noalias_metadata", - bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), - flag_values->xla_llvm_enable_noalias_metadata(), - "In LLVM-based backends, enable the emission of " - "!noalias metadata in the generated IR."), - tensorflow::Flag( - "xla_llvm_enable_invariant_load_metadata", - bool_setter_for( - &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), - flag_values->xla_llvm_enable_invariant_load_metadata(), - "In LLVM-based backends, enable the emission of " - "!invariant.load metadata in " - "the generated IR."), - tensorflow::Flag( - "xla_llvm_disable_expensive_passes", - bool_setter_for( - &DebugOptions::set_xla_llvm_disable_expensive_passes), - flag_values->xla_llvm_disable_expensive_passes(), - "In LLVM-based backends, disable a custom set of " - "expensive optimization passes."), - tensorflow::Flag( - "xla_backend_optimization_level", - int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), - flag_values->xla_backend_optimization_level(), - "Numerical optimization level for the XLA compiler backend."), - tensorflow::Flag( - "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", - "Comma-separated list of hlo passes to be disabled. These names " - "must exactly match the passes' names; no whitespace around " - "commas."), - tensorflow::Flag( - "xla_embed_ir_in_executable", - bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), - flag_values->xla_embed_ir_in_executable(), - "Embed the compiler IR as a string in the executable."), - tensorflow::Flag( - "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), - "Dump the compiler IR into this directory as individual files."), - tensorflow::Flag( - "xla_eliminate_hlo_implicit_broadcast", - bool_setter_for( - &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), - flag_values->xla_eliminate_hlo_implicit_broadcast(), - "Eliminate implicit broadcasts when lowering user " - "computations to HLO instructions; use explicit " - "broadcast instead."), - tensorflow::Flag( - "xla_cpu_multi_thread_eigen", - bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), - flag_values->xla_cpu_multi_thread_eigen(), - "When generating calls to Eigen in the CPU backend, " - "use multi-threaded Eigen mode."), - tensorflow::Flag("xla_gpu_cuda_data_dir", - flag_values->mutable_xla_gpu_cuda_data_dir(), - "If non-empty, speficies a local directory containing " - "ptxas and nvvm libdevice files; otherwise we use " - "those from runfile directories."), - tensorflow::Flag("xla_gpu_ftz", - bool_setter_for(&DebugOptions::set_xla_gpu_ftz), - flag_values->xla_gpu_ftz(), - "If true, flush-to-zero semantics are enabled in the " - "code generated for GPUs."), - tensorflow::Flag( - "xla_gpu_disable_multi_streaming", - bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), - flag_values->xla_gpu_disable_multi_streaming(), - "If true, multi-streaming in the GPU backend is disabled."), - tensorflow::Flag( - "xla_dump_hlo_proto_to", - flag_values->mutable_xla_dump_hlo_proto_to(), - "Dump compilation artifacts as proto binary into this directory."), - tensorflow::Flag( - "xla_test_all_output_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), - flag_values->xla_test_all_output_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of output layouts. For example, with " - "a 3D shape, all permutations of the set {0, 1, 2} are " - "tried."), - tensorflow::Flag( - "xla_test_all_input_layouts", - bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), - flag_values->xla_test_all_input_layouts(), - "Let ClientLibraryTestBase::ComputeAndCompare* test " - "all permutations of *input* layouts. For example, for " - "2 input arguments with 2D shape and 4D shape, the " - "computation will run 2! * 4! times for every possible " - "layouts"), - tensorflow::Flag( - "xla_hlo_profile", - bool_setter_for(&DebugOptions::set_xla_hlo_profile), - flag_values->xla_hlo_profile(), - "Instrument the computation to collect per-HLO cycle counts"), - tensorflow::Flag("xla_dump_computations_to", - flag_values->mutable_xla_dump_computations_to(), - "Dump computations that XLA executes into the provided " - "directory path"), - tensorflow::Flag("xla_dump_executions_to", - flag_values->mutable_xla_dump_executions_to(), - "Dump parameters and results of computations that XLA " - "executes into the provided directory path"), - tensorflow::Flag("xla_backend_extra_options", - setter_for_xla_backend_extra_options, "", - "Extra options to pass to a backend; " - "comma-separated list of 'key=val' strings (=val " - "may be omitted); no whitespace around commas."), - tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision, - "", - "Directions for adding reduce-precision operations. " - "Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is " - "the class of locations in which to insert the " - "operations (e.g., 'OP_OUTPUTS'), E and M are the " - "exponent and matissa bit counts respectively, and " - "OPS and NAMES are comma-separated (no spaces) lists " - "of the operation types and names to which to attach " - "the reduce-precision operations. The NAMES string " - "and its preceding ';' may be omitted. This option " - "may be repeated to define multiple sets of added " - "reduce-precision operations.")}); + flag_objects = new std::vector({ + tensorflow::Flag( + "xla_generate_hlo_graph", + flag_values->mutable_xla_generate_hlo_graph(), + "HLO modules matching this regex will be dumped to a .dot file " + "throughout various stages in compilation."), + tensorflow::Flag( + "xla_hlo_graph_addresses", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_addresses), + flag_values->xla_hlo_graph_addresses(), + "With xla_generate_hlo_graph, show addresses of HLO ops in " + "graph dump."), + tensorflow::Flag( + "xla_hlo_graph_path", flag_values->mutable_xla_hlo_graph_path(), + "With xla_generate_hlo_graph, dump the graphs into this path."), + tensorflow::Flag( + "xla_hlo_dump_as_graphdef", + bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef), + flag_values->xla_hlo_dump_as_graphdef(), + "Dump HLO graphs as TensorFlow GraphDefs."), + tensorflow::Flag( + "xla_hlo_graph_sharding_color", + bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color), + flag_values->xla_hlo_graph_sharding_color(), + "Assign colors based on sharding assignments when generating the " + "HLO graphs."), + tensorflow::Flag( + "xla_hlo_tfgraph_device_scopes", + bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes), + flag_values->xla_hlo_tfgraph_device_scopes(), + "When generating TensorFlow HLO graphs, if the HLO instructions " + "are assigned to a specific device, prefix the name scope with " + "\"devX\" with X being the device ordinal."), + tensorflow::Flag( + "xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(), + "HLO modules matching this regex will be dumped to LOG(INFO)."), + tensorflow::Flag( + "xla_generate_hlo_text_to", + flag_values->mutable_xla_generate_hlo_text_to(), + "Dump all HLO modules as text into the provided directory path."), + tensorflow::Flag( + "xla_enable_fast_math", + bool_setter_for(&DebugOptions::set_xla_enable_fast_math), + flag_values->xla_enable_fast_math(), + "Enable unsafe fast-math optimizations in the compiler; " + "this may produce faster code at the expense of some accuracy."), + tensorflow::Flag( + "xla_llvm_enable_alias_scope_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_alias_scope_metadata), + flag_values->xla_llvm_enable_alias_scope_metadata(), + "In LLVM-based backends, enable the emission of " + "!alias.scope metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_noalias_metadata", + bool_setter_for(&DebugOptions::set_xla_llvm_enable_noalias_metadata), + flag_values->xla_llvm_enable_noalias_metadata(), + "In LLVM-based backends, enable the emission of " + "!noalias metadata in the generated IR."), + tensorflow::Flag( + "xla_llvm_enable_invariant_load_metadata", + bool_setter_for( + &DebugOptions::set_xla_llvm_enable_invariant_load_metadata), + flag_values->xla_llvm_enable_invariant_load_metadata(), + "In LLVM-based backends, enable the emission of " + "!invariant.load metadata in " + "the generated IR."), + tensorflow::Flag( + "xla_llvm_disable_expensive_passes", + bool_setter_for(&DebugOptions::set_xla_llvm_disable_expensive_passes), + flag_values->xla_llvm_disable_expensive_passes(), + "In LLVM-based backends, disable a custom set of " + "expensive optimization passes."), + tensorflow::Flag( + "xla_backend_optimization_level", + int32_setter_for(&DebugOptions::set_xla_backend_optimization_level), + flag_values->xla_backend_optimization_level(), + "Numerical optimization level for the XLA compiler backend."), + tensorflow::Flag( + "xla_disable_hlo_passes", setter_for_xla_disable_hlo_passes, "", + "Comma-separated list of hlo passes to be disabled. These names " + "must exactly match the passes' names; no whitespace around " + "commas."), + tensorflow::Flag( + "xla_embed_ir_in_executable", + bool_setter_for(&DebugOptions::set_xla_embed_ir_in_executable), + flag_values->xla_embed_ir_in_executable(), + "Embed the compiler IR as a string in the executable."), + tensorflow::Flag( + "xla_dump_ir_to", flag_values->mutable_xla_dump_ir_to(), + "Dump the compiler IR into this directory as individual files."), + tensorflow::Flag( + "xla_eliminate_hlo_implicit_broadcast", + bool_setter_for( + &DebugOptions::set_xla_eliminate_hlo_implicit_broadcast), + flag_values->xla_eliminate_hlo_implicit_broadcast(), + "Eliminate implicit broadcasts when lowering user " + "computations to HLO instructions; use explicit " + "broadcast instead."), + tensorflow::Flag( + "xla_cpu_multi_thread_eigen", + bool_setter_for(&DebugOptions::set_xla_cpu_multi_thread_eigen), + flag_values->xla_cpu_multi_thread_eigen(), + "When generating calls to Eigen in the CPU backend, " + "use multi-threaded Eigen mode."), + tensorflow::Flag("xla_gpu_cuda_data_dir", + flag_values->mutable_xla_gpu_cuda_data_dir(), + "If non-empty, speficies a local directory containing " + "ptxas and nvvm libdevice files; otherwise we use " + "those from runfile directories."), + tensorflow::Flag("xla_gpu_ftz", + bool_setter_for(&DebugOptions::set_xla_gpu_ftz), + flag_values->xla_gpu_ftz(), + "If true, flush-to-zero semantics are enabled in the " + "code generated for GPUs."), + tensorflow::Flag( + "xla_gpu_disable_multi_streaming", + bool_setter_for(&DebugOptions::set_xla_gpu_disable_multi_streaming), + flag_values->xla_gpu_disable_multi_streaming(), + "If true, multi-streaming in the GPU backend is disabled."), + tensorflow::Flag( + "xla_dump_optimized_hlo_proto_to", + flag_values->mutable_xla_dump_optimized_hlo_proto_to(), + "Dump Hlo after all hlo passes are executed as proto binary into " + "this directory."), + tensorflow::Flag( + "xla_dump_unoptimized_hlo_proto_to", + flag_values->mutable_xla_dump_unoptimized_hlo_proto_to(), + "Dump HLO before any hlo passes are executed as proto binary into " + "this directory."), + tensorflow::Flag("xla_dump_per_pass_hlo_proto_to", + flag_values->mutable_xla_dump_per_pass_hlo_proto_to(), + "Dump HLO after each pass as an HloProto in binary file " + "format into this directory."), + tensorflow::Flag( + "xla_test_all_output_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), + flag_values->xla_test_all_output_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of output layouts. For example, with " + "a 3D shape, all permutations of the set {0, 1, 2} are " + "tried."), + tensorflow::Flag( + "xla_test_all_input_layouts", + bool_setter_for(&DebugOptions::set_xla_test_all_input_layouts), + flag_values->xla_test_all_input_layouts(), + "Let ClientLibraryTestBase::ComputeAndCompare* test " + "all permutations of *input* layouts. For example, for " + "2 input arguments with 2D shape and 4D shape, the " + "computation will run 2! * 4! times for every possible " + "layouts"), + tensorflow::Flag( + "xla_hlo_profile", + bool_setter_for(&DebugOptions::set_xla_hlo_profile), + flag_values->xla_hlo_profile(), + "Instrument the computation to collect per-HLO cycle counts"), + tensorflow::Flag("xla_dump_computations_to", + flag_values->mutable_xla_dump_computations_to(), + "Dump computations that XLA executes into the provided " + "directory path"), + tensorflow::Flag("xla_dump_executions_to", + flag_values->mutable_xla_dump_executions_to(), + "Dump parameters and results of computations that XLA " + "executes into the provided directory path"), + tensorflow::Flag("xla_backend_extra_options", + setter_for_xla_backend_extra_options, "", + "Extra options to pass to a backend; " + "comma-separated list of 'key=val' strings (=val " + "may be omitted); no whitespace around commas."), + tensorflow::Flag("xla_reduce_precision", setter_for_xla_reduce_precision, + "", + "Directions for adding reduce-precision operations. " + "Format is 'LOCATION=E,M:OPS;NAMES' where LOCATION is " + "the class of locations in which to insert the " + "operations (e.g., 'OP_OUTPUTS'), E and M are the " + "exponent and matissa bit counts respectively, and " + "OPS and NAMES are comma-separated (no spaces) lists " + "of the operation types and names to which to attach " + "the reduce-precision operations. The NAMES string " + "and its preceding ';' may be omitted. This option " + "may be repeated to define multiple sets of added " + "reduce-precision operations."), + tensorflow::Flag( + "xla_gpu_use_cudnn_batchnorm", + bool_setter_for(&DebugOptions::set_xla_gpu_use_cudnn_batchnorm), + flag_values->xla_gpu_use_cudnn_batchnorm(), + "Allows the GPU backend to implement batchnorm HLOs using cudnn, " + "rather than expanding them to a soup of HLOs."), + }); ParseFlagsFromEnv(*flag_objects); } diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h index d0ef8e66ab0bcbf88035ae31fe32eb161e32e998..b53157f59c61cf4e0850e006ad3656f4be63a936 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ #include @@ -35,4 +35,4 @@ xla::DebugOptions GetDebugOptionsFromFlags(); } // namespace legacy_flags } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_FLAGS_H_ diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h index 0c238e6a5decffb0339f428e4ea676944479cf1b..e9cf435d83d8345e974d83f8e5340dafeba8e3b2 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_parsers.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#define TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ #include #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -148,4 +148,4 @@ inline bool parse_xla_reduce_precision_option( } // namespace legacy_flags } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ +#endif // TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_DEBUG_OPTIONS_PARSERS_H_ diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 93d3cd425f0a868b51677058796e9c40c2d3dff8..e0a9b148b443e90a0c4f3e19660b6234d49eef84 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -27,14 +27,20 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" + +using tensorflow::strings::Printf; +using tensorflow::strings::StrCat; + +namespace xla { + namespace { -using tensorflow::int64; constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__; @@ -46,9 +52,8 @@ void ConvertEndianShort(char* bytes, int64 size) { std::swap(bytes[i], bytes[i + 1]); } } -} // namespace -namespace xla { +} // namespace std::ostream& operator<<(std::ostream& out, const Literal& literal) { out << literal.ToString(); @@ -64,12 +69,12 @@ Literal::StrideConfig::StrideConfig( if (!dimensions.empty()) { // Selects the shape with the largest minor dimension as the one upon // which to run the tight stride loop. - if (dimensions[source_shape.layout().minor_to_major()[0]] >= - dimensions[dest_shape.layout().minor_to_major()[0]]) { - minor_dimension = source_shape.layout().minor_to_major()[0]; + if (dimensions[LayoutUtil::Minor(source_shape.layout(), 0)] >= + dimensions[LayoutUtil::Minor(dest_shape.layout(), 0)]) { + minor_dimension = LayoutUtil::Minor(source_shape.layout(), 0); dest_stride = IndexUtil::GetDimensionStride(dest_shape, minor_dimension); } else { - minor_dimension = dest_shape.layout().minor_to_major()[0]; + minor_dimension = LayoutUtil::Minor(dest_shape.layout(), 0); source_stride = IndexUtil::GetDimensionStride(source_shape, minor_dimension); } @@ -78,52 +83,134 @@ Literal::StrideConfig::StrideConfig( } } +Literal::Literal(const Shape& shape) + : Literal(shape, /*allocate_arrays=*/true) {} + +Literal::Literal(const Shape& shape, bool allocate_arrays) + : shape_(shape), pieces_(shape), owns_buffers_(true) { + CHECK(LayoutUtil::HasLayout(shape)); + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + const Shape& subshape = piece.subshape(); + if (ShapeUtil::IsArray(subshape)) { + if (allocate_arrays) { + piece.set_buffer(new char[piece.size_bytes()]); + if (LayoutUtil::IsSparseArray(subshape)) { + piece.set_sparse_indices(new SparseIndexArray( + LayoutUtil::MaxSparseElements(subshape.layout()), + ShapeUtil::Rank(subshape))); + } + } else { + piece.set_buffer(nullptr); + } + } + } +} + +Literal::~Literal() { DeallocateBuffers(); } + +void Literal::DeallocateBuffers() { + if (owns_buffers_) { + for (auto& pair : pieces_) { + Piece& piece = pair.second; + if (piece.buffer() != nullptr) { + delete[] piece.buffer(); + delete piece.sparse_indices(); + } + } + } +} + +Literal::Literal(Literal&& other) { + shape_ = std::move(other.shape_); + pieces_ = std::move(other.pieces_); + // We need to iterate through the pieces to set the subshape pointer + // properly. It must refer to subshapes within shape_. + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + } + owns_buffers_ = other.owns_buffers_; + + other.shape_ = ShapeUtil::MakeNil(); + other.pieces_ = ShapeTree(other.shape_); + other.piece({}).set_subshape(&other.shape_); +} + +Literal& Literal::operator=(Literal&& other) { + DeallocateBuffers(); + shape_ = std::move(other.shape_); + pieces_ = std::move(other.pieces_); + // We need to iterate through the pieces to set the subshape pointer + // properly. It must refer to subshapes within shape_. + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + } + owns_buffers_ = other.owns_buffers_; + + other.shape_ = ShapeUtil::MakeNil(); + other.pieces_ = ShapeTree(other.shape_); + other.piece({}).set_subshape(&other.shape_); + return *this; +} + std::unique_ptr Literal::CreateFromShape(const Shape& shape) { - auto literal = MakeUnique(); - *literal->mutable_shape() = shape; - if (ShapeUtil::IsTuple(shape)) { - int64 num_elements = ShapeUtil::TupleElementCount(shape); - literal->tuple_literals_.resize(num_elements); - for (int i = 0; i < num_elements; ++i) { - std::unique_ptr elem = - CreateFromShape(ShapeUtil::GetTupleElementShape(shape, i)); - literal->tuple_literals_[i] = std::move(*elem); + auto literal = MakeUnique(shape); + for (auto& pair : literal->pieces_) { + Piece& piece = pair.second; + if (ShapeUtil::IsArray(piece.subshape())) { + memset(piece.untyped_data(), 0, piece.size_bytes()); } - } else { - literal->Reserve(ShapeUtil::ElementsIn(literal->shape())); } return literal; } +const SparseIndexArray* Literal::sparse_indices( + const ShapeIndex& shape_index) const { + return piece(shape_index).sparse_indices(); +} + +SparseIndexArray* Literal::sparse_indices(const ShapeIndex& shape_index) { + return piece(shape_index).sparse_indices(); +} + /* static */ std::unique_ptr Literal::CreateFromDimensions( PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions) { return CreateFromShape(ShapeUtil::MakeShape(primitive_type, dimensions)); } -template -Status Literal::CopyRange(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { - const Shape& src_shape = src_literal.shape(); - const Shape& dest_shape = shape(); - tensorflow::gtl::ArraySlice src_data = src_literal.GetArraySlice(); - tensorflow::gtl::MutableArraySlice dest_data = GetMutableArraySlice(); - - TF_RET_CHECK(ShapeUtil::Rank(src_shape) == src_base.size()); - TF_RET_CHECK(ShapeUtil::Rank(dest_shape) == dest_base.size()); +template +Status Literal::CopySliceFromInternal( + const Literal& src_literal, tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::Rank(src_literal.shape()) == src_base.size()); + TF_RET_CHECK(ShapeUtil::Rank(shape()) == dest_base.size()); + + auto linear_index = [](const Shape& shape, + tensorflow::gtl::ArraySlice multi_index) { + return IndexUtil::MultidimensionalIndexToLinearIndex(shape, multi_index); + }; - if (ShapeUtil::Rank(src_shape) == 0 || ShapeUtil::Rank(dest_shape) == 0) { + if (ShapeUtil::Rank(src_literal.shape()) == 0 || + ShapeUtil::Rank(shape()) == 0) { // If any of the two shapes are scalars, we can just call the StridedCopy() // directly, and we know we will be copying only one value. TF_RET_CHECK(copy_size.empty()); - StridedCopy(dest_data, LinearIndex(dest_base), 0, src_data, - src_literal.LinearIndex(src_base), 0, 1); - } else if (!ShapeUtil::HasZeroElements(dest_shape) && - !ShapeUtil::HasZeroElements(src_shape)) { - // Perform copy if neither src literal nor dest literal has dimensions with - // zero element, otherwise it's a no-op. + StridedCopy(data(), linear_index(shape(), dest_base), 0, + src_literal.data(), + linear_index(src_literal.shape(), src_base), 0, 1); + } else if (!ShapeUtil::HasZeroElements(shape()) && + !ShapeUtil::HasZeroElements(src_literal.shape())) { + // Perform copy if neither src nor dest has dimensions with zero element, + // otherwise it's a no-op. TF_RET_CHECK(src_base.size() == dest_base.size()); TF_RET_CHECK(src_base.size() == copy_size.size()); @@ -133,7 +220,8 @@ Status Literal::CopyRange(const Literal& src_literal, // proper stride size at the matching dimension. DimensionVector src_indexes(src_base.size(), 0); DimensionVector dest_indexes(dest_base.size(), 0); - StrideConfig stride_config(src_shape, dest_shape, copy_size); + Literal::StrideConfig stride_config(src_literal.shape(), shape(), + copy_size); auto copy_proc = [&](const std::vector& indexes) { // Map from multi-dimensional index, to source index. @@ -143,89 +231,296 @@ Status Literal::CopyRange(const Literal& src_literal, std::transform(indexes.begin(), indexes.end(), dest_base.begin(), dest_indexes.begin(), std::plus()); - int64 src_index = src_literal.LinearIndex(src_indexes); - int64 dest_index = LinearIndex(dest_indexes); + int64 src_index = linear_index(src_literal.shape(), src_indexes); + int64 dest_index = linear_index(shape(), dest_indexes); - StridedCopy(dest_data, dest_index, stride_config.dest_stride, src_data, - src_index, stride_config.source_stride, - stride_config.minor_loop_size); + // `this->` is needed to workaround MSVC bug: #16882 + StridedCopy(this->data(), dest_index, stride_config.dest_stride, + src_literal.data(), src_index, + stride_config.source_stride, stride_config.minor_loop_size); return true; }; - ShapeUtil::ForEachIndex(src_shape, stride_config.base, + ShapeUtil::ForEachIndex(src_literal.shape(), stride_config.base, stride_config.dimensions, stride_config.step, copy_proc); } return Status::OK(); } -Status Literal::Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size) { +std::vector Literal::DecomposeTuple() { + CHECK(ShapeUtil::IsTuple(shape())); + std::vector elements; + for (int i = 0; i < ShapeUtil::TupleElementCount(shape()); ++i) { + elements.push_back(Literal(ShapeUtil::GetSubshape(shape(), {i}), + /*allocate_arrays=*/false)); + Literal& element = elements.back(); + for (auto& pair : element.pieces_) { + const ShapeIndex& index = pair.first; + Piece& dest_piece = pair.second; + ShapeIndex src_index = {i}; + for (int64 j : index) { + src_index.push_back(j); + } + Piece& src_piece = piece(src_index); + + // Move the respective buffer and sparse indices over to the element + // Literal. + dest_piece.set_buffer(src_piece.buffer()); + src_piece.set_buffer(nullptr); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + src_piece.set_sparse_indices(nullptr); + } + } + // Set this literal to be nil-shaped. + *this = Literal(); + return elements; +} + +/* static */ Literal Literal::MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements) { + std::vector element_shapes; + for (const Literal& element : elements) { + element_shapes.push_back(element.shape()); + } + Literal literal(ShapeUtil::MakeTupleShape(element_shapes), + /*allocate_arrays=*/false); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal.MoveFrom(std::move(elements[i]), /*dest_shape_index=*/{i})); + } + return literal; +} + +namespace { + +// Copies the elements in 'src' to 'dest'. The shape and layout of the data in +// the array slices are indicated by dest_shape and src_shape respectively. +template +void CopyElementsBetween(tensorflow::gtl::MutableArraySlice dest, + tensorflow::gtl::ArraySlice src, + const Shape& dest_shape, const Shape& src_shape) { + CHECK(ShapeUtil::Compatible(dest_shape, src_shape)); + if (ShapeUtil::HasZeroElements(dest_shape)) { + return; + } + std::vector index(ShapeUtil::Rank(dest_shape)); + do { + dest[IndexUtil::MultidimensionalIndexToLinearIndex(dest_shape, index)] = + src[IndexUtil::MultidimensionalIndexToLinearIndex(src_shape, index)]; + } while (IndexUtil::BumpIndices(dest_shape, &index)); +} + +} // namespace + +Status Literal::Piece::CopyFrom(const Literal::Piece& src) { + if (ShapeUtil::Equal(subshape(), src.subshape())) { + // If the layouts are equal it's faster just to memcpy. + memcpy(buffer(), src.buffer(), src.size_bytes()); + } else { + TF_RET_CHECK(ShapeUtil::Compatible(src.subshape(), subshape())); + std::vector origin(ShapeUtil::Rank(subshape()), 0); + switch (subshape().element_type()) { +#define COPY_ELEMENTS(XLA_T, NATIVE_T) \ + case (XLA_T): \ + CopyElementsBetween(data(), src.data(), \ + subshape(), src.subshape()); \ + break; + COPY_ELEMENTS(U8, uint8); + COPY_ELEMENTS(U16, uint16); + COPY_ELEMENTS(U32, uint32); + COPY_ELEMENTS(U64, uint64); + COPY_ELEMENTS(S8, int8); + COPY_ELEMENTS(S16, int16); + COPY_ELEMENTS(S32, int32); + COPY_ELEMENTS(S64, int64); + COPY_ELEMENTS(F16, half); + COPY_ELEMENTS(BF16, bfloat16); + COPY_ELEMENTS(F32, float); + COPY_ELEMENTS(F64, double); + COPY_ELEMENTS(C64, complex64); + COPY_ELEMENTS(PRED, bool); +#undef COPY_ELEMENTS + default: + return Unimplemented( + "Unhandled primitive type %s", + PrimitiveType_Name(subshape().element_type()).c_str()); + } + } + return Status::OK(); +} + +Status Literal::CopyFrom(const Literal& src_literal, + const ShapeIndex& dest_shape_index, + const ShapeIndex& src_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + const Shape& src_subshape = + ShapeUtil::GetSubshape(src_literal.shape(), src_shape_index); + if (!ShapeUtil::Compatible(dest_subshape, src_subshape)) { + return InvalidArgument( + "Destination subshape incompatible with source subshape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_subshape).c_str()); + } + + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + // Determine if this index is in the part of this literal that we want to + // copy over from src_literal. + bool in_subtree_to_copy = true; + for (int i = 0; i < dest_shape_index.size(); ++i) { + if (index[i] != dest_shape_index[i]) { + in_subtree_to_copy = false; + break; + } + } + if (!in_subtree_to_copy) { + continue; + } + + // Construct the index of the corresponding piece in the source literal. + ShapeIndex src_piece_index = src_shape_index; + for (int64 i = dest_shape_index.size(); i < index.size(); ++i) { + src_piece_index.push_back(index[i]); + } + + TF_RETURN_IF_ERROR(piece.CopyFrom(src_literal.piece(src_piece_index))); + } + return Status::OK(); +} + +Status Literal::MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index) { + const Shape& dest_subshape = + ShapeUtil::GetSubshape(shape(), dest_shape_index); + if (!ShapeUtil::Equal(dest_subshape, src_literal.shape())) { + return InvalidArgument( + "Destination subshape not equal to source shape: %s vs %s", + ShapeUtil::HumanString(dest_subshape).c_str(), + ShapeUtil::HumanString(src_literal.shape()).c_str()); + } + + if (!(owns_buffers_ && src_literal.owns_buffers_)) { + return InvalidArgument( + "Source and destination literals must both own their buffers (ie, not " + "be views)"); + } + + for (auto& pair : src_literal.pieces_) { + const ShapeIndex& src_index = pair.first; + Piece& src_piece = pair.second; + if (!ShapeUtil::IsArray(src_piece.subshape())) { + continue; + } + + ShapeIndex dest_index = dest_shape_index; + for (int64 i : src_index) { + dest_index.push_back(i); + } + Piece& dest_piece = piece(dest_index); + delete[] dest_piece.buffer(); + dest_piece.set_buffer(src_piece.buffer()); + delete dest_piece.sparse_indices(); + dest_piece.set_sparse_indices(src_piece.sparse_indices()); + } + + src_literal.shape_ = ShapeUtil::MakeNil(); + src_literal.pieces_ = ShapeTree(src_literal.shape_); + src_literal.piece({}).set_subshape(&src_literal.shape_); + return Status::OK(); +} + +Status Literal::CopySliceFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size) { + TF_RET_CHECK(ShapeUtil::IsArray(shape())) << ShapeUtil::HumanString(shape()); + TF_RET_CHECK(ShapeUtil::IsArray(src_literal.shape())) + << ShapeUtil::HumanString(src_literal.shape()); TF_RET_CHECK(ShapeUtil::SameElementType(src_literal.shape(), shape())); - switch (src_literal.shape().element_type()) { + + switch (shape().element_type()) { case U8: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case U16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case U32: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case U64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S8: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S32: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case S64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case F16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case BF16: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case F32: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case F64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case C64: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); case PRED: - return CopyRange(src_literal, src_base, dest_base, copy_size); + return CopySliceFromInternal(src_literal, src_base, dest_base, + copy_size); default: break; } - return Unimplemented("Unhandled primitive type %d", - src_literal.shape().element_type()); + return Unimplemented("Unhandled primitive type %d", shape().element_type()); } /* static */ Literal Literal::Zero(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case U32: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case U64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case S8: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case S32: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case S64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case F16: - return *Literal::CreateR0(static_cast(0.0f)); + return std::move(*Literal::CreateR0(static_cast(0.0f))); case BF16: - return *Literal::CreateR0(static_cast(0.0f)); + return std::move( + *Literal::CreateR0(static_cast(0.0f))); case F32: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case F64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case C64: - return *Literal::CreateR0(0); + return std::move(*Literal::CreateR0(0)); case PRED: - return *Literal::CreateR0(false); + return std::move(*Literal::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; @@ -241,30 +536,33 @@ Status Literal::Copy(const Literal& src_literal, /* static */ Literal Literal::One(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case U32: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case U64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case S8: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case S32: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case S64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); + case F16: + return std::move(*Literal::CreateR0(static_cast(1.0f))); + case BF16: + return std::move( + *Literal::CreateR0(static_cast(1.0f))); case F32: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case F64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case C64: - return *Literal::CreateR0(1); + return std::move(*Literal::CreateR0(1)); case PRED: - return *Literal::CreateR0(true); + return std::move(*Literal::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; - case F16: - return *Literal::CreateR0(static_cast(1.0f)); case TUPLE: LOG(FATAL) << "tuple element type cannot take on value of 1"; case OPAQUE: @@ -277,35 +575,42 @@ Status Literal::Copy(const Literal& src_literal, /* static */ Literal Literal::MinValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case U32: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case U64: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case S8: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case S32: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case S64: - return *Literal::CreateR0(std::numeric_limits::min()); + return std::move( + *Literal::CreateR0(std::numeric_limits::min())); case F32: - return *Literal::CreateR0(-std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(-std::numeric_limits::infinity())); case F64: - return *Literal::CreateR0( - -std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(-std::numeric_limits::infinity())); case C64: LOG(FATAL) << "C64 element type has no minimum value"; case PRED: - return *Literal::CreateR0(false); + return std::move(*Literal::CreateR0(false)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *Literal::CreateR0( - static_cast(-std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(-std::numeric_limits::infinity()))); case BF16: - return *Literal::CreateR0( - static_cast(-std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(-std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no minimum value"; case OPAQUE: @@ -318,33 +623,40 @@ Status Literal::Copy(const Literal& src_literal, /* static */ Literal Literal::MaxValue(PrimitiveType primitive_type) { switch (primitive_type) { case U8: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case U32: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case U64: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case S8: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case S32: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case S64: - return *Literal::CreateR0(std::numeric_limits::max()); + return std::move( + *Literal::CreateR0(std::numeric_limits::max())); case F32: - return *Literal::CreateR0(std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(std::numeric_limits::infinity())); case F64: - return *Literal::CreateR0( - std::numeric_limits::infinity()); + return std::move( + *Literal::CreateR0(std::numeric_limits::infinity())); case PRED: - return *Literal::CreateR0(true); + return std::move(*Literal::CreateR0(true)); case S16: case U16: LOG(FATAL) << "u16/s16 literals not yet implemented"; case F16: - return *Literal::CreateR0( - static_cast(std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(std::numeric_limits::infinity()))); case BF16: - return *Literal::CreateR0( - static_cast(std::numeric_limits::infinity())); + return std::move(*Literal::CreateR0( + static_cast(std::numeric_limits::infinity()))); case TUPLE: LOG(FATAL) << "tuple element type has no maximum value"; case OPAQUE: @@ -356,17 +668,29 @@ Status Literal::Copy(const Literal& src_literal, /* static */ std::unique_ptr Literal::CreateR1( const tensorflow::core::Bitmap& values) { - auto literal = MakeUnique(); + auto literal = MakeUnique( + ShapeUtil::MakeShape(PRED, {static_cast(values.bits())})); literal->PopulateR1(values); return literal; } +void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(element_count(), values.bits()); + CHECK_EQ(shape().element_type(), PRED); + for (int64 i = 0; i < static_cast(values.bits()); ++i) { + Set({i}, values.get(i)); + } +} + /* static */ std::unique_ptr Literal::CreateR1U8( tensorflow::StringPiece value) { - auto literal = MakeUnique(); - *literal->mutable_shape() = - ShapeUtil::MakeShape(U8, {static_cast(value.size())}); - literal->set_u8s(tensorflow::StringPiece(value.ToString())); + auto literal = MakeUnique( + ShapeUtil::MakeShape(U8, {static_cast(value.size())})); + for (int i = 0; i < value.size(); ++i) { + literal->Set({i}, value[i]); + } return literal; } @@ -380,46 +704,50 @@ Status Literal::Copy(const Literal& src_literal, std::unique_ptr Literal::Relayout( const Layout& new_layout, const ShapeIndex& shape_index) const { - std::unique_ptr outer_result = CloneToUnique(); - - const Literal* copy_from = this; - Literal* copy_to = outer_result.get(); - for (int64 i = 0; i < shape_index.size(); i++) { - *ShapeUtil::GetMutableSubshape(copy_to->mutable_shape(), {shape_index, i}) - ->mutable_layout() = new_layout; - copy_from = ©_from->tuple_literals_[shape_index[i]]; - copy_to = ©_to->tuple_literals_[shape_index[i]]; - } - - DimensionVector base(ShapeUtil::Rank(copy_from->shape()), 0); - DimensionVector copy_size(copy_from->shape().dimensions().begin(), - copy_from->shape().dimensions().end()); + // Create new shape with 'new_layout' set at the given shape index. + Shape new_shape = shape(); + Shape* subshape = ShapeUtil::GetMutableSubshape(&new_shape, shape_index); + TF_CHECK_OK(LayoutUtil::ValidateLayoutForShape(new_layout, *subshape)); + *subshape->mutable_layout() = new_layout; + auto result = MakeUnique(new_shape); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; +} - CHECK(ShapeUtil::IsArray(copy_from->shape())); - CHECK(ShapeUtil::IsArray(copy_to->shape())); - *copy_to->mutable_shape()->mutable_layout() = new_layout; - TF_CHECK_OK(copy_to->Copy(*copy_from, base, base, copy_size)); - return outer_result; +std::unique_ptr Literal::Relayout( + const Shape& shape_with_layout) const { + CHECK(ShapeUtil::Compatible(shape_with_layout, shape())) + << "Given shape_with_layout " << ShapeUtil::HumanString(shape_with_layout) + << " not compatible with literal shape " + << ShapeUtil::HumanString(shape()); + std::unique_ptr result = CreateFromShape(shape_with_layout); + ShapeUtil::ForEachSubshape( + result->shape(), + [this, &result](const Shape& subshape, const ShapeIndex& index) { + if (ShapeUtil::IsArray(subshape)) { + TF_CHECK_OK(result->CopyFrom(*this, + /*dest_shape_index=*/index, + /*src_shape_index=*/index)); + } + }); + return result; } StatusOr> Literal::Reshape( tensorflow::gtl::ArraySlice dimensions) const { - if (ShapeUtil::IsTuple(shape())) { + if (!ShapeUtil::IsArray(shape())) { return InvalidArgument("Reshape does not support tuples."); } std::unique_ptr output; if (!LayoutUtil::IsMonotonicWithDim0Major(shape().layout())) { - std::vector minor_to_major(ShapeUtil::Rank(shape())); - std::iota(minor_to_major.rbegin(), minor_to_major.rend(), - static_cast(0)); - output = Relayout(LayoutUtil::MakeLayout(minor_to_major)); + output = + Relayout(LayoutUtil::GetDefaultLayoutForRank(ShapeUtil::Rank(shape()))); } else { output = CloneToUnique(); } // Because the layout is monotonic, we can simply reuse the same sequence of // values without changing their order. - *output->mutable_shape() = - ShapeUtil::MakeShape(shape().element_type(), dimensions); + output->shape_ = ShapeUtil::MakeShape(shape().element_type(), dimensions); int64 elements_before = ShapeUtil::ElementsIn(shape()); int64 elements_after = ShapeUtil::ElementsIn(output->shape()); @@ -435,7 +763,7 @@ StatusOr> Literal::Reshape( std::unique_ptr Literal::Transpose( tensorflow::gtl::ArraySlice permutation) const { - CHECK(!ShapeUtil::IsTuple(shape())) << "Tuple is not supported for transpose"; + CHECK(ShapeUtil::IsArray(shape())) << "Tuple is not supported for transpose"; CHECK(IsPermutation(permutation, ShapeUtil::Rank(shape()))) << "Given permutation is not a permutation of dimension numbers"; // To transpose the array, we just permute the dimensions and layout, and @@ -458,23 +786,24 @@ std::unique_ptr Literal::Transpose( // dimension has within the transposed array, a layout is affine if // MinMaj(Di) == TMinMaj(T(Di)), with TMinMaj() being the minor to major // vector of the affine layout. + CHECK(LayoutUtil::IsDenseArray(permuted_shape)); Layout* layout = permuted_shape.mutable_layout(); layout->clear_minor_to_major(); - for (auto index : shape().layout().minor_to_major()) { + for (auto index : LayoutUtil::MinorToMajor(shape())) { layout->add_minor_to_major(inverse_permutation[index]); } std::unique_ptr new_literal = CreateFromShape(permuted_shape); DCHECK_GE(ShapeUtil::ByteSizeOf(new_literal->shape()), ShapeUtil::ByteSizeOf(shape())); - std::memcpy(new_literal->MutableInternalData(), InternalData(), - ShapeUtil::ByteSizeOf(shape())); + std::memcpy(new_literal->root_piece().buffer(), root_piece().buffer(), + root_piece().size_bytes()); return new_literal; } std::unique_ptr Literal::Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const { - CHECK(!ShapeUtil::IsTuple(shape())) << "tuple is not supported for reshape"; + CHECK(ShapeUtil::IsArray(shape())) << "tuple is not supported for slice"; DimensionVector result_dimensions; for (int64 dnum = 0; dnum < ShapeUtil::Rank(shape()); ++dnum) { @@ -484,13 +813,11 @@ std::unique_ptr Literal::Slice( CHECK_GT(dimension, 0); result_dimensions.push_back(dimension); } - const auto result_shape = ShapeUtil::MakeShapeWithLayout( - shape().element_type(), result_dimensions, - AsInt64Slice(shape().layout().minor_to_major())); + const auto result_shape = + ShapeUtil::MakeShapeWithLayout(shape().element_type(), result_dimensions, + LayoutUtil::MinorToMajor(shape())); - auto result_literal = MakeUnique(); - *result_literal->mutable_shape() = result_shape; - result_literal->Reserve(ShapeUtil::ElementsIn(result_shape)); + auto result_literal = MakeUnique(result_shape); DimensionVector new_indices(ShapeUtil::Rank(result_shape)); switch (result_shape.element_type()) { @@ -504,6 +831,16 @@ std::unique_ptr Literal::Slice( result_literal->Set(indices, value); }); return result_literal; + case C64: + result_literal->EachCell( + [&](tensorflow::gtl::ArraySlice indices, complex64 /*value*/) { + for (int64 i = 0; i < ShapeUtil::Rank(result_shape); ++i) { + new_indices[i] = indices[i] + start_indices[i]; + } + complex64 value = Get(new_indices); + result_literal->Set(indices, value); + }); + return result_literal; case S32: result_literal->EachCell( [&](tensorflow::gtl::ArraySlice indices, int32 /*value*/) { @@ -530,48 +867,116 @@ std::unique_ptr Literal::Slice( } } +Literal Literal::Clone() const { + Literal result(shape()); + TF_CHECK_OK(result.CopyFrom(*this)); + return result; +} + std::unique_ptr Literal::CloneToUnique() const { - auto unique = MakeUnique(); - *unique = *this; - return unique; + auto result = MakeUnique(shape()); + TF_CHECK_OK(result->CopyFrom(*this)); + return result; } -string Literal::GetAsString( - tensorflow::gtl::ArraySlice multi_index) const { - switch (shape().element_type()) { +string Literal::GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsDenseArray(subshape)); + switch (subshape.element_type()) { case PRED: - return Get(multi_index) ? "true" : "false"; - case U8: - return tensorflow::strings::StrCat(Get(multi_index)); + return Get(multi_index, shape_index) ? "true" : "false"; + case S8: + return StrCat(Get(multi_index, shape_index)); + case S16: + return StrCat(Get(multi_index, shape_index)); case S32: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); case S64: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); + case U8: + return StrCat(Get(multi_index, shape_index)); + case U16: + return StrCat(Get(multi_index, shape_index)); case U32: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); case U64: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); + case F16: + return StrCat(Get(multi_index, shape_index)); case F32: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); + case BF16: + return StrCat( + static_cast(Get(multi_index, shape_index))); case F64: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(Get(multi_index, shape_index)); case C64: { - complex64 c = Get(multi_index); - return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")"); + complex64 c = Get(multi_index, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); } + default: + LOG(FATAL) << PrimitiveType_Name(subshape.element_type()); + } +} + +string Literal::GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + const Shape& subshape = ShapeUtil::GetSubshape(shape(), shape_index); + CHECK(LayoutUtil::IsSparseArray(subshape)); + switch (subshape.element_type()) { + case PRED: + return GetSparseElement(sparse_element_number, shape_index) + ? "true" + : "false"; + case S8: + return StrCat(GetSparseElement(sparse_element_number, shape_index)); + case S16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case S64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U8: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U16: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case U64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); case F16: - return tensorflow::strings::StrCat(Get(multi_index)); + return StrCat(GetSparseElement(sparse_element_number, shape_index)); + case F32: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); case BF16: - return tensorflow::strings::StrCat( - static_cast(Get(multi_index))); + return StrCat(static_cast( + GetSparseElement(sparse_element_number, shape_index))); + case F64: + return StrCat( + GetSparseElement(sparse_element_number, shape_index)); + case C64: { + complex64 c = + GetSparseElement(sparse_element_number, shape_index); + return StrCat("(", c.real(), ", ", c.imag(), ")"); + } default: - return tensorflow::strings::StrCat( - "[", PrimitiveType_Name(shape().element_type()), "]"); + LOG(FATAL) << "Invalid element type for sparse arrays: " + << PrimitiveType_Name(subshape.element_type()); } } StatusOr Literal::GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(shape())); switch (shape().element_type()) { case PRED: return Get(multi_index); @@ -592,13 +997,83 @@ StatusOr Literal::GetIntegralAsS64( } } -int64 Literal::LinearIndex( - tensorflow::gtl::ArraySlice multi_index) const { - return IndexUtil::MultidimensionalIndexToLinearIndex(shape(), multi_index); +tensorflow::gtl::ArraySlice Literal::GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index) const { + const Piece& p = piece(shape_index); + CHECK_GE(sparse_element_number, 0); + CHECK_LT(sparse_element_number, p.sparse_indices()->index_count()); + return p.sparse_indices()->At(sparse_element_number); } -string Literal::ToString(bool print_layout) const { - std::vector pieces; +void Literal::SortSparseElements(const ShapeIndex& shape_index) { + piece(shape_index).SortSparseElements(); +} + +void Literal::Piece::SortSparseElements() { + switch (subshape().element_type()) { + case PRED: + SortSparseElementsInternal(); + break; + case S8: + SortSparseElementsInternal(); + break; + case U8: + SortSparseElementsInternal(); + break; + case S16: + SortSparseElementsInternal(); + break; + case U16: + SortSparseElementsInternal(); + break; + case S32: + SortSparseElementsInternal(); + break; + case U32: + SortSparseElementsInternal(); + break; + case S64: + SortSparseElementsInternal(); + break; + case U64: + SortSparseElementsInternal(); + break; + case F32: + SortSparseElementsInternal(); + break; + case F64: + SortSparseElementsInternal(); + break; + case C64: + SortSparseElementsInternal(); + break; + case F16: + SortSparseElementsInternal(); + break; + case BF16: + SortSparseElementsInternal(); + break; + default: + LOG(FATAL) << "Element type not valid for sparse array: " + << PrimitiveType_Name(subshape().element_type()); + } +} + +template +void Literal::Piece::SortSparseElementsInternal() { + CHECK(LayoutUtil::IsSparseArray(subshape())); + int64 num_elements = sparse_indices()->index_count(); + auto values = data(); + CHECK_LE(num_elements, values.size()); + sparse_indices()->SortWithValues( + tensorflow::gtl::MutableArraySlice(values.data(), num_elements)); +} + +namespace { + +void ToStringHelper(const Literal& literal, const ShapeIndex& shape_index, + bool print_layout, std::vector* pieces) { + const Shape& subshape = ShapeUtil::GetSubshape(literal.shape(), shape_index); auto shape_to_string = [print_layout](const Shape& shape) { if (print_layout) { @@ -608,322 +1083,236 @@ string Literal::ToString(bool print_layout) const { } }; + // TODO(b/32894291): refactor this code to reduce code duplication. + if (ShapeUtil::IsTuple(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" (\n"); + std::vector tuple_pieces; + for (int i = 0; i < ShapeUtil::TupleElementCount(subshape); ++i) { + ShapeIndex element_index = shape_index; + element_index.push_back(i); + std::vector element_pieces; + ToStringHelper(literal, element_index, print_layout, &element_pieces); + tuple_pieces.push_back(tensorflow::str_util::Join(element_pieces, "")); + } + pieces->push_back(tensorflow::str_util::Join(tuple_pieces, ",\n")); + pieces->push_back("\n)"); + return; + } + + if (LayoutUtil::IsSparseArray(subshape)) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back("{"); + int64 rank = ShapeUtil::Rank(subshape); + int64 num_elements = literal.sparse_element_count(); + for (int64 i = 0; i < num_elements; ++i) { + if (i > 0) { + pieces->push_back(", "); + } + if (rank == 1) { + pieces->push_back(StrCat(literal.GetSparseIndex(i)[0])); + pieces->push_back(": "); + } else { + pieces->push_back("["); + pieces->push_back( + tensorflow::str_util::Join(literal.GetSparseIndex(i), ", ")); + pieces->push_back("]: "); + } + pieces->push_back(literal.GetSparseElementAsString(i)); + } + pieces->push_back("}"); + return; + } + + CHECK(LayoutUtil::IsDenseArray(subshape)); + auto element_to_string = - [this](tensorflow::gtl::ArraySlice indices) -> string { - PrimitiveType element_type = shape().element_type(); + [&](tensorflow::gtl::ArraySlice indices) -> string { + PrimitiveType element_type = subshape.element_type(); if (element_type == PRED) { // We display predicates in a densely packed form. - return Get(indices) ? "1" : "0"; + return literal.Get(indices, shape_index) ? "1" : "0"; } return ((!indices.empty() && indices.back() > 0) ? ", " : "") + - GetAsString(indices); + literal.GetAsString(indices, shape_index); }; - // TODO(b/32894291): refactor this code to reduce code duplication. - if (ShapeUtil::IsTuple(shape())) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" (\n"); - pieces.push_back(tensorflow::str_util::Join( - tuple_literals(), ",\n", [](string* out, const Literal& element) { - tensorflow::strings::StrAppend(out, element.ToString()); - })); - pieces.push_back("\n)"); - } else if (ShapeUtil::Rank(shape()) == 0) { - pieces.push_back(GetAsString({})); - } else if (ShapeUtil::Rank(shape()) == 1) { - pieces.push_back("{"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(element_to_string({i0})); + if (ShapeUtil::Rank(subshape) == 0) { + pieces->push_back(literal.GetAsString({}, shape_index)); + } else if (ShapeUtil::Rank(subshape) == 1) { + pieces->push_back("{"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(element_to_string({i0})); } - pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape()) == 2) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(" { "); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back(element_to_string({i0, i1})); + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 2) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(" { "); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(element_to_string({i0, i1})); } - pieces.push_back(" "); - pieces.push_back(i0 == shape().dimensions(0) - 1 ? "}\n" : "},\n"); + pieces->push_back(" "); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? "}\n" : "},\n"); } - pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape()) == 3) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(i0 > 0 ? ",\n{" : "{"); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back(i1 > 0 ? ",\n { " : " { "); - for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { - pieces.push_back(element_to_string({i0, i1, i2})); + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 3) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(i0 > 0 ? ",\n{" : "{"); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(i1 > 0 ? ",\n { " : " { "); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(element_to_string({i0, i1, i2})); } - pieces.push_back(" }"); + pieces->push_back(" }"); } - pieces.push_back(" }"); + pieces->push_back(" }"); } - pieces.push_back("\n}"); - } else if (ShapeUtil::Rank(shape()) == 4) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back( - tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { - pieces.push_back(" {"); - for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { - pieces.push_back(element_to_string({i0, i1, i2, i3})); + pieces->push_back("\n}"); + } else if (ShapeUtil::Rank(subshape) == 4) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(" {"); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(element_to_string({i0, i1, i2, i3})); } - pieces.push_back(i2 == shape().dimensions(2) - 1 ? "}\n" : "},\n"); + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? "}\n" : "},\n"); } - pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" - : " },\n"); + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); } - pieces.push_back("}"); - } else if (ShapeUtil::Rank(shape()) == 5) { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {\n"); - for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); - for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { - pieces.push_back( - tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); - for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { - pieces.push_back( - tensorflow::strings::Printf(" { /*i2=%lld*/\n", i2)); - for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { - pieces.push_back(" {"); - for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) { - pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); + pieces->push_back("}"); + } else if (ShapeUtil::Rank(subshape) == 5) { + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {\n"); + for (int64 i0 = 0; i0 < subshape.dimensions(0); ++i0) { + pieces->push_back(Printf(" { /*i0=%lld*/\n", i0)); + for (int64 i1 = 0; i1 < subshape.dimensions(1); ++i1) { + pieces->push_back(Printf(" { /*i1=%lld*/\n", i1)); + for (int64 i2 = 0; i2 < subshape.dimensions(2); ++i2) { + pieces->push_back(Printf(" { /*i2=%lld*/\n", i2)); + for (int64 i3 = 0; i3 < subshape.dimensions(3); ++i3) { + pieces->push_back(" {"); + for (int64 i4 = 0; i4 < subshape.dimensions(4); ++i4) { + pieces->push_back(element_to_string({i0, i1, i2, i3, i4})); } - pieces.push_back(i3 == shape().dimensions(3) - 1 ? "}\n" : "},\n"); + pieces->push_back(i3 == subshape.dimensions(3) - 1 ? "}\n" + : "},\n"); } - pieces.push_back(i2 == shape().dimensions(2) - 1 ? " }\n" - : " },\n"); + pieces->push_back(i2 == subshape.dimensions(2) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" - : " },\n"); + pieces->push_back(i1 == subshape.dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); + pieces->push_back(i0 == subshape.dimensions(0) - 1 ? " }\n" : " },\n"); } - pieces.push_back("}"); + pieces->push_back("}"); } else { - pieces.push_back(shape_to_string(shape())); - pieces.push_back(" {...}"); + pieces->push_back(shape_to_string(subshape)); + pieces->push_back(" {"); + literal.EachCellAsString( + [&](tensorflow::gtl::ArraySlice indices, const string& value) { + pieces->push_back(" "); + pieces->push_back(value); + }); + pieces->push_back("}"); } +} + +} // namespace +int64 Literal::sparse_element_count() const { + CHECK(LayoutUtil::IsSparseArray(shape())); + return sparse_indices()->index_count(); +} + +string Literal::ToString(bool print_layout) const { + std::vector pieces; + ToStringHelper(*this, {}, print_layout, &pieces); return tensorflow::str_util::Join(pieces, ""); } /* static */ std::unique_ptr Literal::MakeTuple( tensorflow::gtl::ArraySlice elements) { - auto literal = MakeUnique(); - std::vector shape; - for (const Literal* tuple_element : elements) { - *literal->add_tuple_literals() = *tuple_element; - shape.push_back(tuple_element->shape()); + std::vector element_shapes; + for (const Literal* element : elements) { + element_shapes.push_back(element->shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int i = 0; i < elements.size(); ++i) { + TF_CHECK_OK(literal->CopyFrom(*elements[i], /*dest_shape_index=*/{i})); } - *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape); return literal; } /* static */ std::unique_ptr Literal::MakeTupleOwned( std::vector> elements) { - auto literal = MakeUnique(); - std::vector shape; - for (auto& tuple_element : elements) { - shape.push_back(tuple_element->shape()); - *literal->add_tuple_literals() = std::move(*tuple_element); + std::vector element_shapes; + element_shapes.reserve(elements.size()); + for (const auto& element : elements) { + element_shapes.push_back(element->shape()); + } + auto literal = MakeUnique(ShapeUtil::MakeTupleShape(element_shapes)); + for (int64 i = 0; i < elements.size(); ++i) { + TF_CHECK_OK( + literal->MoveFrom(std::move(*elements[i]), /*dest_shape_index=*/{i})); } - *literal->mutable_shape() = ShapeUtil::MakeTupleShape(shape); return literal; } -const void* Literal::InternalData() const { - return const_cast( - const_cast(this)->MutableInternalData()); +void Literal::EachCellAsString( + const std::function indices, + const string& value)>& per_cell) const { + if (ShapeUtil::HasZeroElements(shape())) { + return; + } + std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( + shape(), /*linear_index=*/0); + do { + per_cell(indices, GetAsString(indices)); + } while (IndexUtil::BumpIndices(shape(), &indices)); } -void* Literal::MutableInternalData() { - // NOTE: We access the vectors directly to avoid the const reference - // created by the accessor functions. - switch (shape().element_type()) { - case PRED: - case U8: - return reinterpret_cast(u8s_.data()); - case S32: - return reinterpret_cast(s32s_.data()); - case S64: - return reinterpret_cast(s64s_.data()); - case U32: - return reinterpret_cast(u32s_.data()); - case U64: - return reinterpret_cast(u64s_.data()); - case F32: - return reinterpret_cast(f32s_.data()); - case F64: - return reinterpret_cast(f64s_.data()); - case C64: - return reinterpret_cast(c64s_.data()); - case F16: - return reinterpret_cast(f16s_.data()); - case BF16: - return reinterpret_cast(bf16s_.data()); - default: - LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(shape().element_type()); +namespace { +template +std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique(ShapeUtil::ChangeElementType( + src_literal.shape(), + primitive_util::NativeToPrimitiveType())); + auto src_data = src_literal.data(); + auto dest_data = result_literal->template data(); + int64 num_elements = src_literal.element_count(); + + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = static_cast(src_data[i]); } -} - -void Literal::Reserve(int64 num_elements) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - switch (shape().element_type()) { - case PRED: - Resize(num_elements, false); - break; - case S8: - Resize(num_elements, 0); - break; - case U8: - Resize(num_elements, 0); - break; - case S32: - Resize(num_elements, 0); - break; - case S64: - Resize(num_elements, 0); - break; - case U32: - Resize(num_elements, 0); - break; - case U64: - Resize(num_elements, 0); - break; - case F32: - Resize(num_elements, 0); - break; - case F64: - Resize(num_elements, 0); - break; - case C64: - Resize(num_elements, 0); - break; - case F16: - Resize(num_elements, static_cast(0.0f)); - break; - case BF16: - Resize(num_elements, static_cast(0.0f)); - break; - default: - LOG(FATAL) << "primitive type not supported in literals: " - << PrimitiveType_Name(shape().element_type()); - } -} - -tensorflow::Status Literal::ValidateLiteral() const { - TF_CHECK_OK(ShapeUtil::ValidateShape(shape())); - int64 expected = ShapeUtil::ElementsIn(shape()); - int64 actual = -1; - switch (shape().element_type()) { - case PRED: - case U8: - actual = u8s_size(); - break; - case S32: - actual = s32s_size(); - break; - case U32: - actual = u32s_size(); - break; - case S64: - actual = s64s_size(); - break; - case U64: - actual = u64s_size(); - break; - case F32: - actual = f32s_size(); - break; - case F64: - actual = f64s_size(); - break; - case C64: - actual = c64s_size(); - break; - case F16: - actual = f16s().size() / sizeof(half); - break; - case BF16: - actual = bf16s().size(); - break; - default: - return tensorflow::errors::Unimplemented( - "unhandled element type for literal validation: " + - PrimitiveType_Name(shape().element_type())); - } - - if (expected != actual) { - return tensorflow::errors::InvalidArgument(tensorflow::strings::Printf( - "literal has bad number of elements for its shape %s: want %lld " - "got %lld", - ShapeUtil::HumanString(shape()).c_str(), expected, actual)); - } - - return tensorflow::Status::OK(); -} - -void Literal::EachCellAsString( - const std::function indices, - const string& value)>& per_cell) const { - if (ShapeUtil::HasZeroElements(shape())) { - return; - } - std::vector indices = IndexUtil::LinearIndexToMultidimensionalIndex( - shape(), /*linear_index=*/0); - do { - per_cell(indices, GetAsString(indices)); - } while (IndexUtil::BumpIndices(shape(), &indices)); -} - -namespace { -template -std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { - auto result_literal = MakeUnique(); - Shape* result_shape = result_literal->mutable_shape(); - *result_shape = src_literal.shape(); - result_shape->set_element_type( - primitive_util::NativeToPrimitiveType()); - result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); - tensorflow::gtl::ArraySlice src_data = - src_literal.GetArraySlice(); - tensorflow::gtl::MutableArraySlice dest_data = - result_literal->GetMutableArraySlice(); - int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); - - for (int64 i = 0; i < num_elements; ++i) { - dest_data[i] = static_cast(src_data[i]); - } - return result_literal; + return result_literal; } template std::unique_ptr ConvertToC64(const Literal& src_literal) { - auto result_literal = MakeUnique(); - Shape* result_shape = result_literal->mutable_shape(); - *result_shape = src_literal.shape(); - result_shape->set_element_type(C64); - result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + CHECK(ShapeUtil::IsArray(src_literal.shape())); + auto result_literal = MakeUnique( + ShapeUtil::ChangeElementType(src_literal.shape(), C64)); using NativeSrcT = typename primitive_util::PrimitiveTypeToNative::type; tensorflow::gtl::ArraySlice src_data = - src_literal.GetArraySlice(); + src_literal.data(); tensorflow::gtl::MutableArraySlice dest_data = - result_literal->GetMutableArraySlice(); - int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + result_literal->data(); + int64 num_elements = src_literal.element_count(); for (int64 i = 0; i < num_elements; ++i) { dest_data[i] = complex64(static_cast(src_data[i]), 0); } @@ -968,10 +1357,12 @@ StatusOr> ConvertIfDestTypeMatches( PrimitiveType_Name(primitive_dest_type).c_str()); } } + } // namespace StatusOr> Literal::Convert( PrimitiveType primitive_dest_type) const { + TF_RET_CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { #define CONVERT_IF_DEST_TYPE_MATCHES(type) \ case (type): \ @@ -996,356 +1387,192 @@ StatusOr> Literal::Convert( } } -namespace { - -// Helper function which compares whether the elements of literal1 are equal to -// the elements of literal2. Recursively iterates through the entire -// multidimensional index space and compares the literal elements -// one-by-one. literal1 and literal2 must be compatible (same dimensions and -// type). template -bool EqualElements(const Literal& literal1, const Literal& literal2, - int dimension, std::vector* multi_index) { - if (dimension == ShapeUtil::Rank(literal1.shape())) { - return (literal1.Get(*multi_index) == - literal2.Get(*multi_index)); - } - for (int64 i = 0; i < literal1.shape().dimensions(dimension); ++i) { - (*multi_index)[dimension] = i; - if (!EqualElements(literal1, literal2, dimension + 1, - multi_index)) { +bool Literal::Piece::EqualElementsInternal( + const Literal::Piece& other, std::vector* multi_index) const { + if (multi_index->size() == ShapeUtil::Rank(subshape())) { + return (Get(*multi_index) == other.Get(*multi_index)); + } + for (int64 i = 0; i < subshape().dimensions(multi_index->size()); ++i) { + multi_index->push_back(i); + if (!EqualElementsInternal(other, multi_index)) { return false; } + multi_index->pop_back(); } return true; } -} // namespace +bool Literal::Piece::EqualElements(const Literal::Piece& other) const { + DCHECK(ShapeUtil::Compatible(subshape(), other.subshape())); + + std::vector multi_index; + switch (subshape().element_type()) { + case PRED: + return EqualElementsInternal(other, &multi_index); + case U8: + return EqualElementsInternal(other, &multi_index); + case S32: + return EqualElementsInternal(other, &multi_index); + case S64: + return EqualElementsInternal(other, &multi_index); + case U32: + return EqualElementsInternal(other, &multi_index); + case U64: + return EqualElementsInternal(other, &multi_index); + case F32: + return EqualElementsInternal(other, &multi_index); + case F64: + return EqualElementsInternal(other, &multi_index); + case F16: + return EqualElementsInternal(other, &multi_index); + case BF16: + return EqualElementsInternal(other, &multi_index); + case C64: + return EqualElementsInternal(other, &multi_index); + default: + LOG(FATAL) << "Unimplemented: Literal::Piece::EqualElements for type " + << PrimitiveType_Name(subshape().element_type()); + } +} bool Literal::operator==(const Literal& other) const { if (!ShapeUtil::Compatible(shape(), other.shape())) { return false; } - if (ShapeUtil::IsTuple(shape())) { - // Because the shapes are compatible, they must have the same number of - // tuple elements. - CHECK_EQ(tuple_literals_size(), other.tuple_literals_size()); - for (int i = 0; i < tuple_literals_size(); ++i) { - if (tuple_literals(i) != other.tuple_literals(i)) { - return false; - } + for (const auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; } - return true; - } else { - std::vector multi_index(ShapeUtil::Rank(shape()), 0); - switch (shape().element_type()) { - case PRED: - return EqualElements(*this, other, 0, &multi_index); - case U8: - return EqualElements(*this, other, 0, &multi_index); - case S32: - return EqualElements(*this, other, 0, &multi_index); - case S64: - return EqualElements(*this, other, 0, &multi_index); - case U32: - return EqualElements(*this, other, 0, &multi_index); - case U64: - return EqualElements(*this, other, 0, &multi_index); - case F32: - return EqualElements(*this, other, 0, &multi_index); - case F64: - return EqualElements(*this, other, 0, &multi_index); - case F16: - return EqualElements(*this, other, 0, &multi_index); - case BF16: - return EqualElements(*this, other, 0, &multi_index); - case C64: - return EqualElements(*this, other, 0, &multi_index); - default: - LOG(FATAL) << "Unimplemented: Literal::Equal for type " - << PrimitiveType_Name(shape().element_type()); + + const Piece& other_piece = other.piece(index); + if (!piece.EqualElements(other_piece)) { + return false; } } + return true; } -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_preds(); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u8s(); - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u8s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_s16s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u16s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_s32s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_u32s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - static_assert(sizeof(int64) == sizeof(tensorflow::protobuf_int64) && - alignof(int64) == alignof(tensorflow::protobuf_int64), - "The int64 and tensorflow::protobuf_int64 types are not " - "compatible"); - auto values = mutable_s64s(); - // Because of the fact that tensorflow::protobuf_int64 is defined as int64_t - // while tensorflow::int64 is defined as long long, a reinterpret_cast<> is - // necessary from the raw data pointer returned by the mutable_data() API. - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - static_assert(sizeof(uint64) == sizeof(tensorflow::protobuf_uint64) && - alignof(uint64) == alignof(tensorflow::protobuf_uint64), - "The uint64 and tensorflow::protobuf_uint64 types are not " - "compatible"); - auto values = mutable_u64s(); - // Because of the fact that tensorflow::protobuf_uint64 is defined as uint64_t - // while tensorflow::uint64 is defined as unsigned long long, a - // reinterpret_cast<> is necessary from the raw data pointer returned by the - // mutable_data() API. - return tensorflow::gtl::MutableArraySlice( - reinterpret_cast(values->data()), values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_f32s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_f64s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_c64s(); - return {values->data(), values->size()}; -} - -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { - auto values = mutable_f16s(); - return tensorflow::gtl::MutableArraySlice(values->data(), - values->size()); -} - -template <> -tensorflow::gtl::MutableArraySlice -Literal::GetMutableArraySlice() { - auto values = mutable_bf16s(); - return {values->data(), values->size()}; -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), PRED); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(preds().data()), preds().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U8); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(u8s().data()), u8s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S8); - return tensorflow::gtl::ArraySlice( - reinterpret_cast(u8s().data()), u8s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U16); - return tensorflow::gtl::ArraySlice(u16s().data(), u16s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S16); - return tensorflow::gtl::ArraySlice(s16s().data(), s16s().size()); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U32); - return u32s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), U64); - return u64s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S32); - return s32s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), S64); - return s64s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), F64); - return f64s(); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), F16); - return tensorflow::gtl::ArraySlice(f16s().data(), - f16s().size() / sizeof(half)); -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { - CHECK_EQ(shape().element_type(), BF16); - return {bf16s().data(), bf16s().size()}; -} - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() - const { - CHECK_EQ(shape().element_type(), C64); - return c64s(); -} +namespace { template -static bool AllElementsEqualValue(const Literal& literal, NativeT value) { - for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { - auto multi_index = - IndexUtil::LinearIndexToMultidimensionalIndex(literal.shape(), i); - if (literal.Get(multi_index) != value) { +static bool AllElementsEqualValue(tensorflow::gtl::ArraySlice data, + NativeT value) { + for (int64 i = 0; i < data.size(); ++i) { + if (data[i] != value) { return false; } } return true; } +} // namespace + bool Literal::IsAll(int8 value) const { - switch (shape().element_type()) { - case U8: - if (value >= 0) { - return AllElementsEqualValue(*this, value); - } - return false; - case U32: - if (value >= 0) { - return AllElementsEqualValue(*this, value); - } - return false; - case U64: - if (value >= 0) { - return AllElementsEqualValue(*this, value); - } - return false; - case S8: - return AllElementsEqualValue(*this, value); - case S32: - return AllElementsEqualValue(*this, value); - case S64: - return AllElementsEqualValue(*this, value); - case F32: - return AllElementsEqualValue(*this, value); - case F64: - return AllElementsEqualValue(*this, value); - case F16: - return AllElementsEqualValue(*this, static_cast(value)); - case BF16: - return AllElementsEqualValue(*this, - static_cast(value)); - case PRED: - if (value == 0) { - return AllElementsEqualValue(*this, false); - } - if (value == 1) { - return AllElementsEqualValue(*this, true); + for (const auto& pair : pieces_) { + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case U8: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U32: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case U64: + if (value >= 0) { + return AllElementsEqualValue(piece.data(), value); + } + return false; + case S8: + return AllElementsEqualValue(piece.data(), value); + case S32: + return AllElementsEqualValue(piece.data(), value); + case S64: + return AllElementsEqualValue(piece.data(), value); + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case PRED: + if (value == 0) { + return AllElementsEqualValue(piece.data(), false); + } + if (value == 1) { + return AllElementsEqualValue(piece.data(), true); + } + return false; + default: + return false; } return false; - default: + }; + + if (!piece_is_all()) { return false; + } } + return true; } bool Literal::IsAllFloat(float value) const { - switch (shape().element_type()) { - case F32: - return AllElementsEqualValue(*this, value); - case F64: - return AllElementsEqualValue(*this, value); - case F16: - return AllElementsEqualValue(*this, static_cast(value)); - case BF16: - return AllElementsEqualValue(*this, - static_cast(value)); - default: + for (const auto& pair : pieces_) { + const Piece& piece = pair.second; + if (!ShapeUtil::IsArray(piece.subshape())) { + continue; + } + + auto piece_is_all = [&]() { + switch (shape().element_type()) { + case F32: + return AllElementsEqualValue(piece.data(), value); + case F64: + return AllElementsEqualValue(piece.data(), value); + case F16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + case BF16: + return AllElementsEqualValue(piece.data(), + static_cast(value)); + default: + return false; + } + }; + if (!piece_is_all()) { return false; + } } + return true; } bool Literal::IsAllComplex(complex64 value) const { switch (shape().element_type()) { case C64: - return AllElementsEqualValue(*this, value); + return AllElementsEqualValue(root_piece().data(), + value); default: return false; } } bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { + CHECK(ShapeUtil::IsArray(shape())); switch (shape().element_type()) { case U8: return Get(indices) == 0; @@ -1376,247 +1603,294 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { } } -template <> -/* static */ void Literal::Resize(int64 num_elements, bool value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_preds()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, int8 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u8s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, uint8 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u8s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, int32 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_s32s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, uint32 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u32s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, int64 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_s64s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, uint64 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_u64s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, float value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_f32s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, double value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_f64s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, half value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_f16s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, bfloat16 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_bf16s()->resize(num_elements, value); -} - -template <> -void Literal::Resize(int64 num_elements, complex64 value) { - CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); - mutable_c64s()->resize(num_elements, value); -} +namespace { template void CopyToRepeatedField(RepeatedFieldT* dest, - const std::vector& src) { + const tensorflow::gtl::ArraySlice src) { *dest = RepeatedFieldT(src.begin(), src.end()); } -template <> -void CopyToRepeatedField, complex64>( - tensorflow::protobuf::RepeatedField* dest, - const std::vector& src) { - *dest = tensorflow::protobuf::RepeatedField( - reinterpret_cast(src.data()), - reinterpret_cast(src.data()) + src.size() * 2); -} +} // namespace -LiteralProto Literal::ToProto() const { - LiteralProto proto; - proto.Clear(); - *proto.mutable_shape() = shape(); - switch (shape().element_type()) { +void Literal::Piece::WriteToProto(LiteralProto* proto) const { + *proto->mutable_shape() = subshape(); + switch (subshape().element_type()) { case PRED: - CopyToRepeatedField(proto.mutable_preds(), preds()); + CopyToRepeatedField(proto->mutable_preds(), data()); break; case U8: - *proto.mutable_u8s() = u8s_string(); - break; - case S32: - CopyToRepeatedField(proto.mutable_s32s(), s32s()); - break; - case S64: - CopyToRepeatedField(proto.mutable_s64s(), s64s()); + proto->set_u8s(static_cast(data().data()), + element_count()); break; case U32: - CopyToRepeatedField(proto.mutable_u32s(), u32s()); + CopyToRepeatedField(proto->mutable_u32s(), data()); break; case U64: - CopyToRepeatedField(proto.mutable_u64s(), u64s()); + CopyToRepeatedField(proto->mutable_u64s(), data()); + break; + case S32: + CopyToRepeatedField(proto->mutable_s32s(), data()); + break; + case S64: + CopyToRepeatedField(proto->mutable_s64s(), data()); break; case F16: - *proto.mutable_f16s() = - string(reinterpret_cast(f16s_.data()), - f16s_.size() * sizeof(half)); + *proto->mutable_f16s() = string( + reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto.mutable_f16s()->data()), - proto.f16s().size()); + ConvertEndianShort(const_cast(proto->mutable_f16s()->data()), + proto->f16s().size()); } break; case BF16: - *proto.mutable_bf16s() = - string(reinterpret_cast(bf16s_.data()), - bf16s_.size() * sizeof(bfloat16)); + *proto->mutable_bf16s() = string( + reinterpret_cast(data().data()), size_bytes()); if (!kLittleEndian) { - ConvertEndianShort(const_cast(proto.mutable_bf16s()->data()), - proto.bf16s().size()); + ConvertEndianShort(const_cast(proto->mutable_bf16s()->data()), + proto->bf16s().size()); } break; case F32: - CopyToRepeatedField(proto.mutable_f32s(), f32s()); + CopyToRepeatedField(proto->mutable_f32s(), data()); break; case F64: - CopyToRepeatedField(proto.mutable_f64s(), f64s()); + CopyToRepeatedField(proto->mutable_f64s(), data()); break; case C64: - CopyToRepeatedField(proto.mutable_c64s(), c64s()); - break; - case TUPLE: - for (const auto& tuple : tuple_literals()) { - *proto.add_tuple_literals() = tuple.ToProto(); + for (complex64 value : data()) { + proto->add_c64s(value.real()); + proto->add_c64s(value.imag()); } break; + case TUPLE: + // Nothing to do but assign the shape which is done above. + return; default: - LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); } - - return proto; } -template -void CopyFromRepeatedField(std::vector* dest, - const RepeatedFieldT& src) { - *dest = std::vector(src.begin(), src.end()); +const void* Literal::Piece::untyped_data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); } -template <> -void CopyFromRepeatedField, - complex64>( - std::vector* dest, - const tensorflow::protobuf::RepeatedField& src) { - *dest = std::vector( - reinterpret_cast(src.data()), - reinterpret_cast(src.data()) + src.size() / 2); +void* Literal::Piece::untyped_data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + return buffer(); } -void Literal::CopyFromProto(const LiteralProto& literal_proto) { - if (!literal_proto.has_shape()) { - return; +namespace { + +template +Status CopyFromRepeatedField(tensorflow::gtl::MutableArraySlice dest, + const RepeatedFieldT& src) { + if (dest.size() != src.size()) { + return InvalidArgument( + "Expected %lu elements in LiteralProto repeated field, has %d", + dest.size(), src.size()); } + std::copy(src.begin(), src.end(), dest.begin()); + return Status::OK(); +} - *mutable_shape() = literal_proto.shape(); - switch (shape().element_type()) { +} // namespace + +Status Literal::Piece::CopyFromProto(const LiteralProto& proto) { + // These conditions should have been checked in Literal::CreateFromProto. + TF_RET_CHECK(proto.has_shape()); + TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape())); + TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape())); + + switch (subshape().element_type()) { case PRED: - CopyFromRepeatedField(mutable_preds(), literal_proto.preds()); - break; - case U8: - set_u8s(literal_proto.u8s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.preds())); break; + case U8: { + auto u8_data = data(); + TF_RET_CHECK(proto.u8s().size() == u8_data.size()); + std::copy(proto.u8s().begin(), proto.u8s().end(), u8_data.begin()); + } break; case S32: - CopyFromRepeatedField(mutable_s32s(), literal_proto.s32s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s32s())); break; case S64: - CopyFromRepeatedField(mutable_s64s(), literal_proto.s64s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.s64s())); break; case U32: - CopyFromRepeatedField(mutable_u32s(), literal_proto.u32s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u32s())); break; case U64: - CopyFromRepeatedField(mutable_u64s(), literal_proto.u64s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.u64s())); break; case F16: { - const string& s(literal_proto.f16s()); - CHECK_EQ(0, s.size() % sizeof(half)); - f16s_ = std::vector(s.size() / sizeof(half)); - memcpy(f16s_.data(), s.data(), s.size()); - + const string& s(proto.f16s()); + TF_RET_CHECK(data().size() * sizeof(half) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(f16s_.data()), s.size()); + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); } - break; - } - case BF16: { - const string& s(literal_proto.bf16s()); - CHECK_EQ(0, s.size() % sizeof(bfloat16)); - bf16s_ = std::vector(s.size() / sizeof(bfloat16)); - memcpy(bf16s_.data(), s.data(), s.size()); + } break; + case BF16: { + const string& s(proto.bf16s()); + TF_RET_CHECK(data().size() * sizeof(bfloat16) == s.size()); + memcpy(untyped_data(), s.data(), s.size()); if (!kLittleEndian) { - ConvertEndianShort(reinterpret_cast(bf16s_.data()), s.size()); + ConvertEndianShort(reinterpret_cast(untyped_data()), s.size()); } - break; - } + } break; case F32: - CopyFromRepeatedField(mutable_f32s(), literal_proto.f32s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f32s())); break; case F64: - CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); - break; - case C64: - CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s()); + TF_RETURN_IF_ERROR(CopyFromRepeatedField(data(), proto.f64s())); break; - case TUPLE: - for (const auto& proto : literal_proto.tuple_literals()) { - mutable_tuple_literals()->push_back(Literal(proto)); + case C64: { + auto complex_data = data(); + TF_RET_CHECK(proto.c64s_size() == complex_data.size() * 2); + for (int64 i = 0; i < complex_data.size(); ++i) { + complex_data[i] = complex64{proto.c64s(i * 2), proto.c64s(i * 2 + 1)}; } + } break; + case TUPLE: + LOG(FATAL) << "Should not be called on tuple shapes: " + << ShapeUtil::HumanString(subshape()); break; default: - LOG(FATAL) << "Unhandled primitive type " << shape().element_type(); + LOG(FATAL) << "Unhandled primitive type " << subshape().element_type(); + } + return Status::OK(); +} + +LiteralProto Literal::ToProto() const { + LiteralProto proto; + for (const auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + const Piece& piece = pair.second; + + LiteralProto* proto_piece = &proto; + for (int64 i : index) { + while (proto_piece->tuple_literals_size() <= i) { + proto_piece->add_tuple_literals(); + } + proto_piece = proto_piece->mutable_tuple_literals(i); + } + piece.WriteToProto(proto_piece); + } + + if (LayoutUtil::IsSparseArray(shape())) { + CopyToRepeatedField(proto.mutable_sparse_indices(), + sparse_indices()->data()); + } + + return proto; +} + +/* static */ +StatusOr> Literal::CreateFromProto( + const LiteralProto& proto) { + if (!proto.has_shape()) { + return InvalidArgument("LiteralProto has no shape"); + } + if (!LayoutUtil::HasLayout(proto.shape())) { + return InvalidArgument("LiteralProto has no layout"); + } + + auto literal = MakeUnique(proto.shape()); + + for (auto& pair : literal->pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + const LiteralProto* proto_element = &proto; + for (int64 i : index) { + TF_RET_CHECK(i < proto_element->tuple_literals_size()); + proto_element = &proto_element->tuple_literals(i); + } + + if (ShapeUtil::IsTuple(piece.subshape())) { + if (proto_element->tuple_literals_size() != + ShapeUtil::TupleElementCount(piece.subshape())) { + return InvalidArgument( + "Expected %lld tuple elements in LiteralProto, has %d", + ShapeUtil::TupleElementCount(piece.subshape()), + proto_element->tuple_literals_size()); + } + continue; + } + + TF_RET_CHECK(ShapeUtil::IsArray(piece.subshape())); + TF_RETURN_IF_ERROR(piece.CopyFromProto(*proto_element)); } + return std::move(literal); } -const Literal& Literal::GetSubliteral(const ShapeIndex& index) const { - return const_cast(this)->GetSubliteral(index); +const void* Literal::untyped_data(const ShapeIndex& shape_index) const { + return piece(shape_index).untyped_data(); +} + +void* Literal::untyped_data(const ShapeIndex& shape_index) { + return piece(shape_index).untyped_data(); +} + +int64 Literal::size_bytes(const ShapeIndex& shape_index) const { + return piece(shape_index).size_bytes(); +} + +string Literal::GetR1U8AsString() const { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(shape().element_type(), U8); + return string(tensorflow::bit_cast(data().data()), + ShapeUtil::ElementsIn(shape())); +} + +/* static */ const LiteralView LiteralView::Create( + const Literal& literal, const ShapeIndex& view_root) { + return LiteralView(literal, view_root); +} + +LiteralView::LiteralView(const Literal& literal, const ShapeIndex& view_root) { + shape_ = ShapeUtil::GetSubshape(literal.shape(), view_root); + pieces_ = ShapeTree(shape_); + owns_buffers_ = false; + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + + ShapeIndex src_index = view_root; + for (int64 i : index) { + src_index.push_back(i); + } + const Piece& src_piece = literal.piece(src_index); + piece.set_buffer(src_piece.buffer()); + piece.set_sparse_indices(src_piece.sparse_indices()); + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); + } +} + +LiteralView::~LiteralView() {} + +LiteralView::LiteralView(const LiteralView& other) { CopyFrom(other); } + +LiteralView& LiteralView::operator=(const LiteralView& other) { + CopyFrom(other); + return *this; } -Literal& Literal::GetSubliteral(const ShapeIndex& index) { - Literal* subliteral = this; - for (int64 i : index) { - subliteral = &subliteral->tuple_literals_.at(i); +void LiteralView::CopyFrom(const LiteralView& other) { + // We can't use the default copy-constructor/copy-assignment because + // Piece::subshape_ points to subshapes within the Shape of the owning + // Literal/LiteralView. + shape_ = other.shape(); + pieces_ = other.pieces_; + for (auto& pair : pieces_) { + const ShapeIndex& index = pair.first; + Piece& piece = pair.second; + piece.set_subshape(&ShapeUtil::GetSubshape(shape_, index)); } - return *subliteral; + owns_buffers_ = false; } } // namespace xla diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index f37e529caf54e3aded1a418d1f01c1440cd0f284..d996004888ab521790b4c5a10da2a93f0d98d12f 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -34,7 +34,9 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_tree.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/sparse_index_array.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -50,152 +52,70 @@ limitations under the License. namespace xla { -// Utility class for dealing with XLA literal values. Most methods are -// templated by native (host) type which corresponds to a unique XLA -// PrimitiveType. See ComputationBuilder for details. Not all primitive types -// defined in xla_data.proto have a corresponding native type or even have a -// storage location in the Literal proto yet (for example, primitive type F16). +// Class representing literal values in XLA. +// +// TODO(b/67651157): The methods in this class should be reduced to a minimal +// set of methods which construct Literals and accessors methods. Other methods +// which perform computation on Literals (Reshape, Slice, etc) should be moved +// elsewhere, and perhaps combined with evaluator code which operates on +// Literals. class Literal { public: - Literal() {} + Literal() : Literal(ShapeUtil::MakeNil()) {} - Literal(const Literal& other) = default; - Literal(Literal&&) = default; + // Create a literal of the given shape. The literal is allocated sufficient + // memory to hold the shape. Memory is uninitialized. + explicit Literal(const Shape& shape); + virtual ~Literal(); - explicit Literal(const LiteralProto& other) { CopyFromProto(other); } - - Literal& operator=(const Literal& other) = default; - Literal& operator=(Literal&&) = default; + // Literals are moveable, but not copyable. To copy a literal use + // Literal::Clone or Literal::CloneToUnique. This prevents inadvertent copies + // of literals which can be expensive. + Literal(const Literal& other) = delete; + Literal& operator=(const Literal& other) = delete; + Literal(Literal&& other); + Literal& operator=(Literal&& other); // Literals are equal if they have compatible shapes and the same data - // values. Layout is not checked. + // values. Layout is not compared. bool operator==(const Literal& other) const; bool operator!=(const Literal& other) const { return !(*this == other); } + // Serialize to and from a proto. + static StatusOr> CreateFromProto( + const LiteralProto& proto); LiteralProto ToProto() const; - bool has_shape() const { - return shape_.element_type() != PRIMITIVE_TYPE_INVALID; - } - - // Basic accessor functions. Names mirror the original protobuf - // functions for convenience. - string DebugString() const { return ToProto().DebugString(); } - string ShortDebugString() const { return ToProto().ShortDebugString(); } - - // Return the nested literal at the given shape index. - const Literal& GetSubliteral(const ShapeIndex& index) const; - Literal& GetSubliteral(const ShapeIndex& index); - - void Clear() { - shape_.Clear(); - u8s_.clear(); - s16s_.clear(); - s32s_.clear(); - s64s_.clear(); - u16s_.clear(); - u32s_.clear(); - u64s_.clear(); - f16s_.clear(); - f32s_.clear(); - f64s_.clear(); - tuple_literals_.clear(); - } - - int preds_size() const { return u8s().size(); } - const std::vector& preds() const { - static_assert(sizeof(uint8) == sizeof(bool), - "The uint8 and bool types should be the same size"); - return u8s_; - } - std::vector* mutable_preds() { - static_assert(sizeof(uint8) == sizeof(bool), - "The uint8 and bool types should be the same size"); - return &u8s_; - } - - int s16s_size() const { return s16s().size(); } - int32 s16s(int i) const { return s16s_[i]; } - const std::vector& s16s() const { return s16s_; } - std::vector* mutable_s16s() { return &s16s_; } - - int s32s_size() const { return s32s().size(); } - int32 s32s(int i) const { return s32s_[i]; } - const std::vector& s32s() const { return s32s_; } - std::vector* mutable_s32s() { return &s32s_; } - - int s64s_size() const { return s64s().size(); } - void add_s64s(int64 value) { s64s_.push_back(value); } - const std::vector& s64s() const { return s64s_; } - std::vector* mutable_s64s() { return &s64s_; } - - int u16s_size() const { return u16s().size(); } - uint32 u16s(int i) const { return u16s_[i]; } - const std::vector& u16s() const { return u16s_; } - std::vector* mutable_u16s() { return &u16s_; } - - int u32s_size() const { return u32s().size(); } - uint32 u32s(int i) const { return u32s_[i]; } - const std::vector& u32s() const { return u32s_; } - std::vector* mutable_u32s() { return &u32s_; } - - int u64s_size() const { return u64s().size(); } - const std::vector& u64s() const { return u64s_; } - std::vector* mutable_u64s() { return &u64s_; } - - int f16s_size() const { return f16s().size(); } - half f16s(int i) const { return f16s_[i]; } - const std::vector& f16s() const { return f16s_; } - std::vector* mutable_f16s() { return &f16s_; } - - int f32s_size() const { return f32s().size(); } - float f32s(int i) const { return f32s_[i]; } - void add_f32s(float value) { f32s_.push_back(value); } - const std::vector& f32s() const { return f32s_; } - std::vector& f32s() { return f32s_; } - std::vector* mutable_f32s() { return &f32s_; } - - int f64s_size() const { return f64s().size(); } - const std::vector& f64s() const { return f64s_; } - std::vector* mutable_f64s() { return &f64s_; } - - int c64s_size() const { return c64s().size(); } - const std::vector& c64s() const { return c64s_; } - std::vector* mutable_c64s() { return &c64s_; } - - int bf16s_size() const { return bf16s().size(); } - bfloat16 bf16s(int i) const { return bf16s_[i]; } - const std::vector& bf16s() const { return bf16s_; } - std::vector* mutable_bf16s() { return &bf16s_; } - - int tuple_literals_size() const { return tuple_literals().size(); } - const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } - Literal* add_tuple_literals() { - tuple_literals_.push_back(Literal()); - return &tuple_literals_.back(); - } - std::vector* mutable_tuple_literals() { return &tuple_literals_; } - const std::vector& tuple_literals() const { return tuple_literals_; } - - int u8s_size() const { return u8s().size(); } - const std::vector& u8s() const { return u8s_; } - void set_u8s(const std::vector& value) { u8s_ = value; } - void set_u8s(tensorflow::StringPiece value) { - u8s_ = std::vector(value.size()); - u8s_.clear(); - append_u8s(value); - } - - void append_u8s(tensorflow::StringPiece value) { - u8s_.insert(u8s_.end(), value.begin(), value.end()); - } - - string u8s_string() const { return string(u8s().begin(), u8s().end()); } + // Return the shape of the literal. + const Shape& shape() const { return shape_; } - std::vector* mutable_u8s() { return &u8s_; } + // TODO(b/67651157): Remove this accessor. Literal users should not be able to + // mutate the shape as this can produce malformed Literals. + Shape* mutable_shape_do_not_use() { return &shape_; } - const Shape& shape() const { return shape_; } - Shape* mutable_shape() { return &shape_; } + // Returns a (Mutable)ArraySlice view of the array for this literal for the + // given NativeT (e.g., float). CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. See primitive_util.h for the mapping from + // XLA type to native type. + template + tensorflow::gtl::ArraySlice data( + const ShapeIndex& shape_index = {}) const; + template + tensorflow::gtl::MutableArraySlice data( + const ShapeIndex& shape_index = {}); + + // Returns a pointer to the sparse index array. Returns nullptr if the literal + // is not a sparse array. + const SparseIndexArray* sparse_indices( + const ShapeIndex& shape_index = {}) const; + SparseIndexArray* sparse_indices(const ShapeIndex& shape_index = {}); + + // Returns a pointer to (or size of) the underlying buffer holding the array + // at the given shape index. CHECKs if the subshape of the literal at the + // given ShapeIndex is not array. + const void* untyped_data(const ShapeIndex& shape_index = {}) const; + void* untyped_data(const ShapeIndex& shape_index = {}); + int64 size_bytes(const ShapeIndex& shape_index = {}) const; // Creates a new literal of a given rank. To minimize ambiguity (for users // and the compiler) these CreateR[0-2] methods should explicitly specify the @@ -243,6 +163,60 @@ class Literal { values, const Layout& layout); + // Returns this literal's data as a string. This literal must be a rank-1 U8 + // array. + string GetR1U8AsString() const; + + // Creates a literal with a sparse layout and the given indices and values. + // The shape is initialized from the given dimensions. The minor dimension of + // the indices array must equal the rank of the shape (i.e. size of the + // dimensions array). The major dimension of the indices array must equal the + // number of elements in the values array. The maximum number of elements in + // the array is taken from the max_indices() value of the index array. + // + // XLA assumes that sparse literals are in sorted order for all operations. If + // the `sort` argument is true, then the indices and values will be sorted + // while copying them into the literal. If you have ensured that the indices + // and values are already sorted, then you may set the `sort` argument to + // false to skip the sorting step. + // + // For example: + // + // CreateSparse( + // {12, 12, 12}, + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }), + // {1.0, 2.0 3.0, 4.0}) + // + // This creates an array with shape F64[12,12,12]sparse{10}, that has the + // following non-zero values: + // + // [0, 1, 2]: 1.0 + // [3, 4, 5]: 2.0 + // [6, 7, 8]: 3.0 + // [9, 10, 11]: 4.0 + // + template + static std::unique_ptr CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort = true); + + // Populates a literal with a sparse layout with the given indices and values. + // Each index in the indices array is CHECKed against the dimensions in the + // literal's shape. If sort is true, then the indices and values will be + // sorted. If sort is false, then the indices and values are assumed to + // already be in sorted order. See CreateSparse for an example of how data + // are populated. + template + void PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort = true); + // Creates a new Literal object with the shape specified as parameter. // The content of the literal values is the default value of the primitive // type of literal itself (0 for numeric types, and false for predicates). @@ -256,6 +230,23 @@ class Literal { PrimitiveType primitive_type, tensorflow::gtl::ArraySlice dimensions); + // Copy values from 'src_literal' rooted at 'src_shape_index' into this + // literal rooted at 'dest_shape_index'. The subshape of this literal rooted + // at 'dest_shape_index' must be compatible with the subshape of 'src_literal' + // rooted at 'src_shape_index', but need not be arrays. + Status CopyFrom(const Literal& src_literal, + const ShapeIndex& dest_shape_index = {}, + const ShapeIndex& src_shape_index = {}); + + // Similar to CopyFrom, but with move semantincs. The subshape of this literal + // rooted at 'dest_shape_index' must be *equal* to the shape 'src_literal' + // (layouts and shapes must match), but need not be arrays. The memory + // allocated in this literal for the subshape at dest_shape_index is + // deallocated, and the respective buffers are replaced with those in + // src_literal. Upon return, src_literal is set to a nil shape (empty tuple). + Status MoveFrom(Literal&& src_literal, + const ShapeIndex& dest_shape_index = {}); + // Copies the values from src_literal, starting at src_base shape indexes, // to this literal, starting at dest_base, where the copy size in each // dimension is specified by copy_size. @@ -265,10 +256,24 @@ class Literal { // Note: if either src_literal or this literal contains dimensions with zero // element, then copy_size must be 0 in these dimensions while the // corresponding base indices being 0. - Status Copy(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + // This literal and 'src_literal' must be arrays. + Status CopySliceFrom(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); + + // Returns a vector containing the tuple elements of this Literal as separate + // Literals. This Literal must be tuple-shaped and can be a nested tuple. The + // elements are moved into the new Literals; no data is copied. Upon return + // this Literal is set to a nil shape (empty tuple) + std::vector DecomposeTuple(); + + // This operation is the inverse of DecomposeTuple. The given elements are + // moved into the tuple elements of a new tuple-shaped Literal which is + // returned. Upon return, each of the Literals in 'elements' is set to a nil + // shape (empty tuple). + static Literal MoveIntoTuple( + tensorflow::gtl::MutableArraySlice elements); // Creates a new value that has the equivalent value as this literal, but // conforms to new_layout; e.g. a literal matrix that was in {0, 1} @@ -285,11 +290,16 @@ class Literal { std::unique_ptr Relayout(const Layout& new_layout, const ShapeIndex& shape_index = {}) const; - // Creates a new literal by reshaping this literal to have 'shape'. Both the - // original shape and 'shape' must contain the same number of elements. The + // An overload of Relayout which changes the layout of the entire shape rather + // than being limited to a single array within the shape. + std::unique_ptr Relayout(const Shape& shape_with_layout) const; + + // Creates a new literal by reshaping this literal to have the given + // dimensions. The total number of elements must not change; The // implementation currently only supports monotonic dim0-major layouts. + // This literal must be an array. StatusOr> Reshape( - tensorflow::gtl::ArraySlice shape) const; + tensorflow::gtl::ArraySlice dimensions) const; // Creates a new literal by reordering the dimensions of this literal. // The given `permutation` must be a permutation of the dimension numbers @@ -297,6 +307,7 @@ class Literal { // in the result literal (i.e., new_order[i] = old_order[permutation[i]]). // For example, a transpose call on a literal of shape [3 x 8 x 4] and // `permutation` = {2, 0, 1} returns a new literal of shape [4 x 3 x 8]. + // This literal must be an array. std::unique_ptr Transpose( tensorflow::gtl::ArraySlice permutation) const; @@ -305,6 +316,7 @@ class Literal { // same rank and layout as for the given literal. The number of indices in // start_indices and limit_indices must be the rank of the literal, and the // indices follow the order of the dimensions. + // This literal must be an array. std::unique_ptr Slice( tensorflow::gtl::ArraySlice start_indices, tensorflow::gtl::ArraySlice limit_indices) const; @@ -312,34 +324,35 @@ class Literal { // Creates a literal with a prepended dimension with bound "times"; e.g. a // f32[3x2] with times=4 will produce a f32[4x3x2] with the 3x2 from this // literal replicated four times. + // This literal must be an array. template std::unique_ptr Replicate(int64 times) const; // Converts this literal to another primitive type. Returns an error if the - // conversion is not possible. + // conversion is not possible. This literal must be array-shaped. StatusOr> Convert( PrimitiveType primitive_dest_type) const; - // Creates a literal value zero of the given primitive type. + // Creates a scalar literal value zero of the given primitive type. static Literal Zero(PrimitiveType primitive_type); - // Creates a literal value one of the given primitive type. + // Creates a scalar literal value one of the given primitive type. static Literal One(PrimitiveType primitive_type); - // Creates a literal value containing the minimum value of the given + // Creates a scalar literal value containing the minimum value of the given // primitive type. For floating-point types, returns -inf. static Literal MinValue(PrimitiveType primitive_type); - // Creates a literal value containing the maximum value of the given + // Creates a scalar literal value containing the maximum value of the given // primitive type. For floating-point types, returns inf. static Literal MaxValue(PrimitiveType primitive_type); // Creates a literal of the given shape where each element is `value`. template - static std::unique_ptr CreateFullWithMonotonicDim0MajorLayout( + static std::unique_ptr CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value); - // Creates a new literal from an array. The variants not ending with + // Creates a new literal from an Array type. The variants not ending with // WithLayout use the default XLA layout for the literal's linear // representation in memory. template @@ -388,35 +401,50 @@ class Literal { std::initializer_list> values, int64 projection_p, int64 projection_z); - // Clones this literal into an owned unique_ptr version. + // Clones this literal into a new Literal, or new std::unique_ptr. + Literal Clone() const; std::unique_ptr CloneToUnique() const; - // Returns the linear index of the given index within this literal's - // element_type repeated field. - int64 LinearIndex(tensorflow::gtl::ArraySlice multi_index) const; + // Gets or sets an element in the literal at the given index. The multi_index + // is CHECKed against the dimension sizes. + template + NativeT Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const; + template + void Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value); - // Gets or sets an element in the literal at the given index. The index is - // CHECKed against the dimension sizes. + // Overloads of Get and Set for array literals. CHECKs if the literal is not + // array-shaped and dense. template NativeT Get(tensorflow::gtl::ArraySlice multi_index) const; template void Set(tensorflow::gtl::ArraySlice multi_index, NativeT value); - // Returns a (Mutable)ArraySlice view of the array for this literal for the - // given NativeT (e.g., float). These functions map native type to XLA - // PrimitiveType via template specialization. The unspecialized forms below - // aborts to handle the error case where the given native type does not map to - // an XLA primitive type. + // Returns the multi-index of the element in a sparse literal at the given + // sparse element number. The sparse element number is the position with in + // the sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. + tensorflow::gtl::ArraySlice GetSparseIndex( + int64 sparse_element_number, const ShapeIndex& shape_index = {}) const; + + // Returns the value of the element in a sparse literal at the given sparse + // element number. The sparse element number is the position with in the + // sparse array's list of (index, value) pairs, and is checked against the + // total number of (index, value) pairs in the sparse array. template - tensorflow::gtl::ArraySlice GetArraySlice() const { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } + NativeT GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; + + // Appends the given element to the literal. If the elements are not appended + // in sorted order, then SortSparseElements should be called before calling + // other methods. This literal must have a sparse layout. template - tensorflow::gtl::MutableArraySlice GetMutableArraySlice() { - static_assert(!std::is_same::value, - "Cannot map native type to primitive type."); - } + void AppendSparseElement(tensorflow::gtl::ArraySlice multi_index, + NativeT value, const ShapeIndex& shape_index = {}); + + // Sorts the elements in a sparse array. + void SortSparseElements(const ShapeIndex& shape_index = {}); // Returns the element value at index (0, ..., 0), however many zeroes are // required for that index. @@ -425,10 +453,16 @@ class Literal { // As Get(), but determines the correct type and converts the value // into text. - string GetAsString(tensorflow::gtl::ArraySlice multi_index) const; + string GetAsString(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index = {}) const; + + // As GetSparseElement(), but determines the correct type and converts the + // value into text. + string GetSparseElementAsString(int64 sparse_element_number, + const ShapeIndex& shape_index = {}) const; // As Get(), but determines the correct type and converts the value into - // int64. + // int64. This literal must be an array. StatusOr GetIntegralAsS64( tensorflow::gtl::ArraySlice multi_index) const; @@ -436,7 +470,8 @@ class Literal { template static std::unique_ptr MakeIdentityR2(int64 size); - // Returns a tuple literal composed of given literals. + // Returns a tuple literal composed of given literals. Data is copied from the + // given elements into the returned literal. static std::unique_ptr MakeTuple( tensorflow::gtl::ArraySlice elements); @@ -450,11 +485,29 @@ class Literal { static std::unique_ptr MakeTupleOwned( std::vector> elements); - // Validates that the data payload of the literal matches the literal shape; - // if it does not, an appropriate status is returned. - tensorflow::Status ValidateLiteral() const; + // This overload lets you pass a braced list of unique_ptrs to + // MakeTupleOwned: + // + // Literal::MakeTupleOwned(Literal::CreateR1(...), ...). + // + // Simply relying on the MakeTupleOwned(std::vector>) + // overload doesn't work because std::initializer_list's elements are always + // const. + // + // The arguments to this function must all be unique_ptr. + template + static std::unique_ptr MakeTupleOwned( + std::unique_ptr... elements) { + std::array, sizeof...(Ts)> arr{ + std::move(elements)...}; + std::vector> v; + v.insert(v.begin(), std::make_move_iterator(arr.begin()), + std::make_move_iterator(arr.end())); + return MakeTupleOwned(std::move(v)); + } // Returns a string representation of the literal value. + // Warning: this function can take minutes for multi-million element Literals. string ToString(bool print_layout = false) const; // Invokes the "per cell" callback for each element in the provided @@ -464,6 +517,8 @@ class Literal { // This function is useful if you want a polymorphic representation // of the tensor's elements (turning it to a string for something // like representation in a protobuf). + // + // This literal must have a dense layout. void EachCellAsString( const std::function indices, const string& value)>& per_cell) const; @@ -472,80 +527,45 @@ class Literal { NativeT value)> per_cell) const; - // Templated methods which populate the given repeated field in this literal - // with the given value(s). The Shape field of this literal is set - // to match the array dimensions and type. Examples: + // Populate this literal with the given values. Examples: // // // Populate with floats. // Array2D float_values = ... // literal.PopulateR2FromArray2D(values); // // // Populate with int32s. - // literal.PopulateR2({{1, 2}, {3, 4}}); + // literal.PopulateR2({{1, 2}, {3, 4}}); // - template - void PopulateR0(NativeT values); + // The shape and element type of this literal must match given values. For + // example, in the call above to literal.PopulateR2(), 'literal' must be a 2x2 + // array of S32. template void PopulateR1(tensorflow::gtl::ArraySlice values); void PopulateR1(const tensorflow::core::Bitmap& values); template void PopulateR2(std::initializer_list> values); template - void PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout); - template void PopulateFromArray(const Array& values); template - void PopulateFromArrayWithLayout(const Array& values, - const Layout& layout); - template void PopulateR2FromArray2D(const Array2D& values); template - void PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout); - template void PopulateR3FromArray3D(const Array3D& values); template - void PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout); - template void PopulateR4FromArray4D(const Array4D& values); - template - void PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout); // Populates literal values by calling the generator function for every cell // in this literal object. // // generator must be a callable of the type // NativeT(tensorflow::gtl::ArraySlice indexes) or compatible. + // + // This literal must have a dense layout. template Status Populate(const FnType& generator); - // Creates a Literal of the given dimensions with all elements set to the - // given value. - template - void PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions); - - // Returns a pointer to the underlying vector corresponding to the Literal's - // shape. - const void* InternalData() const; - void* MutableInternalData(); - - // Allocates space in the underlying vector of this literal sufficient to hold - // num_elements of this literal's primitive type. Values in the vector are set - // to zero. num_elements must equal the number of elements in the literal's - // shape. - void Reserve(int64 num_elements); - - // Allocates space in the underlying vector of this literal sufficient to hold - // num_elements of this literal's primitive type and sets each element in this - // literal to the given value. num_elements must equal the number of elements - // in this literal's shape. + // Fills this literal with the given value. template - void Resize(int64 num_elements, NativeT value); + void PopulateWithValue(NativeT value); // Returns whether every element in this literal is equal to value. // @@ -555,7 +575,7 @@ class Literal { // // If value doesn't fit in this literal's type, returns false. Values of 1/0 // are considered equal to true/false; other values are not considered equal - // to true. + // to true. Also if this literal is not array-shaped false is returned. bool IsAll(int8 value) const; // Like IsAll(const Literal&, int8), except we check whether the literal is @@ -566,7 +586,7 @@ class Literal { // This casts value to the type of literal, then compares using ==. The usual // admonishments about floating-point equality checks apply. We expect you to // use this to check for values that can be expressed precisely as a float, - // e.g. -0.5. + // e.g. -0.5. Also if this literal is not array-shaped false is returned. bool IsAllFloat(float value) const; // Like IsAll(const Literal&, int8), except we check whether the literal is @@ -578,23 +598,38 @@ class Literal { // admonishments about floating-point equality checks apply. We expect you to // use this to check for complex values that can be expressed precisely as // float pairs e.g. (-0.5, 1.0). + // + // This literal must have a dense layout. bool IsAllComplex(complex64 value) const; // Returns whether this literal is zero at the specified index. This literal - // must be an array. + // must be an array with a dense layout. bool IsZero(tensorflow::gtl::ArraySlice indices) const; - private: - // Copy from a LiteralProto instance. - void CopyFromProto(const LiteralProto& literal_proto); + // Return the count of the elements in the array at the given shape index in + // this literal. + int64 element_count(const ShapeIndex& index = {}) const { + return ShapeUtil::ElementsIn(ShapeUtil::GetSubshape(shape(), index)); + } + + // Return the count of the elements in the sparse array at the given shape + // index in this literal, which will be no larger than + // LayoutUtil::MaxSparseElements(SetSubshape(shape(), index).layout()). + int64 sparse_element_count() const; + + protected: + // 'allocate_arrays' indicates whether to allocate memory for the arrays in + // the shape. If false, buffer pointers inside of the Literal::Pieces are set + // to nullptr. + Literal(const Shape& shape, bool allocate_arrays); - // Internal template helper for the Copy() API, matching its arguments one by - // one. - template - Status CopyRange(const Literal& src_literal, - tensorflow::gtl::ArraySlice src_base, - tensorflow::gtl::ArraySlice dest_base, - tensorflow::gtl::ArraySlice copy_size); + // Internal template helper for the Literal::CopySliceFrom(), matching its + // arguments one by one. + template + Status CopySliceFromInternal(const Literal& src_literal, + tensorflow::gtl::ArraySlice src_base, + tensorflow::gtl::ArraySlice dest_base, + tensorflow::gtl::ArraySlice copy_size); // Utility structure which is used to create the optimal configuration for // a ShapeUtil::ForEachIndex() scan across two literals. @@ -619,163 +654,243 @@ class Literal { int64 minor_loop_size = 1; }; - Shape shape_; - std::vector u8s_; - std::vector s16s_; - std::vector s32s_; - std::vector s64s_; - std::vector u16s_; - std::vector u32s_; - std::vector u64s_; - std::vector bf16s_; - std::vector f16s_; - std::vector f32s_; - std::vector f64s_; - std::vector c64s_; - std::vector tuple_literals_; -}; - -std::ostream& operator<<(std::ostream& out, const Literal& literal); - -// Declarations of template specializations for GetArraySlice and -// GetMutableArraySlice. The specializations map native type to XLA primitive -// type. -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; - -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // A data structure representing a subshape at a particular ShapeIndex within + // the literal. For array-shaped ShapeIndexes, this data structure holds the + // pointer to the memory allocated for the array data. + class Piece { + public: + // Return the buffer holding the array data for this piece as an array + // slice. This piece must be array-shaped. + template + tensorflow::gtl::ArraySlice data() const; + template + tensorflow::gtl::MutableArraySlice data(); + + // Return the buffer holding the array data for this piece as a void*. This + // piece must be array-shaped. + void* untyped_data(); + const void* untyped_data() const; + + // Gets or sets an element in the array at the given index. The multi_index + // is CHECKed against the dimension sizes of the array. This piece must be + // array-shaped. + template + NativeT Get(tensorflow::gtl::ArraySlice index) const; + template + void Set(tensorflow::gtl::ArraySlice index, NativeT value); + + // Gets/sets the buffer holding the array data. + char* buffer() const { return buffer_; } + void set_buffer(char* buffer) { buffer_ = buffer; } + + // The array of multi-indices that provide the locations of non-zero + // elements in a sparse array. Only used if + // LayoutUtil::IsSparseArray(shape()) is true. + SparseIndexArray* sparse_indices() const { return sparse_indices_; } + void set_sparse_indices(SparseIndexArray* sparse_indices) { + sparse_indices_ = sparse_indices; + } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Gets or sets the subshape of this piece. This reference points to a + // subshape within the shape in the containing Literal (Literal::shape_). + const Shape& subshape() const { return *subshape_; } + void set_subshape(const Shape* subshape) { subshape_ = subshape; } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Returns the size in bytes of the buffer holding the array data. + int64 size_bytes() const { return ShapeUtil::ByteSizeOf(subshape()); } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Returns the number of elements in this piece's array. + int64 element_count() const { return ShapeUtil::ElementsIn(subshape()); } -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Copy the data from 'src' into this piece's buffer. Shapes of this piece + // and src must be compatible. + Status CopyFrom(const Piece& src); -template <> -inline tensorflow::gtl::ArraySlice Literal::GetArraySlice() - const { - DCHECK(shape().element_type() == F32); - return f32s(); -} + // Returns true if this piece and 'other' contain the same data. This piece + // and 'other' must be array-shaped and compatible. + bool EqualElements(const Piece& other) const; -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Writes the shape and data (if array-shaped) into the given proto. + void WriteToProto(LiteralProto* proto) const; -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Copies the data from the given proto into this piece. The shape of this + // piece must be equal (not just compatible) to the shape of the proto. + Status CopyFromProto(const LiteralProto& proto); -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; + // Sorts the elements in a sparse array. + void SortSparseElements(); -template <> -tensorflow::gtl::ArraySlice Literal::GetArraySlice() - const; + private: + // Recursive helper for EqualElements. + template + bool EqualElementsInternal(const Piece& other, + std::vector* multi_index) const; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Helper for SortSparseElements that has the element type as a template + // parameter. + template + void SortSparseElementsInternal(); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // For array-shaped pieces, this is the buffer holding the literal data. + char* buffer_ = nullptr; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // For sparse arrays, this is the array of indices. + SparseIndexArray* sparse_indices_ = nullptr; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // The shape of piece. This points into the shape of the containing Literal + // (Literal::shape_). + const Shape* subshape_ = nullptr; + }; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Returns the piece at the given ShapeIndex. + Piece& piece(const ShapeIndex& shape_index) { + return *pieces_.mutable_element(shape_index); + } + const Piece& piece(const ShapeIndex& shape_index) const { + return pieces_.element(shape_index); + } -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Returns the piece at the root of the shape (empty ShapeIndex). + Piece& root_piece() { return piece({}); } + const Piece& root_piece() const { return piece({}); } -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Deallocate the buffers held by this literal (if the literal owns the + // buffer). + void DeallocateBuffers(); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + Shape shape_; + ShapeTree pieces_; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // Whether the buffers held in pieces_ are owned by this Literal. + bool owns_buffers_; -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + // LiteralView must access and manipulate Pieces of other Literals. + friend class LiteralView; +}; // namespace xla -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +std::ostream& operator<<(std::ostream& out, const Literal& literal); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +// A read-only view of a Literal. A LiteralView contains pointers to buffers +// owned by the viewed Literal. +// +// TODO(b/71550060): Replace LiteralView with Literal slice classes (immutable +// and mutable) similar to (Mutable)ArraySlice. +class LiteralView : public Literal { + public: + // Create and return a view of the given literal rooted at the given shape + // index within the given literal. A factory is used rather than a public + // constructor because only const LiteralViews are supported. It's still + // possible to create non-const LiteralViews via the copy constructors, but + // the factory method makes it a bit less likely. Implementing literal slices + // will fix this undesirable situation (b/71550060). + static const LiteralView Create(const Literal& literal, + const ShapeIndex& view_root = {}); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + LiteralView(const LiteralView& other); + LiteralView& operator=(const LiteralView& other); -template <> -tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + virtual ~LiteralView(); -template <> -void Literal::Resize(int64 num_elements, bool value); + private: + LiteralView(const Literal& literal, const ShapeIndex& view_root); -template <> -void Literal::Resize(int64 num_elements, int8 value); + // Helper for the copy constructor and copy assignment operator. + void CopyFrom(const LiteralView& other); +}; -template <> -void Literal::Resize(int64 num_elements, uint8 value); +template +tensorflow::gtl::ArraySlice Literal::Piece::data() const { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::ArraySlice( + reinterpret_cast(buffer()), + ShapeUtil::ElementsIn(subshape())); +} -template <> -void Literal::Resize(int64 num_elements, int32 value); +template +tensorflow::gtl::MutableArraySlice Literal::Piece::data() { + CHECK(ShapeUtil::IsArray(subshape())) << ShapeUtil::HumanString(subshape()); + CHECK_EQ(subshape().element_type(), + primitive_util::NativeToPrimitiveType()) + << "Attempting to access " + << PrimitiveType_Name(primitive_util::NativeToPrimitiveType()) + << " type, but literal element type is " + << PrimitiveType_Name(subshape().element_type()); + return tensorflow::gtl::MutableArraySlice( + reinterpret_cast(buffer()), ShapeUtil::ElementsIn(subshape())); +} -template <> -void Literal::Resize(int64 num_elements, uint32 value); +template +NativeT Literal::Piece::Get( + tensorflow::gtl::ArraySlice multi_index) const { + CHECK(LayoutUtil::IsDenseArray(subshape())); + return data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)]; +} -template <> -void Literal::Resize(int64 num_elements, int64 value); +template +void Literal::Piece::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + CHECK(LayoutUtil::IsDenseArray(subshape())); + data()[IndexUtil::MultidimensionalIndexToLinearIndex( + subshape(), multi_index)] = value; +} -template <> -void Literal::Resize(int64 num_elements, uint64 value); +template +tensorflow::gtl::ArraySlice Literal::data( + const ShapeIndex& shape_index) const { + return piece(shape_index).data(); +} -template <> -void Literal::Resize(int64 num_elements, float value); +template +tensorflow::gtl::MutableArraySlice Literal::data( + const ShapeIndex& shape_index) { + return piece(shape_index).data(); +} -template <> -void Literal::Resize(int64 num_elements, double value); +template +inline NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index) const { + return piece(shape_index).Get(multi_index); +} -template <> -void Literal::Resize(int64 num_elements, half value); +template +inline NativeT Literal::Get( + tensorflow::gtl::ArraySlice multi_index) const { + return root_piece().Get(multi_index); +} -template <> -void Literal::Resize(int64 num_elements, bfloat16 value); +template +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + const ShapeIndex& shape_index, NativeT value) { + return piece(shape_index).Set(multi_index, value); +} -template <> -void Literal::Resize(int64 num_elements, complex64 value); +template +inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, + NativeT value) { + return root_piece().Set(multi_index, value); +} template /* static */ std::unique_ptr Literal::CreateR0(NativeT value) { - auto literal = MakeUnique(); - literal->PopulateR0(value); + auto literal = MakeUnique(ShapeUtil::MakeShape( + primitive_util::NativeToPrimitiveType(), {})); + literal->Set({}, value); return literal; } template /* static */ std::unique_ptr Literal::CreateR1( tensorflow::gtl::ArraySlice values) { - auto literal = MakeUnique(); + auto literal = MakeUnique( + ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), + {static_cast(values.size())})); literal->PopulateR1(values); return literal; } @@ -784,8 +899,12 @@ template /* static */ std::unique_ptr Literal::CreateR2WithLayout( std::initializer_list> values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR2WithLayout(values, layout); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), + {static_cast(values.size()), + static_cast(values.begin()->size())}, + AsInt64Slice(layout.minor_to_major()))); + literal->PopulateR2(values); return literal; } @@ -858,6 +977,21 @@ template return CreateR4FromArray4DWithLayout(tmp, layout); } +template +/* static */ std::unique_ptr Literal::CreateSparse( + tensorflow::gtl::ArraySlice dimensions, SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, bool sort) { + int64 num_elements = values.size(); + int64 rank = dimensions.size(); + CHECK_EQ(num_elements, indices.index_count()); + CHECK_EQ(rank, indices.rank()); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithSparseLayout( + primitive_util::NativeToPrimitiveType(), dimensions, + indices.max_indices())); + literal->PopulateSparse(indices, values, sort); + return literal; +} + template /* static */ std::unique_ptr Literal::CreateR4( std::initializer_list template /* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( const Array& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateFromArrayWithLayout(values, layout); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithLayout( + primitive_util::NativeToPrimitiveType(), values.dimensions(), + AsInt64Slice(layout.minor_to_major()))); + literal->PopulateFromArray(values); return literal; } @@ -970,81 +1106,33 @@ template return CreateFromArrayWithLayout(values, layout); } -template -NativeT Literal::Get(tensorflow::gtl::ArraySlice multi_index) const { - int64 linear_index = LinearIndex(multi_index); - return GetArraySlice().at(linear_index); -} - template NativeT Literal::GetFirstElement() const { - return GetArraySlice().at(0); -} - -template <> -inline uint8 Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == U8); - int64 linear_index = LinearIndex(multi_index); - return u8s()[linear_index]; -} - -template <> -inline int8 Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == S8); - int64 linear_index = LinearIndex(multi_index); - return u8s()[linear_index]; -} - -template <> -inline half Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == F16); - int64 linear_index = LinearIndex(multi_index); - return GetArraySlice()[linear_index]; -} - -template <> -inline bfloat16 Literal::Get( - tensorflow::gtl::ArraySlice multi_index) const { - CHECK(shape().element_type() == BF16); - int64 linear_index = LinearIndex(multi_index); - return GetArraySlice()[linear_index]; + return data().at(0); } template -void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - NativeT value) { - int64 linear_index = LinearIndex(multi_index); - GetMutableArraySlice().at(linear_index) = value; +NativeT Literal::GetSparseElement(int64 sparse_element_number, + const ShapeIndex& shape_index) const { + CHECK( + LayoutUtil::IsSparseArray(ShapeUtil::GetSubshape(shape(), shape_index))); + return data(shape_index)[sparse_element_number]; } -template <> -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - uint8 value) { - int64 linear_index = LinearIndex(multi_index); - (*mutable_u8s())[linear_index] = value; -} - -template <> -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - int8 value) { - return Set(multi_index, value); -} - -template <> -inline void Literal::Set(tensorflow::gtl::ArraySlice multi_index, - int64 value) { - int64 linear_index = LinearIndex(multi_index); - (*mutable_s64s())[linear_index] = value; -} - -template <> -/* static */ inline void Literal::Set( - tensorflow::gtl::ArraySlice multi_index, uint64 value) { - int64 linear_index = LinearIndex(multi_index); - (*mutable_u64s())[linear_index] = value; +template +void Literal::AppendSparseElement( + tensorflow::gtl::ArraySlice multi_index, NativeT value, + const ShapeIndex& shape_index) { + Piece& p = piece(shape_index); + const Shape& subshape = p.subshape(); + CHECK(LayoutUtil::IsSparseArray(subshape)); + int64 rank = ShapeUtil::Rank(subshape); + CHECK_EQ(multi_index.size(), rank); + int64 last_element = p.sparse_indices()->index_count(); + CHECK_LT(last_element, LayoutUtil::MaxSparseElements(subshape.layout())); + p.sparse_indices()->Append(multi_index); + CHECK_LT(last_element, p.data().size()); + p.data()[last_element] = value; } // Returns an identity matrix (rank 2) with the given row and column count. @@ -1071,51 +1159,31 @@ void Literal::EachCell( } while (IndexUtil::BumpIndices(shape(), &indices)); } -template -inline void Literal::PopulateR0(NativeT value) { - *mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), {}); - Resize(1, value); -} - template inline void Literal::PopulateR1(tensorflow::gtl::ArraySlice values) { - *mutable_shape() = - ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), - {static_cast(values.size())}); - Reserve(values.size()); + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 1); + CHECK_EQ(ShapeUtil::ElementsIn(shape()), values.size()); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); for (int64 i = 0; i < values.size(); ++i) { Set({i}, values[i]); } } -inline void Literal::PopulateR1(const tensorflow::core::Bitmap& values) { - *mutable_shape() = - ShapeUtil::MakeShape(PRED, {static_cast(values.bits())}); - Reserve(values.bits()); - for (int64 i = 0; i < static_cast(values.bits()); ++i) { - Set({i}, values.get(i)); - } -} - template -void Literal::PopulateR2WithLayout( - std::initializer_list> values, - const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {static_cast(values.size()), - static_cast(values.begin()->size())}, - AsInt64Slice(layout.minor_to_major())); +void Literal::PopulateR2( + std::initializer_list> values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(ShapeUtil::Rank(shape()), 2); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); const int64 dim0_size = values.size(); const int64 dim1_size = values.begin()->size(); CHECK_EQ(dim0_size, shape().dimensions(0)); CHECK_EQ(dim1_size, shape().dimensions(1)); - const int64 num_elements = dim1_size * dim0_size; - Reserve(num_elements); - int64 dim0 = 0; for (auto inner_list : values) { int64 dim1 = 0; @@ -1129,69 +1197,65 @@ void Literal::PopulateR2WithLayout( } template -void Literal::PopulateR2( - std::initializer_list> values) { - PopulateR2WithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); -} - -template -void Literal::PopulateFromArrayWithLayout(const Array& values, - const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), values.dimensions(), - AsInt64Slice(layout.minor_to_major())); - Reserve(values.num_elements()); +void Literal::PopulateFromArray(const Array& values) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + CHECK_EQ(ShapeUtil::Rank(shape()), values.num_dimensions()); + for (int dim = 0; dim < values.num_dimensions(); ++dim) { + CHECK_EQ(values.dim(dim), shape().dimensions(dim)); + } values.Each([this](tensorflow::gtl::ArraySlice indices, NativeT value) { this->Set(indices, value); }); } -template -void Literal::PopulateFromArray(const Array& values) { - PopulateFromArrayWithLayout( - values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); -} - -template -void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { - PopulateFromArrayWithLayout(values, layout); -} - template void Literal::PopulateR2FromArray2D(const Array2D& values) { PopulateFromArray(values); } -template -void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, - const Layout& layout) { - PopulateFromArrayWithLayout(values, layout); -} - template void Literal::PopulateR3FromArray3D(const Array3D& values) { PopulateFromArray(values); } template -void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, - const Layout& layout) { - PopulateFromArrayWithLayout(values, layout); +void Literal::PopulateR4FromArray4D(const Array4D& values) { + PopulateFromArray(values); } template -void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateFromArray(values); +void Literal::PopulateSparse(SparseIndexArray indices, + tensorflow::gtl::ArraySlice values, + bool sort) { + CHECK(LayoutUtil::IsSparseArray(shape())); + int rank = ShapeUtil::Rank(shape()); + CHECK_EQ(indices.rank(), rank); + int64 max_elements = LayoutUtil::MaxSparseElements(shape().layout()); + CHECK_LE(indices.max_indices(), max_elements); + int64 num_elements = values.size(); + CHECK_LE(num_elements, max_elements); + CHECK_EQ(num_elements, indices.index_count()); + auto root_data = root_piece().data(); + root_data.remove_suffix(max_elements - values.size()); + std::copy(values.begin(), values.end(), root_data.begin()); + *this->root_piece().sparse_indices() = std::move(indices); + if (sort) { + auto root_data = this->root_piece().data(); + root_data.remove_suffix(root_data.size() - num_elements); + this->root_piece().sparse_indices()->SortWithValues(root_data); + } + DCHECK(this->root_piece().sparse_indices()->Validate(shape())); } template Status Literal::Populate(const FnType& generator) { const Shape& this_shape = shape(); const int64 rank = ShapeUtil::Rank(this_shape); + TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape)); TF_RET_CHECK(this_shape.element_type() == primitive_util::NativeToPrimitiveType()); - tensorflow::gtl::MutableArraySlice data = - GetMutableArraySlice(); + tensorflow::gtl::MutableArraySlice literal_data = data(); if (rank > 0) { StrideConfig stride_config(this_shape, this_shape, AsInt64Slice(this_shape.dimensions())); @@ -1200,11 +1264,12 @@ Status Literal::Populate(const FnType& generator) { ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension); auto init_function = [&](const std::vector& indexes) { - const int64 index = LinearIndex(indexes); + const int64 index = + IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes); std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin()); for (int64 i = 0; i < minor_dimension_size; ++i) { minor_scan_indexes[stride_config.minor_dimension] = i; - data.at(index + i) = generator(minor_scan_indexes); + literal_data.at(index + i) = generator(minor_scan_indexes); } return true; }; @@ -1213,32 +1278,27 @@ Status Literal::Populate(const FnType& generator) { init_function); } else { // For scalars. - data.at(0) = generator({}); + literal_data.at(0) = generator({}); } return Status::OK(); } template -void Literal::PopulateWithValue(NativeT value, - tensorflow::gtl::ArraySlice dimensions) { - *mutable_shape() = ShapeUtil::MakeShape( - primitive_util::NativeToPrimitiveType(), dimensions); - Resize(ShapeUtil::ElementsIn(shape()), value); +void Literal::PopulateWithValue(NativeT value) { + CHECK(ShapeUtil::IsArray(shape())); + CHECK_EQ(shape().element_type(), + primitive_util::NativeToPrimitiveType()); + for (NativeT& element : data()) { + element = value; + } } template -/* static */ std::unique_ptr -Literal::CreateFullWithMonotonicDim0MajorLayout( +/* static */ std::unique_ptr Literal::CreateFullWithDescendingLayout( tensorflow::gtl::ArraySlice dimensions, NativeT value) { - Shape this_shape = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - primitive_util::NativeToPrimitiveType(), dimensions); - auto literal = MakeUnique(); - *literal->mutable_shape() = this_shape; - literal->Reserve(ShapeUtil::ElementsIn(this_shape)); - std::vector index(dimensions.size(), 0); - do { - literal->Set(index, value); - } while (IndexUtil::BumpIndices(this_shape, &index)); + auto literal = MakeUnique(ShapeUtil::MakeShapeWithDescendingLayout( + primitive_util::NativeToPrimitiveType(), dimensions)); + literal->PopulateWithValue(value); return literal; } @@ -1249,14 +1309,12 @@ std::unique_ptr Literal::Replicate(int64 times) const { for (int64 bound : shape().dimensions()) { bounds.push_back(bound); } - auto literal = MakeUnique(); - *literal->mutable_shape() = - ShapeUtil::MakeShape(shape().element_type(), bounds); + auto literal = + MakeUnique(ShapeUtil::MakeShape(shape().element_type(), bounds)); int64 elements = ShapeUtil::ElementsIn(literal->shape()); if (elements == 0) { return literal; } - literal->Reserve(elements); DimensionVector output_indices(bounds.size(), 0); tensorflow::gtl::ArraySlice input_indices = output_indices; diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index 816bb3c549eaae4e8fc2b7d438627266603272f9..b3583c2eb75de8297d5e7507430491f119bd4462 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -31,6 +31,7 @@ namespace xla { namespace { using ::testing::ElementsAre; +using ::testing::HasSubstr; class LiteralUtilTest : public ::testing::Test { protected: @@ -192,6 +193,34 @@ TEST_F(LiteralUtilTest, CreateR3FromArray3d) { ASSERT_EQ(expected, result); } +TEST_F(LiteralUtilTest, CreateSparse) { + std::vector dimensions = {8, 8, 8}; + Array2D indices = { + {3, 4, 5}, + {1, 2, 3}, + {2, 3, 4}, + {3, 5, 6}, + }; + std::vector values = {7, 8, 9, 10}; + auto literal = Literal::CreateSparse( + dimensions, SparseIndexArray(indices.n1() + 3, indices), values); + + Array2D expected_indices = { + {1, 2, 3}, + {2, 3, 4}, + {3, 4, 5}, + {3, 5, 6}, + }; + std::vector expected_values = {8, 9, 7, 10}; + + EXPECT_EQ(literal->sparse_indices()->data(), + tensorflow::gtl::ArraySlice( + expected_indices.data(), expected_indices.num_elements())); + EXPECT_EQ(tensorflow::gtl::ArraySlice(literal->data().data(), + expected_values.size()), + tensorflow::gtl::ArraySlice(expected_values)); +} + TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { // clang-format off auto literal = Literal::CreateR4Projected({ @@ -293,29 +322,28 @@ TEST_F(LiteralUtilTest, NonScalarEquality) { auto matrix_different = Literal::CreateR2({{4.0, 3.0}, {1.0, 2.0}}); auto vector_literal = Literal::CreateR1({1.0, 2.0, 3.0, 4.0}); auto scalar = Literal::CreateR0(1.0); + Literal nil(ShapeUtil::MakeNil()); EXPECT_EQ(*matrix, *matrix); EXPECT_EQ(*matrix, *matrix_clone); EXPECT_NE(*matrix, *matrix_different); EXPECT_NE(*matrix, *vector_literal); EXPECT_NE(*matrix, *scalar); + EXPECT_NE(*matrix, nil); + EXPECT_EQ(nil, nil); } TEST_F(LiteralUtilTest, DifferentLayoutEquality) { // Test equality with literals which have different layouts. - auto colmajor = MakeUnique(); - *colmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); - *colmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - colmajor->Reserve(4); + auto colmajor = + MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})); colmajor->Set({0, 0}, 1.0); colmajor->Set({0, 1}, 2.0); colmajor->Set({1, 0}, 3.0); colmajor->Set({1, 1}, 4.0); - auto rowmajor = MakeUnique(); - *rowmajor->mutable_shape() = ShapeUtil::MakeShape(F32, {2, 2}); - *rowmajor->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - rowmajor->Reserve(4); + auto rowmajor = + MakeUnique(ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})); rowmajor->Set({0, 0}, 1.0); rowmajor->Set({0, 1}, 2.0); rowmajor->Set({1, 0}, 3.0); @@ -515,7 +543,7 @@ TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { TEST_F(LiteralUtilTest, ReshapeR0) { auto original = Literal::CreateR0(1.7f); - auto reshape = original->Reshape(/*shape=*/{}).ConsumeValueOrDie(); + auto reshape = original->Reshape(/*dimensions=*/{}).ConsumeValueOrDie(); EXPECT_EQ(*original, *reshape); } @@ -597,24 +625,26 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) { TEST_F(LiteralUtilTest, TestR2LinearLayout) { // Test expected memory layout of R2 dim0-minor (column-major) literal. - auto mat_dim0minor = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, - layout_r2_dim0minor_); - EXPECT_EQ(mat_dim0minor->s32s_size(), 6); - EXPECT_THAT(mat_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); + auto mat_dim0minor = Literal::CreateR2WithLayout( + {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0minor_); + EXPECT_EQ(mat_dim0minor->element_count(), 6); + EXPECT_THAT(mat_dim0minor->data(), ElementsAre(1, 4, 2, 5, 3, 6)); // Test expected memory layout when using Relayout to row major. auto relaid_mat_to_dim0major = mat_dim0minor->Relayout(layout_r2_dim0major_); - EXPECT_THAT(relaid_mat_to_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); + EXPECT_THAT(relaid_mat_to_dim0major->data(), + ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout of R2 created with dim0-major (row-major). - auto mat_dim0major = Literal::CreateR2WithLayout({{1, 2, 3}, {4, 5, 6}}, - layout_r2_dim0major_); - EXPECT_EQ(mat_dim0major->s32s_size(), 6); - EXPECT_THAT(mat_dim0major->s32s(), ElementsAre(1, 2, 3, 4, 5, 6)); + auto mat_dim0major = Literal::CreateR2WithLayout( + {{1, 2, 3}, {4, 5, 6}}, layout_r2_dim0major_); + EXPECT_EQ(mat_dim0major->element_count(), 6); + EXPECT_THAT(mat_dim0major->data(), ElementsAre(1, 2, 3, 4, 5, 6)); // Test expected memory layout when using Relayout to column major. auto relaid_mat_to_dim0minor = mat_dim0major->Relayout(layout_r2_dim0minor_); - EXPECT_THAT(relaid_mat_to_dim0minor->s32s(), ElementsAre(1, 4, 2, 5, 3, 6)); + EXPECT_THAT(relaid_mat_to_dim0minor->data(), + ElementsAre(1, 4, 2, 5, 3, 6)); } TEST_F(LiteralUtilTest, TestR3LinearLayout) { @@ -634,27 +664,27 @@ TEST_F(LiteralUtilTest, TestR3LinearLayout) { auto lit_dim0minor = Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0minor_); - EXPECT_EQ(lit_dim0minor->s32s_size(), 12); + EXPECT_EQ(lit_dim0minor->element_count(), 12); std::vector expected_dim0minor{1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12}; - EXPECT_THAT(lit_dim0minor->s32s(), + EXPECT_THAT(lit_dim0minor->data(), testing::ElementsAreArray(expected_dim0minor)); // Test expected memory layout when using Relayout to row major. auto relaid_lit_to_dim0major = lit_dim0minor->Relayout(layout_r3_dim0major_); std::vector expected_dim0major{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}; - EXPECT_THAT(relaid_lit_to_dim0major->s32s(), + EXPECT_THAT(relaid_lit_to_dim0major->data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout of R3 created with dim0-major (row-major). auto lit_dim0major = Literal::CreateR3FromArray3DWithLayout(arr3d, layout_r3_dim0major_); - EXPECT_EQ(lit_dim0major->s32s_size(), 12); - EXPECT_THAT(lit_dim0major->s32s(), + EXPECT_EQ(lit_dim0major->element_count(), 12); + EXPECT_THAT(lit_dim0major->data(), testing::ElementsAreArray(expected_dim0major)); // Test expected memory layout when using Relayout to column major. auto relaid_lit_to_dim0minor = lit_dim0major->Relayout(layout_r3_dim0minor_); - EXPECT_THAT(relaid_lit_to_dim0minor->s32s(), + EXPECT_THAT(relaid_lit_to_dim0minor->data(), testing::ElementsAreArray(expected_dim0minor)); } @@ -687,28 +717,28 @@ TEST_F(LiteralUtilTest, SliceR3U32Full) { } TEST_F(LiteralUtilTest, PopulateR1S64) { - Literal output; + Literal output(ShapeUtil::MakeShape(S64, {1})); output.PopulateR1({77}); auto expected = Literal::CreateR1({77}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1U64) { - Literal output; + Literal output(ShapeUtil::MakeShape(U64, {2})); output.PopulateR1({{77, 88}}); auto expected = Literal::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR1C64) { - Literal output; + Literal output(ShapeUtil::MakeShape(C64, {1})); output.PopulateR1({{77, 88}}); auto expected = Literal::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateR2C64) { - Literal output; + Literal output(ShapeUtil::MakeShape(C64, {2, 2})); output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); auto expected = Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); @@ -716,78 +746,78 @@ TEST_F(LiteralUtilTest, PopulateR2C64) { } TEST_F(LiteralUtilTest, PopulateWithValueR0BF16) { - Literal output; + Literal output(ShapeUtil::MakeShape(BF16, {})); bfloat16 h(0.25f); - output.PopulateWithValue(h, {}); + output.PopulateWithValue(h); auto expected = Literal::CreateR0(h); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1BF16) { - Literal output; + Literal output(ShapeUtil::MakeShape(BF16, {3})); bfloat16 h(0.5f); - output.PopulateWithValue(h, {3}); + output.PopulateWithValue(h); auto expected = Literal::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2BF16) { - Literal output; + Literal output(ShapeUtil::MakeShape(BF16, {2, 2})); bfloat16 h(2.0f); - output.PopulateWithValue(h, {2, 2}); + output.PopulateWithValue(h); auto expected = Literal::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { - Literal output; - output.PopulateWithValue(2.5f, {}); + Literal output(ShapeUtil::MakeShape(F32, {})); + output.PopulateWithValue(2.5f); auto expected = Literal::CreateR0(2.5f); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1S64) { - Literal output; - output.PopulateWithValue(-7, {3}); + Literal output(ShapeUtil::MakeShape(S64, {3})); + output.PopulateWithValue(-7); auto expected = Literal::CreateR1({-7, -7, -7}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { - Literal output; - output.PopulateWithValue(42, {2, 2}); + Literal output(ShapeUtil::MakeShape(U64, {2, 2})); + output.PopulateWithValue(42); auto expected = Literal::CreateR2({{42, 42}, {42, 42}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { - Literal output; - output.PopulateWithValue({4, 2}, {2, 2}); + Literal output(ShapeUtil::MakeShape(C64, {2, 2})); + output.PopulateWithValue({4, 2}); auto expected = Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { - Literal output; + Literal output(ShapeUtil::MakeShape(F16, {})); half h(0.25f); - output.PopulateWithValue(h, {}); + output.PopulateWithValue(h); auto expected = Literal::CreateR0(h); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR1F16) { - Literal output; + Literal output(ShapeUtil::MakeShape(F16, {3})); half h(0.5f); - output.PopulateWithValue(h, {3}); + output.PopulateWithValue(h); auto expected = Literal::CreateR1({h, h, h}); EXPECT_EQ(output, *expected); } TEST_F(LiteralUtilTest, PopulateWithValueR2F16) { - Literal output; + Literal output(ShapeUtil::MakeShape(F16, {2, 2})); half h(2.0f); - output.PopulateWithValue(h, {2, 2}); + output.PopulateWithValue(h); auto expected = Literal::CreateR2({{h, h}, {h, h}}); EXPECT_EQ(output, *expected); } @@ -803,7 +833,7 @@ TEST_F(LiteralUtilTest, ReplicateR2U32) { EXPECT_EQ(*output, *expected); } -TEST_F(LiteralUtilTest, Copy) { +TEST_F(LiteralUtilTest, CopySliceFrom) { const int64 dimensions[] = {17, 15, 34, 21}; const int64 layouts[][4] = { {3, 2, 1, 0}, {0, 2, 1, 3}, {0, 1, 2, 3}, {2, 0, 3, 1}, {1, 3, 0, 2}}; @@ -826,7 +856,7 @@ TEST_F(LiteralUtilTest, Copy) { const int64 src_base[] = {3, 1, 5, 7}; const int64 dest_base[] = {6, 4, 12, 2}; const int64 copy_size[] = {7, 8, 11, 9}; - TF_EXPECT_OK(blank->Copy(*source, src_base, dest_base, copy_size)); + TF_EXPECT_OK(blank->CopySliceFrom(*source, src_base, dest_base, copy_size)); std::vector source_indexes(TF_ARRAYSIZE(dimensions), 0); std::vector blank_indexes(TF_ARRAYSIZE(dimensions), 0); @@ -849,16 +879,16 @@ TEST_F(LiteralUtilTest, Copy) { } } -TEST_F(LiteralUtilTest, CopyScalars) { +TEST_F(LiteralUtilTest, CopyFromScalars) { auto zero = Literal::CreateR0(0); auto nine = Literal::CreateR0(9); - TF_EXPECT_OK(zero->Copy(*nine, {}, {}, {})); + TF_EXPECT_OK(zero->CopyFrom(*nine)); EXPECT_EQ(*zero, *nine); auto vect = Literal::CreateR1({3, 4, 9, 12, 5, 17, 21}); - TF_EXPECT_OK(zero->Copy(*vect, {5}, {}, {})); + TF_EXPECT_OK(zero->CopySliceFrom(*vect, {5}, {}, {})); EXPECT_EQ(zero->Get({}), 17); - TF_EXPECT_OK(vect->Copy(*zero, {}, {4}, {})); + TF_EXPECT_OK(vect->CopySliceFrom(*zero, {}, {4}, {})); EXPECT_EQ(vect->Get({4}), 17); } @@ -872,7 +902,7 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = Literal::CreateR1({9}); - TF_EXPECT_OK(nine->Copy(*empty, {0}, {0}, {0})); + TF_EXPECT_OK(nine->CopySliceFrom(*empty, {0}, {0}, {0})); EXPECT_EQ(*nine, *const_nine); } @@ -881,18 +911,101 @@ TEST_F(LiteralUtilTest, CopyFromAndToZeroElement) { const auto empty = Literal::CreateFromShape(empty_r1_shape); auto nine = Literal::CreateR1({9}); - TF_EXPECT_OK(empty->Copy(*nine, {0}, {0}, {0})); + TF_EXPECT_OK(empty->CopySliceFrom(*nine, {0}, {0}, {0})); EXPECT_EQ(*empty, *const_empty); } } +TEST_F(LiteralUtilTest, CopyFromNilShape) { + Literal nil_literal0(ShapeUtil::MakeNil()); + Literal nil_literal1(ShapeUtil::MakeNil()); + // This doesn't actually do any copying, but it should succeed. + TF_ASSERT_OK(nil_literal0.CopyFrom(nil_literal1)); +} + +TEST_F(LiteralUtilTest, CopyFromArrays) { + auto scalar_42 = Literal::CreateR0(42.0); + auto scalar_123 = Literal::CreateR0(123.0); + EXPECT_NE(*scalar_42, *scalar_123); + TF_ASSERT_OK(scalar_42->CopyFrom(*scalar_123, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(*scalar_42, *scalar_123); + EXPECT_EQ(scalar_42->Get({}), 123.0f); + + auto matrix_1234 = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto matrix_5678 = Literal::CreateR2({{5.0, 6.0}, {7.0, 8.0}}); + EXPECT_NE(*matrix_1234, *matrix_5678); + EXPECT_EQ(matrix_1234->Get({0, 0}), 1.0f); + TF_ASSERT_OK(matrix_1234->CopyFrom(*matrix_5678, /*dest_shape_index=*/{}, + /*src_shape_index=*/{})); + EXPECT_EQ(*matrix_1234, *matrix_5678); + EXPECT_EQ(matrix_1234->Get({0, 0}), 5.0f); +} + +TEST_F(LiteralUtilTest, CopyFromTuples) { + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal nil_literal(ShapeUtil::MakeNil()); + auto nested_tuple = Literal::MakeTuple( + {matrix.get(), + Literal::MakeTuple({Literal::CreateR0(42).get(), + Literal::CreateR1({23.0, 44.0}).get(), + &nil_literal}) + .get()}); + // Create a tuple the same shape as the inner tuple of nested_tuple but with + // different values.. + auto tuple = Literal::MakeTuple({Literal::CreateR0(-5).get(), + Literal::CreateR1({2.0, 4.0}).get(), + &nil_literal}); + + EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + EXPECT_EQ(nested_tuple->Get({}, {1, 0}), 42); + EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 23.0); + EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 44.0); + + // Overwrite the inner tuple element of nested_tuple with the contents of + // 'tuple'. + TF_ASSERT_OK(nested_tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{})); + + // The matrix element should be unchanged. + EXPECT_EQ(*matrix, LiteralView::Create(*nested_tuple, {0})); + + // The tuple element should have been copied from 'tuple'. + EXPECT_EQ(nested_tuple->Get({}, {1, 0}), -5); + EXPECT_EQ(nested_tuple->Get({0}, {1, 1}), 2.0); + EXPECT_EQ(nested_tuple->Get({1}, {1, 1}), 4.0); +} +TEST_F(LiteralUtilTest, CopyBetweenSameTuple) { + auto tuple = Literal::MakeTuple( + {Literal::CreateR0(-2).get(), Literal::CreateR0(4).get()}); + + EXPECT_EQ(tuple->Get({}, {0}), -2); + EXPECT_EQ(tuple->Get({}, {1}), 4); + + // Copy from one element to the other. + TF_ASSERT_OK(tuple->CopyFrom(*tuple, /*dest_shape_index=*/{1}, + /*src_shape_index=*/{0})); + + EXPECT_EQ(tuple->Get({}, {0}), -2); + EXPECT_EQ(tuple->Get({}, {1}), -2); +} + +TEST_F(LiteralUtilTest, CopyFromDifferentShapes) { + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto vector = Literal::CreateR1({5.0, 7.0}); + Status status = matrix->CopyFrom(*vector); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Destination subshape incompatible")); +} + TEST_F(LiteralUtilTest, F16) { // Verify that the internal data views are consistent and that they // are in little endian format // TODO - modify if we make the data format machine endianess dependent auto m1 = Literal::CreateFromShape(ShapeUtil::MakeShape(F16, {2, 2})); Literal* l1 = m1.get(); - const char* d1 = static_cast(l1->InternalData()); + const char* d1 = reinterpret_cast(l1->data().data()); EXPECT_EQ(d1[0], 0); EXPECT_EQ(d1[1], 0); EXPECT_EQ(d1[2], 0); @@ -901,13 +1014,12 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d1[5], 0); EXPECT_EQ(d1[6], 0); EXPECT_EQ(d1[7], 0); - EXPECT_EQ(l1->InternalData(), l1->MutableInternalData()); half h1(1.0f); half h2(2.0f); auto m2 = Literal::CreateR2({{h1, h2}, {h2, h1}}); Literal* l2 = m2.get(); - const char* d2 = static_cast(l2->InternalData()); + const char* d2 = reinterpret_cast(l2->data().data()); EXPECT_EQ(d2[0], 0); EXPECT_EQ(d2[1], 0x3C); EXPECT_EQ(d2[2], 0); @@ -916,7 +1028,6 @@ TEST_F(LiteralUtilTest, F16) { EXPECT_EQ(d2[5], 0x40); EXPECT_EQ(d2[6], 0); EXPECT_EQ(d2[7], 0x3C); - EXPECT_EQ(l2->InternalData(), l2->MutableInternalData()); } TEST_F(LiteralUtilTest, Populate) { @@ -941,7 +1052,9 @@ TEST_F(LiteralUtilTest, Populate) { auto generator = [&](tensorflow::gtl::ArraySlice indexes) -> uint32 { // Offsets from linear index just to avoid R0 literals to be initialized // with zero. - return literal->LinearIndex(indexes) + 17; + return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(), + indexes) + + 17; }; TF_EXPECT_OK(literal->Populate(generator)); @@ -1118,16 +1231,18 @@ TEST_F(LiteralUtilTest, CopyFromProto_Bool) { for (int len = 0; len < 25; ++len) { p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(len); + LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_preds(); for (int i = 0; i < len; ++i) { p.add_preds((i % 2) == (len % 2)); } - Literal literal(p); - ASSERT_EQ(len, literal.preds_size()); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, + Literal::CreateFromProto(p)); + ASSERT_EQ(len, literal->data().size()); int i = 0; - for (auto it = literal.preds().begin(); it < literal.preds().end(); ++it) { - EXPECT_EQ((i % 2) == (len % 2), *it); + for (bool value : literal->data()) { + EXPECT_EQ((i % 2) == (len % 2), value); ++i; } } @@ -1141,8 +1256,7 @@ TEST_F(LiteralUtilTest, ToProto_f16) { auto m = Literal::CreateR2({{h1, h2}, {h2, h1}}); Literal* l = m.get(); EXPECT_EQ(4, ShapeUtil::ElementsIn(l->shape())); - EXPECT_EQ(4, l->f16s().size()); - EXPECT_EQ(4, l->f16s_size()); + EXPECT_EQ(4, l->data().size()); LiteralProto p = l->ToProto(); EXPECT_EQ(4, ShapeUtil::ElementsIn(p.shape())); @@ -1168,17 +1282,12 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { p.mutable_shape()->set_element_type(F16); p.mutable_shape()->clear_dimensions(); p.mutable_shape()->add_dimensions(4); + LayoutUtil::SetToDefaultLayout(p.mutable_shape()); p.clear_f16s(); p.set_f16s(half_vals, 8); - - Literal literal(p); - ASSERT_EQ(4, literal.f16s_size()); - ASSERT_EQ(h1, literal.f16s(0)); - ASSERT_EQ(h2, literal.f16s(1)); - ASSERT_EQ(h2, literal.f16s(2)); - ASSERT_EQ(h1, literal.f16s(3)); - - const std::vector& r = literal.f16s(); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr literal, + Literal::CreateFromProto(p)); + auto r = literal->data(); ASSERT_EQ(4, r.size()); ASSERT_EQ(h1, r[0]); ASSERT_EQ(h2, r[1]); @@ -1186,24 +1295,402 @@ TEST_F(LiteralUtilTest, CopyFromProto_f16) { ASSERT_EQ(h1, r[3]); } -TEST_F(LiteralUtilTest, Subliterals) { +TEST_F(LiteralUtilTest, LiteralViewTest) { + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + Literal nil(ShapeUtil::MakeNil()); + + EXPECT_EQ(LiteralView::Create(*scalar, {}), *scalar); + EXPECT_EQ(LiteralView::Create(*matrix, {}), *matrix); + EXPECT_EQ(LiteralView::Create(*tuple, {}), *tuple); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {}), *nested_tuple); + EXPECT_EQ(LiteralView::Create(nil, {}), nil); + + EXPECT_EQ(LiteralView::Create(*tuple, {0}), *scalar); + EXPECT_EQ(LiteralView::Create(*tuple, {1}), *matrix); + + EXPECT_EQ(LiteralView::Create(*nested_tuple, {0}), *tuple); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 0}), *scalar); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {0, 1}), *matrix); + EXPECT_EQ(LiteralView::Create(*nested_tuple, {1}), *scalar); +} + +TEST_F(LiteralUtilTest, MutatingLiteralView) { + auto scalar = Literal::CreateR0(1.0); + auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); + auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); + // Verify that changing the underlying data beneath the view changes the + // data of the view itself. + const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + EXPECT_EQ( + nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 1.0f); + EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, + /*shape_index=*/{0, 0}), + 1.0f); + nested_tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0, 0}, 555.0f); + EXPECT_EQ( + nested_tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0, 0}), + 555.0f); + EXPECT_EQ(nested_tuple_view.Get(/*multi_index=*/{}, + /*shape_index=*/{0, 0}), + 555.0f); +} + +TEST_F(LiteralUtilTest, LiteralViewOfALiteralView) { auto scalar = Literal::CreateR0(1.0); auto matrix = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); auto tuple = Literal::MakeTuple({scalar.get(), matrix.get()}); auto nested_tuple = Literal::MakeTuple({tuple.get(), scalar.get()}); - EXPECT_EQ(&scalar->GetSubliteral(/*index=*/{}), scalar.get()); - EXPECT_EQ(&matrix->GetSubliteral(/*index=*/{}), matrix.get()); - EXPECT_EQ(&tuple->GetSubliteral(/*index=*/{}), tuple.get()); - EXPECT_EQ(&nested_tuple->GetSubliteral(/*index=*/{}), nested_tuple.get()); + const auto nested_tuple_view = LiteralView::Create(*nested_tuple); + const auto tuple_view = + LiteralView::Create(nested_tuple_view, /*view_root=*/{0}); + const auto matrix_view = LiteralView::Create(tuple_view, /*view_root=*/{1}); + EXPECT_EQ(matrix_view, *Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}})); +} + +TEST_F(LiteralUtilTest, LiteralMove) { + std::unique_ptr matrix = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + Literal literal(std::move(*matrix)); + + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); + EXPECT_EQ(literal.Get({0, 0}), 1.0); + EXPECT_EQ(literal.Get({0, 1}), 2.0); + EXPECT_EQ(literal.Get({1, 0}), 3.0); + EXPECT_EQ(literal.Get({1, 1}), 4.0); +} - EXPECT_EQ(tuple->GetSubliteral(/*index=*/{0}), *scalar); - EXPECT_EQ(tuple->GetSubliteral(/*index=*/{1}), *matrix); +TEST_F(LiteralUtilTest, DecomposeTuple) { + Literal nil_literal(ShapeUtil::MakeNil()); + auto nested_tuple = Literal::MakeTuple( + {Literal::CreateR2({{1, 2}, {3, 4}}).get(), + Literal::MakeTuple({Literal::CreateR0(42).get(), + Literal::CreateR1({23.0, 44.0}).get(), + &nil_literal}) + .get(), + &nil_literal}); + + EXPECT_FALSE(ShapeUtil::IsNil(nested_tuple->shape())); + std::vector elements = nested_tuple->DecomposeTuple(); + EXPECT_TRUE(ShapeUtil::IsNil(nested_tuple->shape())); + + ASSERT_EQ(elements.size(), 3); + + EXPECT_TRUE(ShapeUtil::Compatible(elements[0].shape(), + ShapeUtil::MakeShape(S32, {2, 2}))); + EXPECT_EQ(elements[0].Get({0, 0}), 1); + EXPECT_EQ(elements[0].Get({0, 1}), 2); + EXPECT_EQ(elements[0].Get({1, 0}), 3); + EXPECT_EQ(elements[0].Get({1, 1}), 4); + + EXPECT_TRUE(ShapeUtil::Compatible( + elements[1].shape(), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(F64, {2}), + ShapeUtil::MakeNil()}))); + EXPECT_EQ(elements[1].Get({}, /*shape_index=*/{0}), 42); + EXPECT_EQ(elements[1].Get({0}, /*shape_index=*/{1}), 23.0); + EXPECT_EQ(elements[1].Get({1}, /*shape_index=*/{1}), 44.0); + + EXPECT_TRUE(ShapeUtil::Compatible(elements[2].shape(), ShapeUtil::MakeNil())); +} + +TEST_F(LiteralUtilTest, DecomposeEmptyTuple) { + Literal nil_literal(ShapeUtil::MakeNil()); + std::vector elements = nil_literal.DecomposeTuple(); + EXPECT_EQ(elements.size(), 0); +} + +TEST_F(LiteralUtilTest, MoveIntoTuple) { + std::vector elements; + elements.push_back(std::move(*Literal::CreateR0(1.0))); + elements.push_back(std::move(*Literal::CreateR1({4, 8}))); + elements.push_back(std::move( + *Literal::MakeTuple({Literal::CreateR0(42).get(), + Literal::CreateR1({23.0, 44.0}).get()}) + + )); + + Literal literal = Literal::MoveIntoTuple(&elements); + ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 3); + + EXPECT_EQ(literal.Get({}, /*shape_index=*/{0}), 1.0); + EXPECT_EQ(literal.Get({0}, /*shape_index=*/{1}), 4); + EXPECT_EQ(literal.Get({1}, /*shape_index=*/{1}), 8); + EXPECT_EQ(literal.Get({}, /*shape_index=*/{2, 0}), 42); + EXPECT_EQ(literal.Get({0}, /*shape_index=*/{2, 1}), 23.0); + EXPECT_EQ(literal.Get({1}, /*shape_index=*/{2, 1}), 44.0); + + for (const Literal& element : elements) { + EXPECT_TRUE(ShapeUtil::IsNil(element.shape())); + } +} + +TEST_F(LiteralUtilTest, MoveIntoEmptyTuple) { + Literal literal = Literal::MoveIntoTuple({}); + ASSERT_TRUE(ShapeUtil::IsTuple(literal.shape())); + ASSERT_EQ(ShapeUtil::TupleElementCount(literal.shape()), 0); +} + +TEST_F(LiteralUtilTest, LiteralMoveAssignment) { + Literal literal; + EXPECT_TRUE(ShapeUtil::Equal(ShapeUtil::MakeNil(), literal.shape())); + + std::unique_ptr matrix = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + literal = std::move(*matrix); + + EXPECT_TRUE( + ShapeUtil::Equal(ShapeUtil::MakeShape(F32, {2, 2}), literal.shape())); + EXPECT_EQ(literal.Get({0, 0}), 1.0); + EXPECT_EQ(literal.Get({0, 1}), 2.0); + EXPECT_EQ(literal.Get({1, 0}), 3.0); + EXPECT_EQ(literal.Get({1, 1}), 4.0); +} + +TEST_F(LiteralUtilTest, LiteralViewCopy) { + std::unique_ptr matrix = + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); + const auto matrix_view = LiteralView::Create(*matrix); + LiteralView matrix_view_copy(matrix_view); + + EXPECT_EQ(matrix_view_copy.Get({0, 0}), 1.0); + EXPECT_EQ(matrix_view_copy.Get({0, 1}), 2.0); + EXPECT_EQ(matrix_view_copy.Get({1, 0}), 3.0); + EXPECT_EQ(matrix_view_copy.Get({1, 1}), 4.0); +} + +TEST_F(LiteralUtilTest, GetSetTuple) { + auto tuple = Literal::MakeTuple( + {Literal::CreateR0(42.0).get(), + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get()}); + EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), 42.0); + tuple->Set(/*multi_index=*/{}, /*shape_index=*/{0}, -5.0); + EXPECT_EQ(tuple->Get(/*multi_index=*/{}, /*shape_index=*/{0}), -5.0); + + EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + 3.0); + tuple->Set(/*multi_index=*/{1, 0}, /*shape_index=*/{1}, -4.0); + EXPECT_EQ(tuple->Get(/*multi_index=*/{1, 0}, /*shape_index=*/{1}), + -4.0); +} + +TEST_F(LiteralUtilTest, CreateFromShapeZeroInitialized) { + // Literals constructed using CreateFromShape should be zero initialized. + std::unique_ptr scalar_f32 = + Literal::CreateFromShape(ShapeUtil::MakeShape(F32, {})); + EXPECT_EQ(scalar_f32->Get({}), 0.0); + EXPECT_TRUE(scalar_f32->IsAll(0)); + + std::unique_ptr vector_s32 = + Literal::CreateFromShape(ShapeUtil::MakeShape(S32, {3})); + EXPECT_EQ(vector_s32->Get({0}), 0); + EXPECT_EQ(vector_s32->Get({1}), 0); + EXPECT_EQ(vector_s32->Get({2}), 0); + EXPECT_TRUE(vector_s32->IsAll(0)); + + std::unique_ptr tuple = + Literal::CreateFromShape(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {}), ShapeUtil::MakeShape(PRED, {2}), + ShapeUtil::MakeShape(U64, {2, 1}), ShapeUtil::MakeShape(C64, {})})); + + EXPECT_EQ(tuple->Get({}, {0}), 0.0); + EXPECT_EQ(tuple->Get({0}, {1}), false); + EXPECT_EQ(tuple->Get({1}, {1}), false); + EXPECT_EQ(tuple->Get({0, 0}, {2}), 0); + EXPECT_EQ(tuple->Get({1, 0}, {2}), 0); + EXPECT_EQ(tuple->Get({}, {3}), complex64(0.0f, 0.0f)); +} + +TEST_F(LiteralUtilTest, ProtoRoundTrip) { + // Test serializing then deserializing a Literal through a proto. + auto one_f32 = Literal::CreateR0(1.0); + auto two_f32 = Literal::CreateR0(2.0); + auto vector_int8 = Literal::CreateR1({-128, 0, 2, 4, 7, 56, 127}); + auto vector_c64 = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + auto vector_bfloat16 = Literal::CreateR1( + {bfloat16{-1.0}, bfloat16{2.0}, bfloat16{-3.0}}); + auto vector_half = + Literal::CreateR1({half{10.0}, half{20.0}, half{-30.0}}); + auto matrix_pred = + Literal::CreateR2({{true, false, true}, {false, false, true}}); + auto tuple = Literal::MakeTuple( + {one_f32.get(), vector_half.get(), matrix_pred.get(), matrix_pred.get()}); + Literal nil_literal(ShapeUtil::MakeNil()); + auto nested_tuple = Literal::MakeTuple( + {tuple.get(), vector_bfloat16.get(), tuple.get(), &nil_literal}); + + auto to_from_proto = [](const Literal& literal) -> Literal { + return std::move(*Literal::CreateFromProto(literal.ToProto()).ValueOrDie()); + }; + + EXPECT_EQ(*one_f32, to_from_proto(*one_f32)); + EXPECT_EQ(*vector_c64, to_from_proto(*vector_c64)); + EXPECT_EQ(*vector_bfloat16, to_from_proto(*vector_bfloat16)); + EXPECT_EQ(*matrix_pred, to_from_proto(*matrix_pred)); + EXPECT_EQ(*tuple, to_from_proto(*tuple)); + EXPECT_EQ(*nested_tuple, to_from_proto(*nested_tuple)); + EXPECT_EQ(nil_literal, to_from_proto(nil_literal)); + + EXPECT_NE(*one_f32, *two_f32); + EXPECT_NE(*one_f32, to_from_proto(*two_f32)); +} + +TEST_F(LiteralUtilTest, InvalidProtoNoValues) { + // Proto contains a shape, but no values. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 3 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoNoShape) { + // Proto contains values, but no shape. + LiteralProto proto; + proto.add_preds(false); + proto.add_preds(true); + proto.add_preds(false); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no shape")); +} + +TEST_F(LiteralUtilTest, InvalidProtoWrongContainer) { + // Proto contains values in wrong container. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {3}); + proto.add_preds(false); + proto.add_preds(true); + proto.add_preds(false); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 3 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooFewValues) { + // Proto contains too few values. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(F32, {42, 2}); + proto.add_f32s(1.0); + proto.add_f32s(2.0); + proto.add_f32s(3.0); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 84 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooManyValues) { + // Proto contains too many values. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(S32, {2}); + proto.add_s32s(42); + proto.add_s32s(-10); + proto.add_s32s(100); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), + HasSubstr("Expected 2 elements in LiteralProto")); +} + +TEST_F(LiteralUtilTest, InvalidProtoMissingLayout) { + // Proto shape missing layout. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeShape(PRED, {2, 2}); + LayoutUtil::ClearLayout(proto.mutable_shape()); + proto.add_preds(true); + proto.add_preds(false); + proto.add_preds(true); + proto.add_preds(false); + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("LiteralProto has no layout")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooFewTupleElements) { + // Proto has the too few tuple elements. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + LiteralProto* element0 = proto.add_tuple_literals(); + *element0->mutable_shape() = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + element0->add_preds(false); + element0->add_preds(true); + + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); +} + +TEST_F(LiteralUtilTest, InvalidProtoTooManyTupleElements) { + // Proto has the too many tuple elements. + LiteralProto proto; + *proto.mutable_shape() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(PRED, {2}), ShapeUtil::MakeShape(F32, {})}); + LiteralProto* element0 = proto.add_tuple_literals(); + *element0->mutable_shape() = + ShapeUtil::GetTupleElementShape(proto.shape(), 0); + element0->add_preds(false); + element0->add_preds(true); + LiteralProto* element1 = proto.add_tuple_literals(); + *element1->mutable_shape() = + ShapeUtil::GetTupleElementShape(proto.shape(), 1); + element1->add_f32s(42.0); + LiteralProto* element2 = proto.add_tuple_literals(); + *element2->mutable_shape() = ShapeUtil::MakeShape(F32, {}); + element2->add_f32s(123.0); + + Status status = Literal::CreateFromProto(proto).status(); + ASSERT_FALSE(status.ok()); + ASSERT_THAT(status.error_message(), HasSubstr("Expected 2 tuple elements")); +} + +TEST_F(LiteralUtilTest, SortSparseElements) { + auto literal = + Literal::CreateSparse({10, 10, 10}, SparseIndexArray(10, 3), {}); + literal->AppendSparseElement({2, 3, 4}, 2.0); + literal->AppendSparseElement({3, 4, 5}, 3.0); + literal->AppendSparseElement({1, 2, 3}, 1.0); + literal->SortSparseElements(); + ASSERT_EQ(literal->ToString(false), + "f32[10,10,10]{[1, 2, 3]: 1, [2, 3, 4]: 2, [3, 4, 5]: 3}"); +} - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0}), *tuple); - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 0}), *scalar); - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{0, 1}), *matrix); - EXPECT_EQ(nested_tuple->GetSubliteral(/*index=*/{1}), *scalar); +TEST_F(LiteralUtilTest, GetSparseElementAsString) { + std::vector dimensions = {10, 10, 10}; + SparseIndexArray indices(10, {{1, 2, 3}, {2, 3, 4}, {3, 4, 5}}); + + ASSERT_EQ( + Literal::CreateSparse(dimensions, indices, {true, false, true}) + ->GetSparseElementAsString(1), + "false"); + ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1, 2, 3}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(int64{2})); + ASSERT_EQ(Literal::CreateSparse(dimensions, indices, {1.0, 2.0, 3.0}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(double{2.0})); + ASSERT_EQ(Literal::CreateSparse(dimensions, indices, + {half{1.0}, half{2.0}, half{3.0}}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat(half{2.0})); + ASSERT_EQ( + Literal::CreateSparse( + dimensions, indices, + std::vector{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}) + ->GetSparseElementAsString(1), + tensorflow::strings::StrCat("(", float{3.0}, ", ", float{4.0}, ")")); } } // namespace diff --git a/tensorflow/compiler/xla/map_util.h b/tensorflow/compiler/xla/map_util.h index 51d0d5f86f00c539951e8e2baa6296337a5a21e9..8db8c6f3de84a6c46625eadbb6b0f83d2262e5f7 100644 --- a/tensorflow/compiler/xla/map_util.h +++ b/tensorflow/compiler/xla/map_util.h @@ -16,6 +16,11 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ #define TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ +#include +#include + +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -44,6 +49,41 @@ typename Collection::value_type::second_type& FindOrDie( return it->second; } +// Like FindOrDie but returns an error instead of dying if `key` is not in +// `container`. +template +StatusOr< + std::reference_wrapper> +MaybeFind(const Collection& collection, + const typename Collection::value_type::first_type& key) { + typename Collection::const_iterator it = collection.find(key); + if (it == collection.end()) { + std::ostringstream os; + os << key; + return NotFound("key not found: %s", os.str().c_str()); + } + return {it->second}; +} + +// Returns a const reference to the value associated with the given key if it +// exists, otherwise returns a const reference to the provided default value. +// +// WARNING: If a temporary object is passed as the default "value," +// this function will return a reference to that temporary object, +// which will be destroyed at the end of the statement. A common +// example: if you have a map with string values, and you pass a char* +// as the default "value," either use the returned value immediately +// or store it in a string (not string&). +template +const typename Collection::value_type::second_type& FindOrDefault( + const Collection& collection, + const typename Collection::value_type::first_type& key, + const typename Collection::value_type::second_type& value) { + auto it = collection.find(key); + if (it != collection.end()) return it->second; + return value; +} + // Inserts the key-value pair into the collection. Dies if key was already // present. template @@ -60,6 +100,12 @@ bool ContainsKey(const Collection& collection, const Key& key) { return collection.find(key) != collection.end(); } +// Inserts `value` into `set`. Dies if it was already present. +template +void InsertOrDie(Set* const set, const typename Set::value_type& value) { + CHECK(set->insert(value).second) << "duplicate value: " << value; +} + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/packed_literal_reader.cc b/tensorflow/compiler/xla/packed_literal_reader.cc index 70e0f5a74711c8ceef1b6d4225141aa1cc9c6219..857aae0a7982a57bb3057a6f267f5f033a0fdde4 100644 --- a/tensorflow/compiler/xla/packed_literal_reader.cc +++ b/tensorflow/compiler/xla/packed_literal_reader.cc @@ -44,11 +44,11 @@ StatusOr> PackedLiteralReader::Read( VLOG(3) << "reading shape from file: " << ShapeUtil::HumanString(shape) << " layout: " << (layout == nullptr ? "" : layout->ShortDebugString()); - auto result = MakeUnique(); - *result->mutable_shape() = shape; + Shape literal_shape = shape; if (layout != nullptr) { - TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(*layout, shape)); - *result->mutable_shape()->mutable_layout() = *layout; + TF_RETURN_IF_ERROR( + LayoutUtil::ValidateLayoutForShape(*layout, literal_shape)); + *literal_shape.mutable_layout() = *layout; } if (shape.element_type() != F32) { @@ -57,10 +57,12 @@ StatusOr> PackedLiteralReader::Read( PrimitiveType_Name(shape.element_type()).c_str()); } + auto result = MakeUnique(literal_shape); + result->PopulateWithValue(std::numeric_limits::quiet_NaN()); + int64 elements = ShapeUtil::ElementsIn(shape); - result->Resize(elements, std::numeric_limits::quiet_NaN()); - std::vector* field = result->mutable_f32s(); - char* data = tensorflow::bit_cast(field->data()); + tensorflow::gtl::ArraySlice field = result->data(); + char* data = tensorflow::bit_cast(field.data()); uint64 bytes = elements * sizeof(float); tensorflow::StringPiece sp; auto s = file_->Read(offset_, bytes, &sp, data); diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2bce56b7bd2f91f20ea670d0e7ccaa432c2b5f9f..143c9a2366be5786b7ef2148580caeb97d67d2d8 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -20,79 +20,6 @@ limitations under the License. namespace xla { namespace primitive_util { -template <> -PrimitiveType NativeToPrimitiveType() { - return PRED; -} - -// Unsigned integer -template <> -PrimitiveType NativeToPrimitiveType() { - return U8; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return U16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return U32; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return U64; -} - -// Signed integer -template <> -PrimitiveType NativeToPrimitiveType() { - return S8; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return S16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return S32; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return S64; -} - -// Floating point -template <> -PrimitiveType NativeToPrimitiveType() { - return F32; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return F64; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return BF16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return F16; -} - -template <> -PrimitiveType NativeToPrimitiveType() { - return C64; -} - bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 19c6a138885c61f1304bfae3d8bb5d958a1bb5bc..b26a10ade63a5dad3bf8f9f3a2a33c3c5e67bdb2 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -26,6 +26,13 @@ limitations under the License. namespace xla { namespace primitive_util { +// The number of exponent bits in a BF16 value. +const int kBFloat16ExponentBits = 8; + +// The number of mantissa bits in a BF16 value. There is an implicit leading +// 1, so there is an implicit additional bit of precision. +const int kBFloat16MantissaBits = 7; + // Returns the XLA primitive type (eg, F32) corresponding to the given // template parameter native type (eg, float). template @@ -40,49 +47,81 @@ PrimitiveType NativeToPrimitiveType() { } // Declarations of specializations for each native type which correspond to a -// XLA primitive type. +// XLA primitive type. As an optimization, these are declared inline in the +// header. template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return PRED; +} // Unsigned integer template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U8; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U16; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U32; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return U64; +} // Signed integer template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S8; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S16; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S32; +} template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return S64; +} // Floating point template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return F32; +} + template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return F64; +} + template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return F16; +} + template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return BF16; +} // Complex template <> -PrimitiveType NativeToPrimitiveType(); +inline PrimitiveType NativeToPrimitiveType() { + return C64; +} bool IsFloatingPointType(PrimitiveType type); diff --git a/tensorflow/compiler/xla/python/BUILD b/tensorflow/compiler/xla/python/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..e2972f06016ab3555c4fc0cc4616993fe6764b1e --- /dev/null +++ b/tensorflow/compiler/xla/python/BUILD @@ -0,0 +1,86 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") + +py_library( + name = "xla_client", + srcs = ["xla_client.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":pywrap_xla", + "//tensorflow/compiler/xla:xla_data_proto_py", + ], +) + +py_test( + name = "xla_client_test", + srcs = ["xla_client_test.py"], + main = "xla_client_test.py", + srcs_version = "PY2AND3", + deps = [ + ":xla_client", + "//tensorflow/python:platform_test", + ], +) + +cc_library( + name = "numpy_bridge", + srcs = ["numpy_bridge.cc"], + hdrs = ["numpy_bridge.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/python:numpy_lib", + ], +) + +cc_library( + name = "local_computation_builder", + srcs = ["local_computation_builder.cc"], + hdrs = ["local_computation_builder.h"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:executable_build_options", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib", + ], +) + +tf_py_wrap_cc( + name = "pywrap_xla", + srcs = ["xla.i"], + swig_includes = [ + "local_computation_builder.i", + ], + deps = [ + ":local_computation_builder", + ":numpy_bridge", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:cpu_plugin", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/ndlstm/__init__.py b/tensorflow/compiler/xla/python/__init__.py similarity index 100% rename from tensorflow/contrib/ndlstm/__init__.py rename to tensorflow/compiler/xla/python/__init__.py diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc new file mode 100644 index 0000000000000000000000000000000000000000..a89146d4484a90fc4d89ced0b0240ae9585e1f28 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -0,0 +1,590 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/local_computation_builder.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/platform/default/thread_annotations.h" + +namespace xla { + +namespace swig { + +// TODO(b/34473877) Ideally XLA would support AllReduce among arbitrary sets of +// device handles instead of needing to set the number of replicas at XLA +// service initialization time. +tensorflow::mutex g_local_client_mutex(tensorflow::LINKER_INITIALIZED); +int g_replica_count GUARDED_BY(g_local_client_mutex) = 1; +LocalClient* g_local_client GUARDED_BY(g_local_client_mutex) = nullptr; + +Status InitializeReplicaCount(int replica_count) { + if (replica_count < 1) { + return InvalidArgument("Replica count must be >= 1; got %d.", + replica_count); + } + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return FailedPrecondition( + "Attempted to set the replica count to %d, but a local XLA service was " + "previously created with a replica count of %d.", + replica_count, g_replica_count); + } + g_replica_count = replica_count; + return Status::OK(); +} + +int GetReplicaCount() { + tensorflow::mutex_lock lock(g_local_client_mutex); + return g_replica_count; +} + +LocalClient* GetOrCreateLocalClient() { + tensorflow::mutex_lock lock(g_local_client_mutex); + if (g_local_client != nullptr) { + return g_local_client; + } + LocalClientOptions options; + options.set_number_of_replicas(g_replica_count); + g_local_client = ClientLibrary::GetOrCreateLocalClient(options).ValueOrDie(); + CHECK(g_local_client != nullptr); + return g_local_client; +} + +Status TransferToInfeedLocal(const Literal& literal) { + VLOG(1) << "Infeeding literal without replica number; shape: " + << literal.shape(); + LocalClient* client = GetOrCreateLocalClient(); + return client->TransferToInfeedLocal(literal, /*device_ordinal=*/0); +} + +Status TransferToInfeedLocalReplica(const Literal& literal, + int replica_number) { + VLOG(1) << "Infeeding shape " << literal.shape() + << " to replica number: " << replica_number; + LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + return client->TransferToInfeedLocal(literal, device_ordinal); +} + +StatusOr> TransferFromOutfeedLocalReplica( + const Shape& shape, int replica_number) { + VLOG(1) << "Outfeeding literal from replica number: " << replica_number + << " shape: " << shape; + LocalClient* client = GetOrCreateLocalClient(); + TF_ASSIGN_OR_RETURN(int device_ordinal, + client->ReplicaNumberToDeviceOrdinal(replica_number)); + return client->TransferFromOutfeedLocal(shape, device_ordinal); +} + +LocalShapedBuffer::LocalShapedBuffer( + std::unique_ptr shaped_buffer) + : shaped_buffer_(std::move(shaped_buffer)) {} + +const std::unique_ptr& LocalShapedBuffer::shaped_buffer() + const { + return shaped_buffer_; +} + +static StatusOr> ToBuffer( + LocalClient* client, int device_ordinal, const Literal& arg) { + return client->LiteralToShapedBuffer(arg, device_ordinal, + client->backend().memory_allocator()); +} + +/* static */ +LocalShapedBuffer* LocalShapedBuffer::FromLiteral( + const Literal& argument, + const tensorflow::gtl::optional& shape_with_layout) { + LocalClient* client = GetOrCreateLocalClient(); + std::unique_ptr buf; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + buf = ToBuffer(client, /*device_ordinal=*/0, *relaid).ConsumeValueOrDie(); + } else { + buf = ToBuffer(client, /*device_ordinal=*/0, argument).ConsumeValueOrDie(); + } + return new LocalShapedBuffer(std::move(buf)); +} + +std::unique_ptr LocalShapedBuffer::ToLiteral() const { + LocalClient* client = GetOrCreateLocalClient(); + return client->ShapedBufferToLiteral(*shaped_buffer()).ConsumeValueOrDie(); +} + +CompiledLocalComputation::CompiledLocalComputation( + std::unique_ptr executable) + : executable_(std::move(executable)) {} + +StatusOr> CompiledLocalComputation::Execute( + const std::vector& arguments, + const std::vector>& shapes_with_layout) { + LocalClient* client = GetOrCreateLocalClient(); + + VLOG(1) << "Execution requested with " << GetReplicaCount() << " replicas."; + + // Each replica populates a StatusOr result, but only replica zero actually + // retrieves its literal value. + std::vector>> results(GetReplicaCount()); + { + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "xlarun", + GetReplicaCount()); + + for (int replica = 0; replica < GetReplicaCount(); ++replica) { + pool.Schedule([this, client, replica, &arguments, &shapes_with_layout, + &results] { + StatusOr device_ordinal_status = + client->ReplicaNumberToDeviceOrdinal(replica); + if (!device_ordinal_status.ok()) { + results[replica] = device_ordinal_status.status(); + return; + } + const int device_ordinal = device_ordinal_status.ValueOrDie(); + VLOG(3) << "Replica " << replica + << " mapped to device ordinal for execution: " + << device_ordinal; + + // Transfer arguments in + std::vector> scoped_buffers; + scoped_buffers.reserve(arguments.size()); + for (int i = 0; i < arguments.size(); ++i) { + const Literal& argument = arguments[i]; + const tensorflow::gtl::optional& shape_with_layout = + shapes_with_layout[i]; + + StatusOr> pushed; + if (shape_with_layout) { + std::unique_ptr relaid = + argument.Relayout(shape_with_layout.value()); + pushed = ToBuffer(client, device_ordinal, *relaid); + } else { + pushed = ToBuffer(client, device_ordinal, argument); + } + if (!pushed.ok()) { + results[replica] = pushed.status(); + return; + } + + scoped_buffers.push_back(std::move(pushed).ValueOrDie()); + } + + // Execute + std::vector argument_buffers; + argument_buffers.reserve(scoped_buffers.size()); + for (auto& buffer : scoped_buffers) { + argument_buffers.push_back(buffer.get()); + } + + DeviceAssignment device_assignment = + client->backend() + .computation_placer() + ->AssignDevices(GetReplicaCount(), /*computation_count=*/1) + .ConsumeValueOrDie(); + + ExecutableRunOptions options; + options.set_device_ordinal(device_ordinal); + options.set_allocator(client->backend().memory_allocator()); + options.set_inter_op_thread_pool( + client->backend().inter_op_thread_pool()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + options.set_device_assignment(&device_assignment); + StatusOr> result_buffer_status = + executable_->Run(argument_buffers, options); + if (!result_buffer_status.ok()) { + results[replica] = result_buffer_status.status(); + return; + } + + // Transfer result out + results[replica] = + client->ShapedBufferToLiteral(*result_buffer_status.ValueOrDie()); + }); + } + } + + for (int replica = 0; replica < GetReplicaCount(); ++replica) { + const auto& statusor = results[replica]; + if (!statusor.ok()) { + return InternalError( + "Failed running replica %d (other replicas may have failed as well): " + "%s.", + replica, statusor.status().ToString().c_str()); + } + } + + return std::move(results[0]); +} + +LocalShapedBuffer* CompiledLocalComputation::ExecuteWithShapedBuffers( + tensorflow::gtl::ArraySlice argument_handles) { + LocalClient* client = GetOrCreateLocalClient(); + + std::vector argument_buffers; + argument_buffers.reserve(argument_handles.size()); + for (auto& handle : argument_handles) { + argument_buffers.push_back(handle->shaped_buffer().get()); + } + + // Execute + ExecutableRunOptions options; + options.set_allocator(client->backend().memory_allocator()); + options.set_inter_op_thread_pool(client->backend().inter_op_thread_pool()); + options.set_intra_op_thread_pool( + client->backend().eigen_intra_op_thread_pool_device()); + std::unique_ptr result_buffer = + executable_->Run(argument_buffers, options).ConsumeValueOrDie(); + + return new LocalShapedBuffer(std::move(result_buffer)); +} + +LocalComputation::LocalComputation(Computation computation) + : computation_(std::move(computation)) {} + +StatusOr LocalComputation::Compile( + const std::vector& argument_shapes, + const ExecutableBuildOptions* build_options) { + std::vector argument_shape_pointers; + argument_shape_pointers.reserve(argument_shapes.size()); + for (auto& argument_shape : argument_shapes) { + argument_shape_pointers.push_back(&argument_shape); + } + + LocalClient* client = GetOrCreateLocalClient(); + ExecutableBuildOptions options; + if (build_options != nullptr) { + options = *build_options; + } + TF_ASSIGN_OR_RETURN( + auto local_executable, + client->Compile(computation_, argument_shape_pointers, options)); + return new CompiledLocalComputation(std::move(local_executable)); +} + +const Computation& LocalComputation::computation() const { + return computation_; +} + +LocalComputationBuilder::LocalComputationBuilder(const string& computation_name) + : builder_(GetOrCreateLocalClient(), computation_name) {} + +void LocalComputationBuilder::SetOpMetadata(const OpMetadata& metadata) { + builder_.SetOpMetadata(metadata); +} + +void LocalComputationBuilder::ClearOpMetadata() { builder_.ClearOpMetadata(); } + +StatusOr LocalComputationBuilder::Build() { + TF_ASSIGN_OR_RETURN(Computation computation, builder_.Build()); + return new LocalComputation(std::move(computation)); +} + +ComputationDataHandle LocalComputationBuilder::Parameter(int64 parameter_number, + const Shape& shape, + const string& name) { + return builder_.Parameter(parameter_number, shape, name); +} + +std::unique_ptr LocalComputationBuilder::GetShape( + const ComputationDataHandle& operand) { + return builder_.GetShape(operand).ConsumeValueOrDie(); +} + +StatusOr LocalComputationBuilder::GetReturnValueShape() { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, builder_.GetProgramShape()); + return program_shape.result(); +} + +ComputationDataHandle LocalComputationBuilder::Infeed(const Shape& shape) { + return builder_.Infeed(shape); +} + +void LocalComputationBuilder::Outfeed(const ComputationDataHandle& operand, + const Shape& shape, + const string& outfeed_config) { + builder_.Outfeed(operand, shape, outfeed_config); +} + +ComputationDataHandle LocalComputationBuilder::ConstantLiteral( + const Literal& literal) { + return builder_.ConstantLiteral(literal); +} + +ComputationDataHandle LocalComputationBuilder::Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes) { + return builder_.Broadcast(operand, broadcast_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Pad( + const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config) { + return builder_.Pad(operand, padding_value, padding_config); +} + +ComputationDataHandle LocalComputationBuilder::Reshape( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes) { + return builder_.Reshape(operand, dimensions, new_sizes); +} + +ComputationDataHandle LocalComputationBuilder::Collapse( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions) { + return builder_.Collapse(operand, dimensions); +} + +ComputationDataHandle LocalComputationBuilder::CrossReplicaSum( + const ComputationDataHandle& operand) { + return builder_.CrossReplicaSum(operand); +} + +ComputationDataHandle LocalComputationBuilder::Slice( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides) { + return builder_.Slice(operand, start_indices, limit_indices, strides); +} + +ComputationDataHandle LocalComputationBuilder::DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes) { + return builder_.DynamicSlice(operand, start_indices, slice_sizes); +} + +ComputationDataHandle LocalComputationBuilder::DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices) { + return builder_.DynamicUpdateSlice(operand, update, start_indices); +} + +ComputationDataHandle LocalComputationBuilder::ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension) { + return builder_.ConcatInDim(operands, dimension); +} + +ComputationDataHandle +LocalComputationBuilder::SelectAndScatterWithGeneralPadding( + const ComputationDataHandle& operand, const LocalComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const LocalComputation& scatter) { + return builder_.SelectAndScatterWithGeneralPadding( + operand, select.computation(), window_dimensions, window_strides, padding, + source, init_value, scatter.computation()); +} + +ComputationDataHandle LocalComputationBuilder::Tuple( + tensorflow::gtl::ArraySlice elements) { + return builder_.Tuple(elements); +} + +ComputationDataHandle LocalComputationBuilder::GetTupleElement( + const ComputationDataHandle& tuple_data, int64 index) { + return builder_.GetTupleElement(tuple_data, index); +} + +ComputationDataHandle LocalComputationBuilder::Dot( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) { + return builder_.Dot(lhs, rhs); +} + +ComputationDataHandle LocalComputationBuilder::DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers) { + return builder_.DotGeneral(lhs, rhs, dimension_numbers); +} + +ComputationDataHandle LocalComputationBuilder::ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers) { + return builder_.ConvGeneralDilated(lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation, + dimension_numbers); +} + +ComputationDataHandle LocalComputationBuilder::ConvertElementType( + const ComputationDataHandle& operand, PrimitiveType new_element_type) { + return builder_.ConvertElementType(operand, new_element_type); +} + +ComputationDataHandle LocalComputationBuilder::Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands) { + return builder_.Call(local_computation.computation(), operands); +} + +ComputationDataHandle LocalComputationBuilder::Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation) { + return builder_.Transpose(operand, permutation); +} + +ComputationDataHandle LocalComputationBuilder::Rev( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions) { + return builder_.Rev(operand, dimensions); +} + +ComputationDataHandle LocalComputationBuilder::Map( + tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands) { + return builder_.Map(operands, local_computation.computation(), dimensions, + static_operands); +} + +ComputationDataHandle LocalComputationBuilder::Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce) { + return builder_.Reduce(operand, init_value, local_computation.computation(), + dimensions_to_reduce); +} + +ComputationDataHandle LocalComputationBuilder::ReduceWindowWithGeneralPadding( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice> padding) { + return builder_.ReduceWindowWithGeneralPadding( + operand, init_value, local_computation.computation(), window_dimensions, + window_strides, padding); +} + +ComputationDataHandle LocalComputationBuilder::RngNormal( + const ComputationDataHandle& mu, const ComputationDataHandle& sigma, + const Shape& shape) { + return builder_.RngNormal(mu, sigma, shape); +} + +ComputationDataHandle LocalComputationBuilder::RngUniform( + const ComputationDataHandle& a, const ComputationDataHandle& b, + const Shape& shape) { + return builder_.RngUniform(a, b, shape); +} + +ComputationDataHandle LocalComputationBuilder::While( + const LocalComputation& condition, const LocalComputation& body, + const ComputationDataHandle& init) { + return builder_.While(condition.computation(), body.computation(), init); +} + +ComputationDataHandle LocalComputationBuilder::Conditional( + const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const LocalComputation& true_computation, + const ComputationDataHandle& false_operand, + const LocalComputation& false_computation) { + return builder_.Conditional(predicate, true_operand, + true_computation.computation(), false_operand, + false_computation.computation()); +} + +#define _FORWARD(method_name, return_sig, args_sig, args) \ + return_sig LocalComputationBuilder::method_name args_sig { \ + return builder_.method_name args; \ + } + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand), (operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions), \ + (lhs, rhs, broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + const ComputationDataHandle& ehs), \ + (lhs, rhs, ehs)) + +_FORWARD_TRIOP(Select) +_FORWARD_TRIOP(Clamp) +_FORWARD_BINOP(Eq) +_FORWARD_BINOP(Ne) +_FORWARD_BINOP(Ge) +_FORWARD_BINOP(Gt) +_FORWARD_BINOP(Lt) +_FORWARD_BINOP(Le) +_FORWARD_BINOP(Add) +_FORWARD_BINOP(Sub) +_FORWARD_BINOP(Mul) +_FORWARD_BINOP(Div) +_FORWARD_BINOP(Rem) +_FORWARD_BINOP(Max) +_FORWARD_BINOP(Min) +_FORWARD_BINOP(And) +_FORWARD_BINOP(Or) +_FORWARD_UNOP(Not) +_FORWARD_UNOP(Abs) +_FORWARD_UNOP(Exp) +_FORWARD_UNOP(Floor) +_FORWARD_UNOP(Ceil) +_FORWARD_UNOP(Round) +_FORWARD_UNOP(Log) +_FORWARD_UNOP(Sign) +_FORWARD_UNOP(Cos) +_FORWARD_UNOP(Sin) +_FORWARD_UNOP(Tanh) +_FORWARD_UNOP(SqrtF32) +_FORWARD_UNOP(SquareF32) +_FORWARD_BINOP(Pow) +_FORWARD_UNOP(IsFinite) +_FORWARD_UNOP(ReciprocalF32) +_FORWARD_UNOP(Neg) +_FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP +#undef _FORWARD_TRIOP + +void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer) { + delete local_shaped_buffer; +} + +void DeleteCompiledLocalComputation(CompiledLocalComputation* computation) { + delete computation; +} + +void DeleteLocalComputation(LocalComputation* computation) { + delete computation; +} + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/local_computation_builder.h b/tensorflow/compiler/xla/python/local_computation_builder.h new file mode 100644 index 0000000000000000000000000000000000000000..d682204d26a819556db6f960ee639e763b6f4988 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.h @@ -0,0 +1,335 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ + +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/executable_build_options.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/shaped_buffer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +namespace swig { + +// Initializes the number of replicas that XLA will be initialized with (when +// first obtaining a handle to the local XLA service). If this is called after +// the handle to the local XLA service has been established, then an error is +// returned. +Status InitializeReplicaCount(int replica_count); + +// Returns the replica count that is currently set, regardless of whether the +// local XLA service has been instantiated yet or not. +int GetReplicaCount(); + +// Wraps the local client's infeed-transfer function. +// +// The default device ordinal (0) is used. +Status TransferToInfeedLocal(const Literal& literal); + +// Transfers the given literal to the infeed of the given replica. +// +// The replica number is resolved to an appropriate device ordinal. +Status TransferToInfeedLocalReplica(const Literal& literal, int replica_number); + +// Transfers a literal of the given shape from the outfeed of the given replica. +// +// The replica number is resolved to an appropriate device ordinal. +StatusOr > TransferFromOutfeedLocalReplica( + const Shape& shape, int replica_number); + +// Wraps a ScopedShapedBuffer produced by copying a literal "to +// device," i.e. copying a literal to a scoped buffer via the local +// client. +class LocalShapedBuffer { + public: + static LocalShapedBuffer* FromLiteral( + const Literal& argument, + const tensorflow::gtl::optional& shape_with_layout); + LocalShapedBuffer(std::unique_ptr shaped_buffer); + const std::unique_ptr& shaped_buffer() const; + std::unique_ptr ToLiteral() const; + + private: + std::unique_ptr shaped_buffer_; +}; + +// Wraps a LocalExecutable produced by compiling a +// LocalComputation. The Execute method forwards to that of the +// underlying LocalExecutable, and additionally handles tranferring +// arguments and return values in and back out of the client library's +// local client. This class is intended to be made available to Python +// via SWIG. +class CompiledLocalComputation { + public: + CompiledLocalComputation(std::unique_ptr executable); + + // Execute the computation with the given argument literals, and + // with optionally-specified argument layouts. The literals will be + // re-laid out according to the corresponding elements of + // shapes_with_layout. + StatusOr > Execute( + const std::vector& arguments, + const std::vector >& shapes_with_layout); + + LocalShapedBuffer* ExecuteWithShapedBuffers( + tensorflow::gtl::ArraySlice argument_handles); + + private: + std::unique_ptr executable_; +}; + +// Wraps a Computation produced by a LocalComputationBuilder. The +// Compile method compiles the computation to a (local) executable via +// the client library's local client. This class is intended to be +// made available to Python via SWIG. +class LocalComputation { + public: + LocalComputation(Computation computation); + StatusOr Compile( + const std::vector& argument_shapes, + const ExecutableBuildOptions* build_options); + const Computation& computation() const; + + private: + Computation computation_; +}; + +// Wraps the ComputationBuilder API in order to: +// - Support consumption by SWIG in order to be made available to +// Python. +// - Set up the underlying builder to use the client library's +// LocalClient. +// - Wrap Computations in LocalComputations for Python access. +// - Correspondingly unwrap incoming LocalComputations. +class LocalComputationBuilder { + public: + LocalComputationBuilder(const string& computation_name); + + void SetOpMetadata(const OpMetadata& metadata); + void ClearOpMetadata(); + + // Returns an owned LocalComputation to the caller on success. + StatusOr Build(); + + ComputationDataHandle Parameter(int64 parameter_number, const Shape& shape, + const string& name); + + std::unique_ptr GetShape(const ComputationDataHandle& operand); + + // Returns the shape of the current return value for the computation. + StatusOr GetReturnValueShape(); + + ComputationDataHandle Infeed(const Shape& shape); + + void Outfeed(const ComputationDataHandle& operand, const Shape& shape, + const string& outfeed_config); + + ComputationDataHandle ConstantLiteral(const Literal& literal); + + ComputationDataHandle Broadcast( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice broadcast_sizes); + + ComputationDataHandle Pad(const ComputationDataHandle& operand, + const ComputationDataHandle& padding_value, + const PaddingConfig& padding_config); + + ComputationDataHandle Reshape(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice new_sizes); + + ComputationDataHandle Collapse(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions); + + ComputationDataHandle CrossReplicaSum(const ComputationDataHandle& operand); + + ComputationDataHandle Slice(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice start_indices, + tensorflow::gtl::ArraySlice limit_indices, + tensorflow::gtl::ArraySlice strides); + + ComputationDataHandle DynamicSlice( + const ComputationDataHandle& operand, + const ComputationDataHandle& start_indices, + tensorflow::gtl::ArraySlice slice_sizes); + + ComputationDataHandle DynamicUpdateSlice( + const ComputationDataHandle& operand, const ComputationDataHandle& update, + const ComputationDataHandle& start_indices); + + ComputationDataHandle ConcatInDim( + tensorflow::gtl::ArraySlice operands, + int64 dimension); + + ComputationDataHandle SelectAndScatterWithGeneralPadding( + const ComputationDataHandle& operand, const LocalComputation& select, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice > padding, + const ComputationDataHandle& source, + const ComputationDataHandle& init_value, const LocalComputation& scatter); + + ComputationDataHandle Tuple( + tensorflow::gtl::ArraySlice elements); + + ComputationDataHandle GetTupleElement(const ComputationDataHandle& tuple_data, + int64 index); + + ComputationDataHandle Dot(const ComputationDataHandle& lhs, + const ComputationDataHandle& rhs); + + ComputationDataHandle DotGeneral( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + const DotDimensionNumbers& dimension_numbers); + + ComputationDataHandle ConvGeneralDilated( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice > padding, + tensorflow::gtl::ArraySlice lhs_dilation, + tensorflow::gtl::ArraySlice rhs_dilation, + const ConvolutionDimensionNumbers& dimension_numbers); + + ComputationDataHandle ConvertElementType(const ComputationDataHandle& operand, + PrimitiveType new_element_type); + + ComputationDataHandle Call( + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice operands); + + ComputationDataHandle Transpose( + const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice permutation); + + ComputationDataHandle Rev(const ComputationDataHandle& operand, + tensorflow::gtl::ArraySlice dimensions); + + ComputationDataHandle Map( + tensorflow::gtl::ArraySlice operands, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice static_operands); + + ComputationDataHandle Reduce( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice dimensions_to_reduce); + + ComputationDataHandle ReduceWindowWithGeneralPadding( + const ComputationDataHandle& operand, + const ComputationDataHandle& init_value, + const LocalComputation& local_computation, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides, + tensorflow::gtl::ArraySlice > padding); + + ComputationDataHandle RngNormal(const ComputationDataHandle& mu, + const ComputationDataHandle& sigma, + const Shape& shape); + + ComputationDataHandle RngUniform(const ComputationDataHandle& a, + const ComputationDataHandle& b, + const Shape& shape); + + ComputationDataHandle While(const LocalComputation& condition, + const LocalComputation& body, + const ComputationDataHandle& init); + + ComputationDataHandle Conditional(const ComputationDataHandle& predicate, + const ComputationDataHandle& true_operand, + const LocalComputation& true_computation, + const ComputationDataHandle& false_operand, + const LocalComputation& false_computation); + +#define _FORWARD(method_name, return_sig, args_sig) \ + return_sig method_name args_sig; + +#define _FORWARD_UNOP(method_name) \ + _FORWARD(method_name, ComputationDataHandle, \ + (const ComputationDataHandle& operand)) + +#define _FORWARD_BINOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + tensorflow::gtl::ArraySlice broadcast_dimensions)) + +#define _FORWARD_TRIOP(method_name) \ + _FORWARD( \ + method_name, ComputationDataHandle, \ + (const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, \ + const ComputationDataHandle& ehs)) + + _FORWARD_TRIOP(Select) + _FORWARD_TRIOP(Clamp) + _FORWARD_BINOP(Eq) + _FORWARD_BINOP(Ne) + _FORWARD_BINOP(Ge) + _FORWARD_BINOP(Gt) + _FORWARD_BINOP(Lt) + _FORWARD_BINOP(Le) + _FORWARD_BINOP(Add) + _FORWARD_BINOP(Sub) + _FORWARD_BINOP(Mul) + _FORWARD_BINOP(Div) + _FORWARD_BINOP(Rem) + _FORWARD_BINOP(Max) + _FORWARD_BINOP(Min) + _FORWARD_BINOP(And) + _FORWARD_BINOP(Or) + _FORWARD_UNOP(Not) + _FORWARD_UNOP(Abs) + _FORWARD_UNOP(Exp) + _FORWARD_UNOP(Floor) + _FORWARD_UNOP(Ceil) + _FORWARD_UNOP(Round) + _FORWARD_UNOP(Log) + _FORWARD_UNOP(Sign) + _FORWARD_UNOP(Cos) + _FORWARD_UNOP(Sin) + _FORWARD_UNOP(Tanh) + _FORWARD_UNOP(SqrtF32) + _FORWARD_UNOP(SquareF32) + _FORWARD_BINOP(Pow) + _FORWARD_UNOP(IsFinite) + _FORWARD_UNOP(ReciprocalF32) + _FORWARD_UNOP(Neg) + _FORWARD_UNOP(Sort) + +#undef _FORWARD +#undef _FORWARD_UNOP +#undef _FORWARD_BINOP +#undef _FORWARD_TRIOP + + private: + ComputationBuilder builder_; +}; + +// Functions for freeing resources from the Python side. +void DeleteLocalShapedBuffer(LocalShapedBuffer* local_shaped_buffer); +void DeleteCompiledLocalComputation(CompiledLocalComputation* computation); +void DeleteLocalComputation(LocalComputation* computation); + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_LOCAL_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/python/local_computation_builder.i b/tensorflow/compiler/xla/python/local_computation_builder.i new file mode 100644 index 0000000000000000000000000000000000000000..fa6c8bfa296c7f80e95e4afc4a6062d133643c53 --- /dev/null +++ b/tensorflow/compiler/xla/python/local_computation_builder.i @@ -0,0 +1,937 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// SWIG typemaps and declarations for building, compiling, and +// executing XLA computations, wrapping most of what is declared in +// local_computation_builder.h. +// +// The typemaps below implement/assert the following correspondences +// (with elaborations below): +// +// C++ Python +// -------------------------------------+--------------------------------------- +// ComputationDataHandle <-> int +// ArraySlice <- sequence of int +// ArraySlice <- sequence of int +// Literal <-> (nested tuple of) numpy ndarray +// std::vector <- sequence of (nested tuple of) ndarray +// Shape -> pair holding (dtype, dimensions) +// <- object duck-typed as xla_client.Shape +// std::vector <- sequence of xla_client.Shape objects +// PrimitiveType <- int +// ArraySlice> <- sequence of int pairs +// PaddingConfig proto <- corresponding Python proto +// ConvolutionDimensionNumbers proto <- corresponding Python proto +// DotDimensionNumbers proto <- corresponding Python proto +// +// Arrows indicate whether a conversion only ever occurs in one +// direction, or whether it is maintained bidirectionally. +// +// The Python objects corresponding to C++ Literals have the type: +// +// T = ndarray | (T, ...) +// +// where a terminal numpy ndarray translates to a Literal with a +// non-tuple Shape, an XLA primitive element type corresponding to the +// ndarray's dtype. Meanwhile, a non-terminal "tuple of T" translates +// to a tuple-shaped Literal whose tuple components are translated +// recursively. For example, if x is a numpy ndarray in Python, with +// shape (2, 3) and dtype of dtype('float32'), then x translates to a +// Literal with rank 2, dimension 2 and 3, and XLA primitive type +// F32. Meanwhile, +// +// (x, (x, x), (x,)), +// +// translates to a tuple-shaped XLA Literal, whose component subshapes +// are a 2x3 F32-shaped literal followed by two tuple-shaped literals. +// +// Shapes output by C++ become Python objects with the type: +// +// T = (dtype, S) +// S = DIMENSIONS | TUPLE_SHAPES +// DIMENSIONS = (int, ...) +// TUPLE_SHAPES = (T, ...) +// +// In the pair described by the T rule, the terminal dtype determines +// whether S expands as DIMENSIONS or TUPLE_SHAPES. Namely if it is +// dtype('O'), numpy's object dtype, the structure represents a tuple +// shape and the expansion of the non-terminal S is +// TUPLE_SHAPES. Otherwise, dtype describes a primitive element type +// and S expands into DIMENSIONS giving dimension sizes. For example: +// +// (dtype('float32'), (3, 5, 7)) +// +// describes a 3x5x7 array of F32s, and +// +// (dtype('O'), ((dtype('float32'), (2, 3)), +// (dtype('float64'), (4, 5)))) +// +// describes a tuple shape with two subshapes: the first a 2x3 F32, +// and the other a 4x5 F64. +// +// The Python int corresponding to a PrimitiveType enum must be valid +// per xla_data.proto (e.g. xla_data.PRED, xla_data.F32). +// +// The SWIG object wrappers generated by this file are not intended +// for end use, but rather for internal use in the Python XLA client, +// xla_client.py. +// +// One central reason for the Python-side indirection is that the +// Python-side objects produced by the typemaps in this file are +// further packaged up by xla_client before being passed on. For +// instance, xla_client wraps the long produced for a C++ +// ComputationDataHandle in a Python ComputationDataHandle proto, +// rather than exposing a raw long outside of the client. Similarly, +// the Python pair produced for a C++ Shape is further wrapped in a +// Python class (xla_client.Shape) so as not to expose the raw pair +// externally. +// +// Other SWIG object wrappers (e.g. of LocalComputation) are further +// wrapped by xla_client in order to set up a custom destructor that +// triggers memory deallocation on the C++ side. + +%module(threads="1") local_computation_builder + +// Keep the GIL except where explicitly specified. +%nothread; + +%include "tensorflow/python/platform/base.i" + +%{ +// Must be included first +#include "tensorflow/python/lib/core/numpy.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/python/local_computation_builder.h" + +using namespace xla; +using namespace xla::swig; + +namespace xla { +namespace swig { + +bool GetIntAttr(PyObject* o, const char* field, int64* result) { + PyObject* fo = PyObject_GetAttrString(o, field); + if (!fo) { + return false; + } + const int64 value = numpy::PyIntOrPyLongToLong(fo); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(fo); + return false; + } + Py_DECREF(fo); + *result = value; + return true; +} + +} +} +%} + +// Required to use PyArray_* functions. +%init %{ +tensorflow::ImportNumpy(); +%} + +// ComputationDataHandle + +%typemap(in) const ComputationDataHandle& (ComputationDataHandle temp) { + const int64 handle = numpy::PyIntOrPyLongToLong($input); + if (handle == -1 && PyErr_Occurred()) { + return NULL; + } + temp.set_handle(handle); + $1 = &temp; +} + +%typemap(out) ComputationDataHandle { + $result = numpy::LongToPyIntOrPyLong($1.handle()); +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::CompiledLocalComputation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + +%typemap(out) StatusOr< std::unique_ptr > { + if ($1.ok()) { + std::unique_ptr value = $1.ConsumeValueOrDie(); + $result = numpy::PyObjectFromXlaLiteral(*value); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + auto* value = $1.ValueOrDie(); + { + auto* $1 = value; + $typemap(out, xla::swig::LocalComputation*) + } + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + +%typemap(out) StatusOr { + if ($1.ok()) { + $result = numpy::PyShapeInfoFromXlaShape($1.ConsumeValueOrDie()); + } else { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } +} + +%typemap(out) Status { + if (!$1.ok()) { + PyErr_SetString( + PyExc_RuntimeError, $1.ToString().c_str()); + return NULL; + } + $result = Py_None; +} + +// ArraySlice + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + Py_DECREF(o); + return NULL; + } + temps[i] = numpy::PyIntOrPyLongToLong(py_int); + if (temps[i] == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// ComputationDataHandle + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.resize(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + PyObject* py_int = numpy::PyNumberToPyInt(o); + if (!py_int) { + PyErr_SetString( + PyExc_TypeError, + "Argument sequence element cannot be converted to int"); + return NULL; + } + const int64 handle = numpy::PyIntOrPyLongToLong(py_int); + if (handle == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + Py_DECREF(o); + return NULL; + } + temps[i].set_handle(handle); + Py_DECREF(py_int); + Py_DECREF(o); + } + $1 = temps; +} + +// LocalShapedBuffer* + +%typemap(in) tensorflow::gtl::ArraySlice + (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + LocalShapedBuffer* lsbp; + if ((SWIG_ConvertPtr(o, (void**) &lsbp, $descriptor(xla::swig::LocalShapedBuffer*), + SWIG_POINTER_EXCEPTION)) == -1) { + return NULL; + } + temps.push_back(lsbp); + Py_DECREF(o); + } + $1 = temps; +} + +// Literal + +%typemap(in) const Literal& (StatusOr< std::unique_ptr > literal_status) { + literal_status = numpy::XlaLiteralFromPyObject($input); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + return NULL; + } + $1 = literal_status.ValueOrDie().get(); +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyObjectFromXlaLiteral(*$1); +} + +%typemap(out) StatusOr< std::unique_ptr > { + if (!$1.ok()) { + PyErr_SetString(PyExc_RuntimeError, $1.status().ToString().c_str()); + return NULL; + } + $result = numpy::PyObjectFromXlaLiteral(*$1.ValueOrDie()); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr< std::unique_ptr > literal_status = numpy::XlaLiteralFromPyObject(o); + if (!literal_status.ok()) { + PyErr_SetString(PyExc_RuntimeError, literal_status.status().ToString().c_str()); + Py_DECREF(o); + return NULL; + } + temps.push_back(std::move(*literal_status.ConsumeValueOrDie())); + Py_DECREF(o); + } + $1 = &temps; +} + +// OpMetadata + +%typemap(in) const OpMetadata& (OpMetadata temp) { + StatusOr statusor = numpy::OpMetadataFromPyObject($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +// Shape + +%typemap(in) const Shape& (Shape temp) { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; +} + +%typemap(in) const tensorflow::gtl::optional& ( + tensorflow::gtl::optional temp) { + if ($input == Py_None) { + temp = tensorflow::gtl::nullopt; + $1 = &temp; + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape($input); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temp = std::move(statusor).ValueOrDie(); + $1 = &temp; + } +} + +%typemap(out) std::unique_ptr { + $result = numpy::PyShapeInfoFromXlaShape(*$1); +} + +%typemap(in) const std::vector& (std::vector temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + $1 = &temps; +} + +%typemap(in) const std::vector >& ( + std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (o == Py_None) { + temps.push_back(tensorflow::gtl::nullopt); + } else { + StatusOr statusor = numpy::XlaShapeFromPyShape(o); + Py_DECREF(o); + if (!statusor.ok()) { + PyErr_SetString(PyExc_RuntimeError, statusor.status().ToString().c_str()); + return NULL; + } + temps.push_back(statusor.ConsumeValueOrDie()); + } + } + $1 = &temps; +} + +// PrimitiveType + +%typemap(in) PrimitiveType { + PyObject* py_int = numpy::PyNumberToPyInt($input); + if (!py_int) { + PyErr_SetString(PyExc_TypeError, "Argument cannot be converted to int"); + return NULL; + } + const long value = numpy::PyIntOrPyLongToLong(py_int); + if (value == -1 && PyErr_Occurred()) { + Py_DECREF(py_int); + return NULL; + } + if (!PrimitiveType_IsValid(value)) { + PyErr_SetString( + PyExc_TypeError, "Argument not valid for PrimitiveType enum"); + Py_DECREF(py_int); + return NULL; + } + $1 = static_cast(value); +} + +// ArraySlice> + +%typemap(in) tensorflow::gtl::ArraySlice > + (std::vector > temps) { + if (!PySequence_Check($input)) { + PyErr_SetString(PyExc_TypeError, "Argument is not a sequence"); + return NULL; + } + const int size = PySequence_Size($input); + temps.reserve(size); + for (int i = 0; i < size; ++i) { + PyObject* o = PySequence_GetItem($input, i); + if (!o) { + return NULL; + } + PyObject* first = PyTuple_GetItem(o, 0); + if (!first) { + Py_DECREF(o); + return NULL; + } + PyObject* first_pyint = numpy::PyNumberToPyInt(first); + if (!first_pyint) { + PyErr_SetString( + PyExc_TypeError, + "First pair item cannot be converted to int"); + Py_DECREF(o); + return NULL; + } + PyObject* second = PyTuple_GetItem(o, 1); + if (!second) { + Py_DECREF(o); + Py_DECREF(first_pyint); + return NULL; + } + PyObject* second_pyint = numpy::PyNumberToPyInt(second); + if (!second_pyint) { + PyErr_SetString( + PyExc_TypeError, + "Second pair item cannot be converted to int"); + Py_DECREF(o); + Py_DECREF(first_pyint); + return NULL; + } + const int64 first_value = numpy::PyIntOrPyLongToLong(first_pyint); + if (first_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + return NULL; + } + const int64 second_value = numpy::PyIntOrPyLongToLong(second_pyint); + if (second_value == -1 && PyErr_Occurred()) { + Py_DECREF(o); + Py_DECREF(first_pyint); + Py_DECREF(second_pyint); + return NULL; + } + temps.push_back(std::make_pair(first_value, second_value)); + Py_DECREF(o); + } + $1 = temps; +} + +// DotDimensionNumbers + +%typemap(in) const DotDimensionNumbers& + (DotDimensionNumbers dimension_numbers) { + int length; + + /* lhs_contracting_dimensions */ + PyObject* lhs_contracting_dimensions = PyObject_GetAttrString( + $input, "lhs_contracting_dimensions"); + if (!lhs_contracting_dimensions) { + return NULL; + } + + length = PySequence_Size(lhs_contracting_dimensions); + if (length == -1) { + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(lhs_contracting_dimensions, i); + if (!item) { + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(lhs_contracting_dimensions); + return NULL; + } + dimension_numbers.add_lhs_contracting_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(lhs_contracting_dimensions); + + /* rhs_contracting_dimensions */ + PyObject* rhs_contracting_dimensions = PyObject_GetAttrString( + $input, "rhs_contracting_dimensions"); + if (!lhs_contracting_dimensions) { + return NULL; + } + + length = PySequence_Size(rhs_contracting_dimensions); + if (length == -1) { + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(rhs_contracting_dimensions, i); + if (!item) { + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(rhs_contracting_dimensions); + return NULL; + } + dimension_numbers.add_rhs_contracting_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(rhs_contracting_dimensions); + + /* lhs_batch_dimensions */ + PyObject* lhs_batch_dimensions = PyObject_GetAttrString( + $input, "lhs_batch_dimensions"); + if (!lhs_batch_dimensions) { + return NULL; + } + + length = PySequence_Size(lhs_batch_dimensions); + if (length == -1) { + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(lhs_batch_dimensions, i); + if (!item) { + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(lhs_batch_dimensions); + return NULL; + } + dimension_numbers.add_lhs_batch_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(lhs_batch_dimensions); + + /* rhs_batch_dimensions */ + PyObject* rhs_batch_dimensions = PyObject_GetAttrString( + $input, "rhs_batch_dimensions"); + if (!rhs_batch_dimensions) { + return NULL; + } + + length = PySequence_Size(rhs_batch_dimensions); + if (length == -1) { + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(rhs_batch_dimensions, i); + if (!item) { + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(rhs_batch_dimensions); + return NULL; + } + dimension_numbers.add_rhs_batch_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(rhs_batch_dimensions); + + $1 = &dimension_numbers; +} + +// PaddingConfig + +%typemap(in) const PaddingConfig& + (PaddingConfig padding_config) { + PyObject* dimensions = PyObject_GetAttrString($input, "dimensions"); + if (!dimensions) { + return NULL; + } + + int length = PySequence_Size(dimensions); + if (length == -1) { + Py_DECREF(dimensions); + return NULL; + } + + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(dimensions, i); + if (!item) { + Py_DECREF(dimensions); + return NULL; + } + int64 edge_padding_low, edge_padding_high, interior_padding; + if (!GetIntAttr(item, "edge_padding_low", &edge_padding_low) + || !GetIntAttr(item, "edge_padding_high", &edge_padding_high) + || !GetIntAttr(item, "interior_padding", &interior_padding)) { + Py_DECREF(item); + Py_DECREF(dimensions); + return NULL; + } + Py_DECREF(item); + + PaddingConfig::PaddingConfigDimension* dimension = + padding_config.add_dimensions(); + dimension->set_edge_padding_low(edge_padding_low); + dimension->set_edge_padding_high(edge_padding_high); + dimension->set_interior_padding(interior_padding); + } + Py_DECREF(dimensions); + + $1 = &padding_config; +} + +// ConvolutionDimensionNumbers + +%typemap(in) const ConvolutionDimensionNumbers& + (ConvolutionDimensionNumbers dimension_numbers) { + int64 value; + + if (!GetIntAttr($input, "input_batch_dimension", &value)) { + return NULL; + } + dimension_numbers.set_input_batch_dimension(value); + + if (!GetIntAttr($input, "input_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_input_feature_dimension(value); + + if (!GetIntAttr($input, "output_batch_dimension", &value)) { + return NULL; + } + dimension_numbers.set_output_batch_dimension(value); + + if (!GetIntAttr($input, "output_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_output_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_kernel_output_feature_dimension(value); + + if (!GetIntAttr($input, "kernel_input_feature_dimension", &value)) { + return NULL; + } + dimension_numbers.set_kernel_input_feature_dimension(value); + + PyObject* o; + int length; + + o = PyObject_GetAttrString($input, "input_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_input_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + o = PyObject_GetAttrString($input, "kernel_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_kernel_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + o = PyObject_GetAttrString($input, "output_spatial_dimensions"); + if (!o) { + return NULL; + } + length = PySequence_Size(o); + if (length == -1) { + Py_DECREF(o); + return NULL; + } + for (int i = 0; i < length; ++i) { + PyObject* item = PySequence_GetItem(o, i); + if (!item) { + Py_DECREF(o); + return NULL; + } + const int64 dimension = numpy::PyIntOrPyLongToLong(item); + if (dimension == -1 && PyErr_Occurred()) { + Py_DECREF(item); + Py_DECREF(o); + return NULL; + } + dimension_numbers.add_output_spatial_dimensions(dimension); + Py_DECREF(item); + } + Py_DECREF(o); + + $1 = &dimension_numbers; +} + +// ExecutableBuildOptions + +%typemap(in) const ExecutableBuildOptions* + (ExecutableBuildOptions build_options) { + if ($input == Py_None) { + $1 = NULL; + } else { + PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph"); + if (!o) { + return NULL; + } + if (o != Py_None) { + if (!PyString_Check(o)) { + PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None."); + return NULL; + } + build_options.set_generate_hlo_graph(PyString_AsString(o)); + } + Py_DECREF(o); + + $1 = &build_options; + } +} + +%ignoreall +%unignore xla; +%unignore xla::swig; +%unignore xla::swig::InitializeReplicaCount; +%unignore xla::swig::GetReplicaCount; +%unignore xla::swig::TransferToInfeedLocal; +%unignore xla::swig::TransferToInfeedLocalReplica; +%unignore xla::swig::TransferFromOutfeedLocalReplica; +%unignore xla::swig::LocalShapedBuffer; +%unignore xla::swig::LocalShapedBuffer::FromLiteral; +%unignore xla::swig::LocalShapedBuffer::ToLiteral; +%unignore xla::swig::CompiledLocalComputation; +%unignore xla::swig::CompiledLocalComputation::Execute; +%unignore xla::swig::CompiledLocalComputation::ExecuteWithShapedBuffers; +%unignore xla::swig::LocalComputation; +%unignore xla::swig::LocalComputation::Compile; +%unignore xla::swig::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::LocalComputationBuilder; +%unignore xla::swig::LocalComputationBuilder::Build; +%unignore xla::swig::LocalComputationBuilder::SetOpMetadata; +%unignore xla::swig::LocalComputationBuilder::ClearOpMetadata; +%unignore xla::swig::LocalComputationBuilder::Parameter; +%unignore xla::swig::LocalComputationBuilder::GetShape; +%unignore xla::swig::LocalComputationBuilder::GetReturnValueShape; +%unignore xla::swig::LocalComputationBuilder::Infeed; +%unignore xla::swig::LocalComputationBuilder::Outfeed; +%unignore xla::swig::LocalComputationBuilder::ConstantLiteral; +%unignore xla::swig::LocalComputationBuilder::ConstantR0; +%unignore xla::swig::LocalComputationBuilder::Broadcast; +%unignore xla::swig::LocalComputationBuilder::Pad; +%unignore xla::swig::LocalComputationBuilder::Reshape; +%unignore xla::swig::LocalComputationBuilder::Collapse; +%unignore xla::swig::LocalComputationBuilder::CrossReplicaSum; +%unignore xla::swig::LocalComputationBuilder::Slice; +%unignore xla::swig::LocalComputationBuilder::DynamicSlice; +%unignore xla::swig::LocalComputationBuilder::DynamicUpdateSlice; +%unignore xla::swig::LocalComputationBuilder::ConcatInDim; +%unignore xla::swig::LocalComputationBuilder::SelectAndScatterWithGeneralPadding; +%unignore xla::swig::LocalComputationBuilder::Select; +%unignore xla::swig::LocalComputationBuilder::Tuple; +%unignore xla::swig::LocalComputationBuilder::GetTupleElement; +%unignore xla::swig::LocalComputationBuilder::ConvertElementType; +%unignore xla::swig::LocalComputationBuilder::Call; +%unignore xla::swig::LocalComputationBuilder::Transpose; +%unignore xla::swig::LocalComputationBuilder::Rev; +%unignore xla::swig::LocalComputationBuilder::Clamp; +%unignore xla::swig::LocalComputationBuilder::Map; +%unignore xla::swig::LocalComputationBuilder::Reduce; +%unignore xla::swig::LocalComputationBuilder::ReduceWindowWithGeneralPadding; +%unignore xla::swig::LocalComputationBuilder::RngNormal; +%unignore xla::swig::LocalComputationBuilder::RngUniform; +%unignore xla::swig::LocalComputationBuilder::RngBernoulli; +%unignore xla::swig::LocalComputationBuilder::While; +%unignore xla::swig::LocalComputationBuilder::Conditional; +%unignore xla::swig::LocalComputationBuilder::Eq; +%unignore xla::swig::LocalComputationBuilder::Ne; +%unignore xla::swig::LocalComputationBuilder::Ge; +%unignore xla::swig::LocalComputationBuilder::Gt; +%unignore xla::swig::LocalComputationBuilder::Lt; +%unignore xla::swig::LocalComputationBuilder::Le; +%unignore xla::swig::LocalComputationBuilder::Dot; +%unignore xla::swig::LocalComputationBuilder::DotGeneral; +%unignore xla::swig::LocalComputationBuilder::ConvGeneralDilated; +%unignore xla::swig::LocalComputationBuilder::Add; +%unignore xla::swig::LocalComputationBuilder::Sub; +%unignore xla::swig::LocalComputationBuilder::Mul; +%unignore xla::swig::LocalComputationBuilder::Div; +%unignore xla::swig::LocalComputationBuilder::Rem; +%unignore xla::swig::LocalComputationBuilder::Max; +%unignore xla::swig::LocalComputationBuilder::Min; +%unignore xla::swig::LocalComputationBuilder::And; +%unignore xla::swig::LocalComputationBuilder::Or; +%unignore xla::swig::LocalComputationBuilder::Not; +%unignore xla::swig::LocalComputationBuilder::Abs; +%unignore xla::swig::LocalComputationBuilder::Exp; +%unignore xla::swig::LocalComputationBuilder::Floor; +%unignore xla::swig::LocalComputationBuilder::Ceil; +%unignore xla::swig::LocalComputationBuilder::Round; +%unignore xla::swig::LocalComputationBuilder::Log; +%unignore xla::swig::LocalComputationBuilder::Sign; +%unignore xla::swig::LocalComputationBuilder::Cos; +%unignore xla::swig::LocalComputationBuilder::Sin; +%unignore xla::swig::LocalComputationBuilder::Tanh; +%unignore xla::swig::LocalComputationBuilder::SqrtF32; +%unignore xla::swig::LocalComputationBuilder::SquareF32; +%unignore xla::swig::LocalComputationBuilder::Pow; +%unignore xla::swig::LocalComputationBuilder::IsFinite; +%unignore xla::swig::LocalComputationBuilder::ReciprocalF32; +%unignore xla::swig::LocalComputationBuilder::Neg; +%unignore xla::swig::LocalComputationBuilder::Sort; +%unignore xla::swig::DeleteLocalShapedBuffer; +%unignore xla::swig::DeleteLocalComputation; +%unignore xla::swig::DeleteCompiledLocalComputation; + +%thread; +%include "tensorflow/compiler/xla/python/local_computation_builder.h" +%nothread; + +%unignoreall diff --git a/tensorflow/compiler/xla/python/numpy_bridge.cc b/tensorflow/compiler/xla/python/numpy_bridge.cc new file mode 100644 index 0000000000000000000000000000000000000000..3d87480728aab1d4ebbc71c6c7504d37cae5edaf --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.cc @@ -0,0 +1,517 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/python/numpy_bridge.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type) { + switch (primitive_type) { + case PRED: + return NPY_BOOL; + case S8: + return NPY_INT8; + case S16: + return NPY_INT16; + case S32: + return NPY_INT32; + case S64: + return NPY_INT64; + case U8: + return NPY_UINT8; + case U16: + return NPY_UINT16; + case U32: + return NPY_UINT32; + case U64: + return NPY_UINT64; + case F16: + return NPY_FLOAT16; + case F32: + return NPY_FLOAT32; + case F64: + return NPY_FLOAT64; + case TUPLE: + return NPY_OBJECT; + default: + LOG(FATAL) << "No Numpy type for XLA primitive type " << primitive_type; + } +} + +PrimitiveType NumpyTypeToPrimitiveType(int np_type) { + switch (np_type) { + case NPY_BOOL: + return PRED; + case NPY_INT8: + return S8; + case NPY_INT16: + return S16; + case NPY_INT32: + return S32; + case NPY_INT64: + return S64; + case NPY_UINT8: + return U8; + case NPY_UINT16: + return U16; + case NPY_UINT32: + return U32; + case NPY_UINT64: + return U64; + case NPY_FLOAT16: + return F16; + case NPY_FLOAT32: + return F32; + case NPY_FLOAT64: + return F64; + case NPY_OBJECT: + return TUPLE; + default: + LOG(FATAL) << "No XLA primitive type for Numpy type " << np_type; + } +} + +bool NumpyTypeIsValid(int np_type) { + switch (np_type) { + case NPY_BOOL: + case NPY_INT8: + case NPY_INT16: + case NPY_INT32: + case NPY_INT64: + case NPY_UINT8: + case NPY_UINT16: + case NPY_UINT32: + case NPY_UINT64: + case NPY_FLOAT16: + case NPY_FLOAT32: + case NPY_FLOAT64: + case NPY_OBJECT: + return true; + default: + return false; + } +} + +PyObject* PyShapeInfoFromXlaShape(const Shape& shape) { + int np_typenum = PrimitiveTypeToNumpyType(shape.element_type()); + PyArray_Descr* np_dtype = PyArray_DescrFromType(np_typenum); + + PyObject* dimensions; + if (ShapeUtil::IsTuple(shape)) { + int num_elements = ShapeUtil::TupleElementCount(shape); + dimensions = PyTuple_New(ShapeUtil::TupleElementCount(shape)); + for (int i = 0; i < num_elements; ++i) { + PyTuple_SET_ITEM( + dimensions, i, + PyShapeInfoFromXlaShape(ShapeUtil::GetTupleElementShape(shape, i))); + } + } else { + int rank = ShapeUtil::Rank(shape); + dimensions = PyTuple_New(rank); + for (int i = 0; i < rank; ++i) { + PyTuple_SET_ITEM(dimensions, i, + LongToPyIntOrPyLong(ShapeUtil::GetDimension(shape, i))); + } + } + return PyTuple_Pack(2, np_dtype, dimensions); +} + +// Precondition: o->ob_type == &PyArrayDescr_Type +static int NumpyTypenum(PyObject* o) { + return reinterpret_cast(o)->type_num; +} + +// Extracts the string held inside r and returns it as a C++ string. +// +// NOTE: this is an internal helper for conversion to a C++, and so decrefs r. +static string ExtractStringAndDecref(PyObject* r) { + auto error = [r] { + return tensorflow::strings::Printf("", r); + }; + if (r == nullptr) { + return error(); + } +#if PY_MAJOR_VERSION < 3 + string result = PyString_AsString(r); +#else + PyObject* bytes = PyUnicode_AsEncodedString(r, 0, 0); + if (bytes == nullptr) { + return error(); + } + CHECK(PyBytes_Check(bytes)); + string result = PyBytes_AsString(bytes); + Py_DECREF(bytes); +#endif + Py_DECREF(r); + return result; +} + +// Safely returns a str of the given Python object o as a C++ string. +static string PyObjectCppStr(PyObject* o) { + PyObject* s = PyObject_Str(o); + return ExtractStringAndDecref(s); +} + +// Safely returns a repr of the given Python object o as a C++ string. +static string PyObjectCppRepr(PyObject* o) { + PyObject* r = PyObject_Repr(o); + return ExtractStringAndDecref(r); +} + +StatusOr XlaShapeFromPyShape(PyObject* o) { + auto error = [o](const string& prefix) { + return InvalidArgument("%s; got %s", prefix.c_str(), + PyObjectCppRepr(o).c_str()); + }; + + auto get_attr = [o, &error](const string& field) -> StatusOr { + PyObject* result = + PyObject_GetAttrString(o, const_cast(field.c_str())); + if (result == nullptr) { + return error(tensorflow::strings::StrCat( + "Failed to get attribute of Shape object:", field)); + } + return result; + }; + + auto call_method = [o, &error](const string& method) -> StatusOr { + PyObject* result = + PyObject_CallMethod(o, const_cast(method.c_str()), nullptr); + if (result == nullptr) { + return error(tensorflow::strings::StrCat( + "Failed to call method of shape object:", method)); + } + return result; + }; + + PyObject* np_type; + TF_ASSIGN_OR_RETURN(np_type, get_attr("np_dtype")); + if (np_type->ob_type != &PyArrayDescr_Type) { + return error("Shape attribute np_dtype is not an integer numpy dtype"); + } + if (!NumpyTypeIsValid(NumpyTypenum(np_type))) { + return error("Shape attribute np_dtype is not a valid integer numpy dtype"); + } + const PrimitiveType element_type = + NumpyTypeToPrimitiveType(NumpyTypenum(np_type)); + Py_DECREF(np_type); + + if (element_type == TUPLE) { + PyObject* py_subshapes; + TF_ASSIGN_OR_RETURN(py_subshapes, call_method("tuple_shapes")); + if (!PyTuple_Check(py_subshapes)) { + return error( + "Return value of Shape method tuple_shapes() is not a tuple"); + } + const int length = PyTuple_Size(py_subshapes); + std::vector subshapes; + subshapes.reserve(length); + for (int i = 0; i < length; i++) { + TF_ASSIGN_OR_RETURN( + const Shape& subshape, + XlaShapeFromPyShape(PyTuple_GetItem(py_subshapes, i))); + subshapes.push_back(subshape); + } + Py_DECREF(py_subshapes); + return ShapeUtil::MakeTupleShape(subshapes); + } else { + PyObject* py_dimensions; + PyObject* py_minor_to_major; + TF_ASSIGN_OR_RETURN(py_dimensions, call_method("dimensions")); + TF_ASSIGN_OR_RETURN(py_minor_to_major, call_method("minor_to_major")); + if (!PyTuple_Check(py_dimensions)) { + return error("Return value of Shape method dimensions() is not a tuple"); + } + if (py_minor_to_major != Py_None && !PyTuple_Check(py_minor_to_major)) { + return error( + "Return value of Shape method minor_to_major() is neither a tuple " + "nor None"); + } + const int length = PyTuple_Size(py_dimensions); + if (py_minor_to_major != Py_None && + length != PyTuple_Size(py_minor_to_major)) { + return error( + "Shape methods dimensions() and minor_to_major() return " + "different-length tuples"); + } + std::vector dimensions(length); + std::vector minor_to_major(length); + for (int i = 0; i < length; i++) { + dimensions[i] = PyIntOrPyLongToLong(PyTuple_GetItem(py_dimensions, i)); + if (dimensions[i] == -1 && PyErr_Occurred()) { + return error("Dimension is not an int"); + } + + if (py_minor_to_major != Py_None) { + minor_to_major[i] = + PyIntOrPyLongToLong(PyTuple_GetItem(py_minor_to_major, i)); + if (minor_to_major[i] == -1 && PyErr_Occurred()) { + return error("Minor-to-major value is not an int"); + } + } + } + bool with_layout = py_minor_to_major != Py_None; + Py_DECREF(py_dimensions); + Py_DECREF(py_minor_to_major); + if (with_layout) { + return ShapeUtil::MakeShapeWithLayout(element_type, dimensions, + minor_to_major); + } else { + return ShapeUtil::MakeShape(element_type, dimensions); + } + } +} + +// Helper that retrieves the member with attr_name, stringifies it if is not +// None, and returns it as a C++ string. +static tensorflow::gtl::optional GetAttrAsString( + PyObject* o, const string& attr_name) { + if (!PyObject_HasAttrString(o, attr_name.c_str())) { + return tensorflow::gtl::nullopt; + } + PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); + if (attr == Py_None) { + Py_DECREF(attr); + return tensorflow::gtl::nullopt; + } + string result = PyObjectCppStr(attr); + Py_DECREF(attr); + return result; +} + +// Helper that retrieves the member with attr_name, checks that it is an integer +// if it is not None, and returns it as an int32 value. +static tensorflow::gtl::optional GetAttrAsInt32( + PyObject* o, const string& attr_name) { + if (!PyObject_HasAttrString(o, attr_name.c_str())) { + return tensorflow::gtl::nullopt; + } + PyObject* attr = PyObject_GetAttrString(o, attr_name.c_str()); + if (attr == Py_None) { + Py_DECREF(attr); + return tensorflow::gtl::nullopt; + } + if (!CheckPyIntOrLong(attr)) { + Py_DECREF(attr); + return tensorflow::gtl::nullopt; + } + long value = PyIntOrPyLongToLong(attr); // NOLINT + Py_DECREF(attr); + if (value == -1 && PyErr_Occurred() != nullptr) { + return tensorflow::gtl::nullopt; + } + if (static_cast(value) != value) { + return tensorflow::gtl::nullopt; + } + return value; +} + +StatusOr OpMetadataFromPyObject(PyObject* o) { + OpMetadata result; + tensorflow::gtl::optional op_type = GetAttrAsString(o, "op_type"); + if (op_type.has_value()) { + result.set_op_type(op_type.value()); + } + tensorflow::gtl::optional op_name = GetAttrAsString(o, "op_name"); + if (op_name.has_value()) { + result.set_op_name(op_name.value()); + } + tensorflow::gtl::optional source_file = + GetAttrAsString(o, "source_file"); + if (source_file.has_value()) { + result.set_source_file(source_file.value()); + } + tensorflow::gtl::optional source_line = + GetAttrAsInt32(o, "source_line"); + if (source_line.has_value()) { + result.set_source_line(source_line.value()); + } + return result; +} + +PyObject* PyObjectFromXlaLiteral(const Literal& literal) { + if (ShapeUtil::IsTuple(literal.shape())) { + int num_elements = ShapeUtil::TupleElementCount(literal.shape()); + PyObject* tuple = PyTuple_New(num_elements); + for (int i = 0; i < num_elements; i++) { + PyTuple_SET_ITEM( + tuple, i, PyObjectFromXlaLiteral(LiteralView::Create(literal, {i}))); + } + return tuple; + } else { + int rank = ShapeUtil::Rank(literal.shape()); + std::vector dimensions(rank); // NOLINT - PyArray requires a long* + for (int i = 0; i < rank; i++) { + dimensions[i] = ShapeUtil::GetDimension(literal.shape(), i); + } + int np_type = PrimitiveTypeToNumpyType(literal.shape().element_type()); + PyObject* array = + PyArray_EMPTY(rank, dimensions.data(), np_type, /*fortran=*/0); + CopyLiteralToNumpyArray(np_type, literal, + reinterpret_cast(array)); + return array; + } +} + +StatusOr> XlaLiteralFromPyObject(PyObject* o) { + if (PyTuple_Check(o)) { + int num_elements = PyTuple_Size(o); + std::vector> elements; + elements.reserve(num_elements); + for (int i = 0; i < num_elements; i++) { + PyObject* element = PyTuple_GetItem(o, i); + TF_ASSIGN_OR_RETURN(auto literal, XlaLiteralFromPyObject(element)); + elements.push_back(std::move(literal)); + } + return Literal::MakeTupleOwned(std::move(elements)); + } else if (PyArray_Check(o)) { + PyArrayObject* py_array = reinterpret_cast(o); + int rank = PyArray_NDIM(py_array); + std::vector dimensions(rank); + for (int i = 0; i < rank; i++) { + dimensions[i] = PyArray_DIM(py_array, i); + } + int np_type = PyArray_TYPE(py_array); + auto literal = Literal::CreateFromDimensions( + NumpyTypeToPrimitiveType(np_type), dimensions); + TF_RETURN_IF_ERROR( + CopyNumpyArrayToLiteral(np_type, py_array, literal.get())); + return std::move(literal); + } else { + return InvalidArgument( + "Non-tuple or Numpy array encountered in conversion to XLA literal."); + } +} + +Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal) { + switch (np_type) { + case NPY_BOOL: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_INT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT8: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_UINT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT16: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT32: + CopyNumpyArrayToLiteral(py_array, literal); + break; + case NPY_FLOAT64: + CopyNumpyArrayToLiteral(py_array, literal); + break; + default: + return InvalidArgument( + "No XLA literal container for Numpy type number: %d", np_type); + } + return Status::OK(); +} + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array) { + switch (np_type) { + case NPY_BOOL: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_INT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT8: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_UINT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT16: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT32: + CopyLiteralToNumpyArray(literal, py_array); + break; + case NPY_FLOAT64: + CopyLiteralToNumpyArray(literal, py_array); + break; + default: + LOG(FATAL) << "No XLA literal container for Numpy type" << np_type; + } +} + +PyObject* LongToPyIntOrPyLong(long x) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_FromLong(x); +#else + return PyLong_FromLong(x); +#endif +} + +long PyIntOrPyLongToLong(PyObject* o) { // NOLINT +#if PY_MAJOR_VERSION < 3 + return PyInt_AsLong(o); +#else + return PyLong_AsLong(o); +#endif +} + +bool CheckPyIntOrLong(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyInt_Check(o); +#else + if (!PyLong_Check(o)) { + return false; + } + int overflow = 0; + PyLong_AsLongAndOverflow(o, &overflow); + return (overflow == 0); +#endif +} + +PyObject* PyNumberToPyInt(PyObject* o) { +#if PY_MAJOR_VERSION < 3 + return PyNumber_Int(o); +#else + return PyNumber_Long(o); +#endif +} + +} // namespace numpy + +} // namespace swig + +} // namespace xla diff --git a/tensorflow/compiler/xla/python/numpy_bridge.h b/tensorflow/compiler/xla/python/numpy_bridge.h new file mode 100644 index 0000000000000000000000000000000000000000..adfcc3b8588dce01718bb19dea936bace483be4d --- /dev/null +++ b/tensorflow/compiler/xla/python/numpy_bridge.h @@ -0,0 +1,123 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// These functions transform Python/Numpy data structures to XLA data +// structures and vice versa, performing copies where +// appropriate. Python tuples and Numpy ndarrays translate to XLA +// tuples and XLA literals, respectively, and Numpy shape/dtype +// information is translated to XLA shape information. + +#ifndef TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ +#define TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/python/lib/core/numpy.h" + +namespace xla { + +namespace swig { + +namespace numpy { + +// Maps XLA primitive types (PRED, S8, F32, ..., and TUPLE) to numpy +// dtypes (NPY_BOOL, NPY_INT8, NPY_FLOAT32, ..., and NPY_OBJECT), and +// vice versa. +int PrimitiveTypeToNumpyType(PrimitiveType primitive_type); +PrimitiveType NumpyTypeToPrimitiveType(int np_type); + +// Determines whether an integer-encoded Numpy dtype is valid, +// i.e. has a supported conversion to an XLA PrimitiveType. +bool NumpyTypeIsValid(int np_type); + +// Converts XLA shape information into a Python pair of the form +// (numpy dtype, dimensions). If the XLA shape represents a tuple, +// then the numpy dtype is NPY_OBJECT ('O') and `dimensions` is a +// Python tuple of shape-description pairs, created +// recursively. Otherwise, `dimensions` is a Python tuple-of-integers +// providing the array dimensions. +// +// The return value is a new reference. +PyObject* PyShapeInfoFromXlaShape(const Shape& shape); + +// Converts a Python object with a method interface mathing that of +// xla_client.Shape into an XLA Shape object. +// +// The return value is a new reference. +StatusOr XlaShapeFromPyShape(PyObject* o); + +// Converts a PyObject that represents operation metadata into protocol buffer +// form. +StatusOr OpMetadataFromPyObject(PyObject* o); + +// Converts an XLA literal to a Python object, either a Numpy ndarray +// or a nested Python tuple thereof. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +// +// The return value is a new reference. +PyObject* PyObjectFromXlaLiteral(const Literal& literal); + +// Converts a Numpy ndarray or a nested Python tuple thereof to a +// corresponding XLA literal. +// +// To avoid transferring ownership of the data buffers that underlie +// PyArrays and XLA literals, this function makes deep copies of all +// array data. +StatusOr > XlaLiteralFromPyObject(PyObject* o); + +// The following functions copy array data from the buffers underlying Numpy +// ndarrays into those underlying XLA literals, and vice versa. + +Status CopyNumpyArrayToLiteral(int np_type, PyArrayObject* py_array, + Literal* literal); + +void CopyLiteralToNumpyArray(int np_type, const Literal& literal, + PyArrayObject* py_array); + +template +void CopyNumpyArrayToLiteral(PyArrayObject* py_array, Literal* literal) { + NativeT* source = static_cast(PyArray_DATA(py_array)); + auto dest = literal->data(); + std::copy(source, source + PyArray_SIZE(py_array), dest.data()); +} + +template +void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) { + NativeT* dest = static_cast(PyArray_DATA(py_array)); + auto source = literal.data(); + std::copy(source.begin(), source.end(), dest); +} + +// Workarounds for Python 2 and 3 interop + +PyObject* LongToPyIntOrPyLong(long x); // NOLINT +long PyIntOrPyLongToLong(PyObject* o); // NOLINT +bool CheckPyIntOrLong(PyObject* o); +PyObject* PyNumberToPyInt(PyObject* o); + +} // namespace numpy + +} // namespace swig + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_PYTHON_NUMPY_BRIDGE_H_ diff --git a/tensorflow/compiler/xla/python/xla.i b/tensorflow/compiler/xla/python/xla.i new file mode 100644 index 0000000000000000000000000000000000000000..1c4021a558d3fcff2abfdbdbad7f3928e86ed3b8 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla.i @@ -0,0 +1,18 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* XLA-wide SWIG wrapper */ + +%include "tensorflow/compiler/xla/python/local_computation_builder.i" diff --git a/tensorflow/compiler/xla/python/xla_client.py b/tensorflow/compiler/xla/python/xla_client.py new file mode 100644 index 0000000000000000000000000000000000000000..a9b4bdb1d6a8fc53adec8f2dcb37335d3f52cf21 --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client.py @@ -0,0 +1,1139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An in-process, local XLA client in Python, supporting AOT compilation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import enum # pylint: disable=g-bad-import-order +import inspect +import itertools +import os + +import numpy as np + +from tensorflow.compiler.xla import xla_data_pb2 +from tensorflow.compiler.xla.python import pywrap_xla as c_api + + +# Most functions are snake_case for consistency with other modules, +# whereas method names of ComputationBuilder and LocalComputation are +# CamelCase for consistency with XLA. +# pylint: disable=invalid-name + + +_OP_METADATA_FIELDS = [ + 'op_type', + 'op_name', + 'source_file', + 'source_line', +] +OpMetadata = collections.namedtuple('OpMetadata', _OP_METADATA_FIELDS) + + +def OpMetadataToProto(pyobj): + proto = xla_data_pb2.OpMetadata() + for field in _OP_METADATA_FIELDS: + attr = getattr(pyobj, field) + if attr is not None: + setattr(proto, field, attr) + return proto + + +def CurrentSourceInfoMetadata(op_type=None, op_name=None, skip_frames=1): + """Helper for use in source mapping that returns an OpMetadata object.""" + full_filename, lineno = inspect.stack()[skip_frames][1:3] + filename = os.path.basename(full_filename) + return OpMetadata( + op_type=op_type, + op_name=op_name, + source_file=filename, + source_line=lineno) + + +class PaddingType(enum.Enum): + VALID = 1 + SAME = 2 + + +def _convert_padding_type_to_pad_values(padding_type, lhs_dims, rhs_dims, + window_strides): + """Maps PaddingType (VALID or SAME) to pad values (list of pairs of ints).""" + if padding_type == PaddingType.VALID: + return [(0, 0)] * len(window_strides) + + out_shape = np.ceil(np.true_divide(lhs_dims, window_strides)).astype(int) + pad_sizes = [max((out_size - 1) * stride + filter_size - in_size, 0) + for out_size, stride, filter_size, in_size + in zip(out_shape, window_strides, rhs_dims, lhs_dims)] + return [(pad_size // 2, pad_size - pad_size // 2) + for pad_size in pad_sizes] + + +_UNARY_OPS = [ + 'Not', + 'Abs', + 'Exp', + 'Floor', + 'Round', + 'Ceil', + 'Log', + 'Sign', + 'Cos', + 'Sin', + 'Tanh', + 'SqrtF32', + 'SquareF32', + 'IsFinite', + 'ReciprocalF32', + 'Neg', + 'Sort', +] + +_BINARY_OPS = [ + 'Eq', + 'Ne', + 'Ge', + 'Gt', + 'Lt', + 'Le', + 'Add', + 'Sub', + 'Mul', + 'Div', + 'Rem', + 'Max', + 'Min', + 'And', + 'Or', + 'Pow', +] + +XLA_ELEMENT_TYPE_TO_DTYPE = { + xla_data_pb2.F32: np.dtype(np.float32), + xla_data_pb2.F64: np.dtype(np.float64), + xla_data_pb2.S32: np.dtype(np.int32), + xla_data_pb2.S64: np.dtype(np.int64), + xla_data_pb2.U32: np.dtype(np.uint32), + xla_data_pb2.U64: np.dtype(np.uint64), + xla_data_pb2.PRED: np.dtype(np.bool), + xla_data_pb2.TUPLE: np.dtype(np.object), +} + +# Note the conversion on the key. Numpy has a known issue wherein dtype hashing +# doesn't work as expected (https://github.com/numpy/numpy/issues/7242). Thus, +# when keying by dtype in this dict, we use the string form of dtypes. +DTYPE_TO_XLA_ELEMENT_TYPE = { + str(v): k + for k, v in XLA_ELEMENT_TYPE_TO_DTYPE.items() +} + + +class LocalBuffer(object): + """Represents a handle to data owned by XLA. + + The referent is ready for use in executing a local, compiled + Computation. On XLA platforms involving a device (e.g. GPU), this + means the referent is in device memory. + """ + + def __init__(self, c_local_shaped_buffer): + self.c_local_shaped_buffer = c_local_shaped_buffer + self._delete = c_api.DeleteLocalShapedBuffer + + @staticmethod + def from_py(npval, layout_fn=None): + npval = require_numpy_array_layout(npval) + if layout_fn: + shape = Shape.from_numpy(npval) + shape = shape.map_leaves(layout_fn) + else: + shape = None + return LocalBuffer(c_api.LocalShapedBuffer.FromLiteral(npval, shape)) + + def to_py(self): + return self.c_local_shaped_buffer.ToLiteral() + + def delete(self): + if self.c_local_shaped_buffer is not None: + self._delete(self.c_local_shaped_buffer) + self.c_local_shaped_buffer = None + + def is_deleted(self): + return self.c_local_shaped_buffer is None + + def __del__(self): + self.delete() + + +class Shape(object): + """XLA shape. + + Represents an XLA shape by a corresponding Python/Numpy type and a + list of dimensions, which are themselves Shapes in case this one + represents an XLA tuple. + """ + + def __init__(self, np_dtype, dimensions, minor_to_major=None): + assert isinstance(dimensions, tuple) + self.np_dtype = np_dtype + self._dimensions = dimensions + self._minor_to_major = minor_to_major + self._check_minor_to_major() + + def __eq__(self, other): + # pylint: disable=protected-access + return (self.np_dtype == other.np_dtype and + self._dimensions == other._dimensions and + self._minor_to_major == other._minor_to_major) + + def __repr__(self): + return ('xla_client.Shape(np_dtype={!r}, dimensions={!r}, ' + 'minor_to_major={!r})').format(self.np_dtype, self._dimensions, + self._minor_to_major) + + def element_type(self): + return DTYPE_TO_XLA_ELEMENT_TYPE[str(self.np_dtype)] + + def is_tuple(self): + return self.element_type() == xla_data_pb2.TUPLE + + def dimensions(self): + if self.is_tuple(): + raise ValueError('Tuple shape has no dimensions') + return self._dimensions + + def minor_to_major(self): + return self._minor_to_major + + def tuple_shapes(self): + if not self.is_tuple(): + raise ValueError('Shape is not a tuple shape') + return self._dimensions + + def rank(self): + return len(self.dimensions()) + + def map_leaves(self, f): + """Map f over each leaf-level array subshape. + + Args: + f: The function to apply. Whenever f returns None, the identity is + applied instead. + + Returns: + A new Shape with the mapped leaves. + """ + if self.is_tuple(): + children = tuple(child.map_leaves(f) for child in self.tuple_shapes()) + return Shape(np.dtype('O'), children) + else: + mapped = f(self) + return self if mapped is None else mapped + + def _check_minor_to_major(self): + mtm = self._minor_to_major + if self.is_tuple(): + assert mtm is None, self + if mtm is not None: + assert self.rank() == len(mtm), self + assert sorted(mtm) == range(len(mtm)), self + + def update_minor_to_major(self, minor_to_major): + if not isinstance(minor_to_major, tuple): + raise TypeError('minor_to_major must be a tuple') + updated = Shape(self.np_dtype, tuple(self.dimensions()), minor_to_major) + updated._check_minor_to_major() # pylint: disable=protected-access + return updated + + @staticmethod + def from_numpy(npval): + + def convert(npval): + if isinstance(npval, tuple): + return Shape(np.dtype('O'), tuple(convert(elt) for elt in npval)) + else: + return Shape(npval.dtype, np.shape(npval)) + + return convert(require_numpy_array_layout(npval)) + + +def _wrap_shape(shape_info): + dtype, dims = shape_info + element_type = DTYPE_TO_XLA_ELEMENT_TYPE[str(dtype)] + if element_type == xla_data_pb2.TUPLE: + dims = tuple(_wrap_shape(subshape_info) for subshape_info in dims) + return Shape(dtype, dims) + + +def _wrap_data_handle(handle): + cdh = xla_data_pb2.ComputationDataHandle() + cdh.handle = handle + return cdh + + +def _unwrap_data_handle(handle_proto): + return handle_proto.handle + + +def _unwrap_data_handles(handle_protos): + return [_unwrap_data_handle(cdh) for cdh in handle_protos] + + +def require_numpy_array_layout(value): + if isinstance(value, tuple): + return tuple(require_numpy_array_layout(x) for x in value) + else: + return np.require(value, requirements=['C', 'A']) + + +class CompileOptions(object): + """Python object for XLA compile options. + + These options can be passed to the 'compile' step when using a local XLA + client. + """ + + def __init__(self): + self.generate_hlo_graph = None + + +def transfer_to_infeed(value, replica_number=None): + """Transfers the given value into the XLA infeed queue. + + XLA's infeed queue is a single queue that feeds the "XLA virtual machine" with + a totally ordered stream of values. This is dequeued from XLA computations via + the Infeed() operation. + + Args: + value: the value that the caller would like to enqueue into the XLA infeed + queue + replica_number: the replica number to infeed the value to -- if not + provided, then the default replica (trivially replica 0) is used. + """ + if replica_number is None: + c_api.TransferToInfeedLocal(require_numpy_array_layout(value)) + else: + c_api.TransferToInfeedLocalReplica( + require_numpy_array_layout(value), replica_number) + + +def transfer_from_outfeed(shape, replica_number=None): + """Transfers a literal of the given shape from replica_number's outfeed. + + Args: + shape: The shape of the value to transfer from outfeed. + replica_number: The replica number ordinal to transfer the outfeed value + from. (Each replica has a distinct outfeed queue.) + + Returns: + The literal value that is produced from the outfeed queue. + """ + return c_api.TransferFromOutfeedLocalReplica(shape, replica_number or 0) + + +class LocalComputation(object): + """Python wrapper for a local XLA Computation. + + A LocalComputation can be executed if it is compiled. Otherwise, it + can still be used as a Computation where required by the + ComputationBuilder methods. + """ + + def __init__(self, c_local_computation, is_compiled): + self.c_local_computation = c_local_computation + self.is_compiled = is_compiled + + # Ensure a reference to C-based destructor for use in __del__. + if is_compiled: + self._delete = c_api.DeleteCompiledLocalComputation + else: + self._delete = c_api.DeleteLocalComputation + + def Compile(self, argument_shapes=(), compile_options=None, layout_fn=None): + if self.is_compiled: + raise ValueError('Attempt to compile a compiled local XLA computation.') + if layout_fn: + argument_shapes = [ + shape.map_leaves(layout_fn) for shape in argument_shapes + ] + return LocalComputation( + self.c_local_computation.Compile(argument_shapes, compile_options), + is_compiled=True) + + def CompileWithExampleArguments(self, + arguments=(), + compile_options=None, + layout_fn=None): + return self.Compile( + argument_shapes=[Shape.from_numpy(arg) for arg in arguments], + compile_options=compile_options, + layout_fn=layout_fn) + + def Execute(self, arguments=(), layout_fn=None): + """Execute with Python values as arguments and return value.""" + if not self.is_compiled: + raise ValueError('Cannot execute an uncompiled local XLA computation.') + argument_shapes = [Shape.from_numpy(arg) for arg in arguments] + if layout_fn: + argument_shapes = [ + shape.map_leaves(layout_fn) for shape in argument_shapes + ] + else: + argument_shapes = [None for shape in argument_shapes] + arguments = tuple(map(require_numpy_array_layout, arguments)) + return self.c_local_computation.Execute(arguments, argument_shapes) + + def ExecuteWithLocalBuffers(self, arguments=()): + """Execute with LocalBuffer arguments and return value.""" + if not self.is_compiled: + raise ValueError('Cannot execute an uncompiled local XLA computation.') + arguments = tuple(arguments) + if any(arg.is_deleted() for arg in arguments): + raise ValueError('Executing with deleted local buffer argument') + return LocalBuffer( + self.c_local_computation.ExecuteWithShapedBuffers( + [arg.c_local_shaped_buffer for arg in arguments])) + + def __del__(self): + self._delete(self.c_local_computation) + + +class ComputationBuilder(object): + """XLA computation builder. + + Enqueues XLA ops in sequence and in order to build a + LocalComputation, which in turn can be compiled into a + CompiledLocalComputation, which in turn can be locally executed. + """ + + # The methods of this class map 1-to-1 onto the XLA C++ + # computation builder API. Therefore, there's no need to laboriously list + # arguments and return values for every method, especially where it's obvious. + # + # pylint: disable=g-doc-return-or-yield + # pylint: disable=g-doc-args + + def __init__(self, name): + self._client = c_api.LocalComputationBuilder(name.encode('utf8')) + self._parameter_numbering = itertools.count() + + def Build(self): + return LocalComputation(self._client.Build(), is_compiled=False) + + def SetOpMetadata(self, op_metadata): + """Set metadata for operations that are about to be enqueued.""" + self._client.SetOpMetadata(op_metadata) + + def ClearOpMetadata(self): + """Clear metadata for operations that are about to be enqueued.""" + self._client.ClearOpMetadata() + + def Infeed(self, shape): + """Enqueues an infeed op onto the computation. + + Infeed operations dequeue data of the given shape from the device's infeed + queue for subsequent use in the computation. + + Returns: + A ComputationDataHandle message. + """ + return _wrap_data_handle(self._client.Infeed(shape)) + + def Outfeed(self, operand): + """Enqueues an outfeed op onto the computation. + + Outfeed operations enqueue data, using the given operand, onto the XLA + outfeed queue for subsequent dequeue via the client API. + """ + self._client.Outfeed( + _unwrap_data_handle(operand), self.GetShape(operand), + ''.encode('utf-8')) + + def Constant(self, value): + """Enqueues a constant op onto the computation. + + Args: + value: value for the constant, as a np.array with an explicit dtype set + to one of the supported types. + + Returns: + A ComputationDataHandle message. + """ + value = require_numpy_array_layout(value) + return _wrap_data_handle(self._client.ConstantLiteral(value)) + + def ConstantF32Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float32)) + + def ConstantF64Scalar(self, value): + """Convenience method to enqueue a scalar F32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.float64)) + + def ConstantS32Scalar(self, value): + """Convenience method to enqueue a scalar S32 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int32)) + + def ConstantS64Scalar(self, value): + """Convenience method to enqueue a scalar S64 constant op. + + Args: + value: a floating-point number. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.int64)) + + def ConstantPredScalar(self, value): + """Convenience method to enqueue a scalar PRED constant op. + + Args: + value: a boolean value. + + Returns: + A ComputationDataHandle message. + """ + return self.Constant(np.array(value, dtype=np.bool)) + + def ParameterWithShape(self, shape, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation, given a shape. + + Args: + shape: the parameter's shape as a Shape object. + name: optional string name for the parameter. + parameter_num: parameter number in the computation function. If None, + the next linear parameter number is used. The default value capability + can be used for auto-numbering. If you're using auto-numbering for some + parameters, use it for *all* parameters to avoid clashes. + + Returns: + A ComputationDataHandle message. + """ + if name is None: + name = '' + if parameter_num is None: + parameter_num = next(self._parameter_numbering) + + return _wrap_data_handle( + self._client.Parameter(parameter_num, shape, name.encode('utf8'))) + + def ParameterFromNumpy(self, value, name=None, parameter_num=None): + """Enqueues a Parameter op onto the computation. + + Args: + value: a Numpy array, or a nested tuple thereof, from which the + shape is inferred. + name: as in ParameterWithShape. + parameter_num: as in ParameterWithShape. + + Returns: + A ComputationDataHandle message. + """ + return self.ParameterWithShape( + Shape.from_numpy(value), name=name, parameter_num=parameter_num) + + def Broadcast(self, operand, sizes): + """Enqueues a broadcast operation onto the computation. + + Args: + operand: the operand ComputationDataHandle to broadcast. + sizes: an iterable of broadcast sizes. + + Returns: + A ComputationDataHandle representing the added broadcast op. + """ + return _wrap_data_handle( + self._client.Broadcast(_unwrap_data_handle(operand), sizes)) + + def Concatenate(self, operands, dimension): + """Enqueues a concatenate operation onto the computation. + + Args: + operands: the operands to concatenate. + dimension: the dimension in which to perform the concatenation. + + Returns: + A ComputationDataHandle representing the added concatenate op. + """ + return _wrap_data_handle( + self._client.ConcatInDim(_unwrap_data_handles(operands), dimension)) + + def ConvertElementType(self, operand, new_element_type): + """Enqueues an element type conversion operation onto the computation. + + Args: + operand: the operand to convert. + new_element_type: the target primitive type. + + Returns: + A ComputationDataHandle representing the added conversion op. + """ + return _wrap_data_handle( + self._client.ConvertElementType( + _unwrap_data_handle(operand), new_element_type)) + + def GetShape(self, operand): + return _wrap_shape(self._client.GetShape(_unwrap_data_handle(operand))) + + def GetReturnValueShape(self): + return _wrap_shape(self._client.GetReturnValueShape()) + + def GetComputationStats(self): + raise NotImplementedError() + + def Pad(self, operand, padding_value, padding_config): + """Enqueues a Pad operation onto the computation. + + Args: + operand: ComputationDataHandle representing the array to pad. + padding_value: ComputationDataHandle representing the scalar pad value. + padding_config: either an xla_data_pb2.PaddingConfig or a list of integer + triples (edge_padding_low, edge_padding_high, interior_padding) + representing the configuration of the padding operation. + + Returns: + A ComputationDataHandle representing the added pad op. + """ + if not isinstance(padding_config, xla_data_pb2.PaddingConfig): + padding_config = GetPaddingConfigFromTriples(padding_config) + return _wrap_data_handle( + self._client.Pad(_unwrap_data_handle(operand), + _unwrap_data_handle(padding_value), + padding_config)) + + def Reshape(self, operand, dimensions, new_sizes): + """Reshape op.""" + return _wrap_data_handle( + self._client.Reshape( + _unwrap_data_handle(operand), dimensions, new_sizes)) + + def CrossReplicaSum(self, operand): + """CrossReplicaSum op. + + Args: + operand: the operand to sum across replica instances. + + Returns: + A ComputationDataHandle that has the sum of the value among all replicas. + """ + return _wrap_data_handle( + self._client.CrossReplicaSum(_unwrap_data_handle(operand))) + + def Collapse(self, operand, dimensions): + """Collapse op.""" + return _wrap_data_handle( + self._client.Collapse(_unwrap_data_handle(operand), dimensions)) + + def Trans(self, operand): + """Specialized matrix transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), [1, 0])) + + def Transpose(self, operand, permutation): + """Transpose op.""" + return _wrap_data_handle( + self._client.Transpose(_unwrap_data_handle(operand), permutation)) + + def Rev(self, operand, dimensions): + """Rev op.""" + return _wrap_data_handle( + self._client.Rev(_unwrap_data_handle(operand), dimensions)) + + def Clamp(self, min, operand, max): # pylint: disable=redefined-builtin + """Clamp op.""" + return _wrap_data_handle( + self._client.Clamp(_unwrap_data_handle(min), + _unwrap_data_handle(operand), + _unwrap_data_handle(max))) + + def SelectAndScatter(self, operand, select, window_dimensions, window_strides, + padding, source, init_value, scatter): + """Select and scatter op, used by the gradient of ReduceWindow. + + Args: + operand: ComputationDataHandle for array of dimension N and type T over + which the windows slide. + select: Computation of type (T, T) -> Pred to apply to the elements of + each window to indicate which element is selected. + window_dimensions: sequence of N integers for dimensions of the window. + window_strides: sequence of N integers for the strides of the window. + padding: PaddingType representing either 'SAME' or 'VALID ' padding. + source: ComputationDataHandle for array of type T with values to scatter. + init_value: ComputationDataHandle of scalar type T for initial out value. + scatter: Computation of type (T, T) -> T to apply to each scatter source + element with its destination element. + + Returns: + A ComputationDataHandle representing the added SelectAndScatter op. + """ + pads = _convert_padding_type_to_pad_values( + padding, self.GetShape(operand).dimensions(), + window_dimensions, window_strides) + return _wrap_data_handle( + self._client.SelectAndScatterWithGeneralPadding( + _unwrap_data_handle(operand), select.c_local_computation, + window_dimensions, window_strides, pads, + _unwrap_data_handle(source), _unwrap_data_handle(init_value), + scatter.c_local_computation)) + + def Select(self, pred, on_true, on_false): + """Element-wise selection op. + + Constructs an output array from elements of two input arrays, based on the + values of a predicate array. + """ + return _wrap_data_handle( + self._client.Select( + _unwrap_data_handle(pred), + _unwrap_data_handle(on_true), + _unwrap_data_handle(on_false))) + + def Slice(self, operand, start_indices, limit_indices, strides=None): + """Enqueues a slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: iterable of N integers containing the starting indices of + the slice for each dimension. + limit_indices: iterable of N integers containing the ending indices + (exclusive) of the slice for each dimension. + strides: optional iterable of N integers containing the stride sizes for + each dimension. + + Returns: + A ComputationDataHandle representing the added Slice op. + """ + if strides is None: + start_indices = list(start_indices) + strides = [1] * len(start_indices) + return _wrap_data_handle( + self._client.Slice( + _unwrap_data_handle(operand), + start_indices, + limit_indices, + strides)) + + def DynamicSlice(self, operand, start_indices, slice_sizes): + """Enqueues a slice op with dynamic start indices onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be sliced. + start_indices: ComputationDataHandle for the 1D array of N integers + containing the starting indices of the slice. + slice_sizes: iterable of N integers containing the slice sizes in each + dimension. + + Returns: + A ComputationDataHandle representing the added DynamicSlice op. + """ + return _wrap_data_handle( + self._client.DynamicSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(start_indices), + slice_sizes)) + + def DynamicUpdateSlice(self, operand, update, start_indices): + """Enqueues a dynamic update slice operation onto the computation. + + Args: + operand: ComputationDataHandle for the N dimensional array to be updated. + update: N dimensional array comprising the slice update. + start_indices: Rank-1 array of N integers comprising the starting indices + of the slice along each dimension. + Returns: + A ComputationDataHandle representing the added DynamicUpdateSlice op. + """ + return _wrap_data_handle( + self._client.DynamicUpdateSlice( + _unwrap_data_handle(operand), + _unwrap_data_handle(update), + _unwrap_data_handle(start_indices))) + + def Tuple(self, *ops): + """Enqueues a tuple operation onto the computation. + + Args: + ops: a sequence of tuple operands (each a ComputationDataHandle). + + Returns: + A ComputationDataHandle representing the added Tuple op. + """ + return _wrap_data_handle(self._client.Tuple(_unwrap_data_handles(ops))) + + def GetTupleElement(self, tup, index): + """Enqueues a 'get tuple element' operation onto the computation. + + Args: + tup: the tuple operand (a ComputationDataHandle). + index: numeric index to select from the tuple. + + Returns: + A ComputationDataHandle representing the added GetTupleElement op. + """ + return _wrap_data_handle( + self._client.GetTupleElement(_unwrap_data_handle(tup), index)) + + def Call(self, computation_to_apply, operands): + """Enqueues a call operation onto the computation. + + Args: + computation_to_apply: a Computation object. + operands: an iterable of ComputationDataHandle. The number and types of + operands must match the arity of computation_to_apply. + + Returns: + A ComputationDataHandle representing the added call op. + """ + return _wrap_data_handle( + self._client.Call(computation_to_apply.c_local_computation, + _unwrap_data_handles(operands))) + + def Map(self, operands, computation_to_apply, dimensions, static_operands=()): + """Enqueues a map operation onto the computation. + + Args: + operands: an iterable of ComputationDataHandle. + computation_to_apply: a Computation object. + dimensions: dimensions over which to apply map the function. + static_operands: auxiliary arguments passed to the applied computation. + + Returns: + A ComputationDataHandle representing the added Map op. + """ + return _wrap_data_handle( + self._client.Map( + _unwrap_data_handles(operands), + computation_to_apply.c_local_computation, + dimensions, + _unwrap_data_handles(static_operands))) + + def Reduce(self, operand, init_value, computation_to_apply, dimensions): + """Enqueues a reduction operation onto the computation. + + Args: + operand: reduction operand (ComputationDataHandle). + init_value: reduction initial value (ComputationDataHandle). + computation_to_apply: a Computation object - binary reduction function. + dimensions: sequence of dimensions (integers) to reduce on. + + Returns: + A ComputationDataHandle representing the added Reduce op. + """ + return _wrap_data_handle( + self._client.Reduce( + _unwrap_data_handle(operand), + _unwrap_data_handle(init_value), + computation_to_apply.c_local_computation, + dimensions)) + + def ReduceWindow(self, operand, init_value, computation_to_apply, + window_dimensions, window_strides, padding): + """Enqueues a windowed reduction operation onto the computation. + + Args: + operand: reduction operand (ComputationDataHandle). + init_value: reduction initial value (ComputationDataHandle). + computation_to_apply: a binary reduction function (Computation). + window_dimensions: dimensions of window (sequence of integers). + window_strides: strides for window (sequence of integers). + padding: PaddingType representing either 'SAME' or 'VALID' padding. + + Returns: + A ComputationDataHandle representing the added ReduceWindow op. + """ + pads = _convert_padding_type_to_pad_values( + padding, self.GetShape(operand).dimensions(), window_dimensions, + window_strides) + return _wrap_data_handle( + self._client.ReduceWindowWithGeneralPadding( + _unwrap_data_handle(operand), + _unwrap_data_handle(init_value), + computation_to_apply.c_local_computation, + window_dimensions, window_strides, pads)) + + def RngNormal(self, mu, sigma, dims): + """Enqueues an RngNormal operation onto the computation. + + Args: + mu: A ComputationDataHandle to an F32 scalar specifying the mean. + sigma: A ComputationDataHandle to an F32 scalar specifying the standard + deviation. + dims: A 1D array-like of nonnegative integers specifying the dimensions. + + Returns: a ComputationDataHandle to the generated array of F32 values. + """ + shape = Shape(self.GetShape(mu).np_dtype, dims) + return _wrap_data_handle( + self._client.RngNormal( + _unwrap_data_handle(mu), _unwrap_data_handle(sigma), shape)) + + def RngUniform(self, a, b, dims): + """Enqueues an RngUniform operation onto the computation. + + Args: + a: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + the type of b) specifying the low end of the interval [a, b) over which + values are generated. + b: a ComputationDataHandle to an F32, S32, or U32 scalar (consistent with + the type of a) specifying the high end of the interval [a, b) over which + values are generated. + dims: A 1D array-like of nonnegative integers specifying the dimensions. + + Returns: a ComputationDataHandle to the generated array of values with the + same numeric type (F32, S32, or U32) as the arguments a and b. + """ + shape = Shape(self.GetShape(a).np_dtype, dims) + return _wrap_data_handle( + self._client.RngUniform( + _unwrap_data_handle(a), _unwrap_data_handle(b), shape)) + + def While(self, cond, body, init): + """Enqueues a While operation onto the computation. + + Args: + cond: a Computation for the loop condition, which has type T -> PRED + body: a Computation for the loop body, which has type T -> T + init: a ComputationDataHandle for the initial parameter, which has type T + + Returns: a ComputationDataHandle representing the While operation. + """ + return _wrap_data_handle( + self._client.While(cond.c_local_computation, + body.c_local_computation, + _unwrap_data_handle(init))) + + def Conditional(self, pred, true_operand, true_computation, false_operand, + false_computation): + """Enqueues a Conditional operation onto the computation. + + Args: + predicate: a ComputationDataHandle to test, which has scalar type PRED + true_operand: a ComputationDataHandle of type T_0 + true_computation: a Computation to apply to true_operand, type T_0 -> S + false_operand: a ComputationDatahandle of type T_1 + false_computation: a Computation to apply to false_operand, type T_1 -> S + + Returns: a ComputationDataHandle representing the Conditional operation. + """ + return _wrap_data_handle( + self._client.Conditional( + _unwrap_data_handle(pred), _unwrap_data_handle(true_operand), + true_computation.c_local_computation, + _unwrap_data_handle(false_operand), + false_computation.c_local_computation)) + + def Dot(self, lhs, rhs): + """Enqueues a dot operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank 1 or rank 2 left-hand-side array. + rhs: ComputationDataHandle for the rank 1 or rank 2 right-hand-side array. + + Returns: a ComputationDataHandle representing the Dot operation. + """ + return _wrap_data_handle( + self._client.Dot(_unwrap_data_handle(lhs), _unwrap_data_handle(rhs))) + + def DotGeneral(self, lhs, rhs, dimension_numbers): + """Enqueues a general dot operation onto the computation. + + Args: + lhs: ComputationDataHandle for the left-hand-side array. + rhs: ComputationDataHandle for the right-hand-side array. + dimension_numbers: either an xla_data_pb2.DotDimensionNumbers or a nested + tuple ((lhs_contract, rhs_contract), (lhs_batch, rhs_batch)) of lists of + integers representing the dimensions to treat as contracting dimensions + and batch dimensions on each input operand. + + Returns: a ComputationDataHandle representing the DotGeneral operation. + """ + if not isinstance(dimension_numbers, xla_data_pb2.DotDimensionNumbers): + dimension_numbers = GetDotDimensionsFromLists(dimension_numbers) + return _wrap_data_handle( + self._client.DotGeneral( + _unwrap_data_handle(lhs), _unwrap_data_handle(rhs), + dimension_numbers)) + + def Conv(self, lhs, rhs, window_strides, padding): + """Enqueues a Conv operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank N+2 array of inputs. + rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + window_strides: length-N array-like of integer kernel strides. + padding: PaddingType representing either 'SAME' or 'VALID' padding. + + Returns: a ComputationDataHandle representing the Conv operation. + """ + pads = _convert_padding_type_to_pad_values( + padding, self.GetShape(lhs).dimensions()[2:], + self.GetShape(rhs).dimensions()[2:], window_strides) + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + return _wrap_data_handle( + self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), + _unwrap_data_handle(rhs), + window_strides, + pads, + (), + (), + dimension_numbers)) + + def ConvWithGeneralPadding(self, lhs, rhs, window_strides, padding, + lhs_dilation, rhs_dilation): + """Enqueues a ConvWithGeneralPadding operation onto the computation. + + Args: + lhs: ComputationDataHandle for the rank N+2 array of inputs. + rhs: ComputationDataHandle for the rank N+2 array of kernel weights. + window_strides: length-N array-like of kernel strides. + padding: length-N array-like of pairs of integers of (low, high) padding. + lhs_dilation: length-N array-like of dilation factors. + rhs_dilation: length-N array-like of dilation factors. + + Returns: + A ComputationdataHandle representing the added ConvWithGeneralPadding op. + """ + dimension_numbers = self._GetConvDimensionNumbers(len(window_strides)) + return _wrap_data_handle( + self._client.ConvGeneralDilated(_unwrap_data_handle(lhs), + _unwrap_data_handle(rhs), + window_strides, + padding, + lhs_dilation, + rhs_dilation, + dimension_numbers)) + + def _GetConvDimensionNumbers(self, num_spatial_dims): + """Create ConvolutionDimensionNumbers proto for convolutions.""" + nd = num_spatial_dims + dimension_numbers = xla_data_pb2.ConvolutionDimensionNumbers() + dimension_numbers.input_batch_dimension = 0 + dimension_numbers.input_feature_dimension = 1 + dimension_numbers.output_batch_dimension = 0 + dimension_numbers.output_feature_dimension = 1 + dimension_numbers.kernel_output_feature_dimension = 0 + dimension_numbers.kernel_input_feature_dimension = 1 + dimension_numbers.input_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.kernel_spatial_dimensions.extend(range(2, 2 + nd)) + dimension_numbers.output_spatial_dimensions.extend(range(2, 2 + nd)) + return dimension_numbers + + +def _forward_methods_to_local_builder(): + """Forward remaining ComputationBuilder methods to the C API. + + Set up methods, corresponding to unary and binary XLA operations, + whose calls are forwarded in a boilerplate manner to the underlying + LocalComputationBuilder C-extension API. + """ + + def forward_to_local_builder_with_handles(target_method, is_binop=False): + """Generate a forwarding method that wraps/unwraps data handles.""" + + def forward(self, *args, **kwargs): + unwrapped_args = [_unwrap_data_handle(arg) for arg in args] + + if is_binop and len(unwrapped_args) < 3: + unwrapped_args.append(kwargs.get('broadcast_dimensions', ())) + + return _wrap_data_handle( + target_method( + self._client, # pylint: disable=protected-access + *unwrapped_args)) + + return forward + + for method_name in _UNARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name)) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + for method_name in _BINARY_OPS: + forward = forward_to_local_builder_with_handles( + getattr(c_api.LocalComputationBuilder, method_name), is_binop=True) + forward.__name__ = method_name + setattr(ComputationBuilder, method_name, forward) + + +_forward_methods_to_local_builder() + + +def initialize_replica_count(replica_count): + """Initializes the desired replica count to use on XLA service init. + + Args: + replica_count: number of replicas that are desired for set up during XLA + initialization. + + Raises: + A runtime exception if the XLA service has already been initialized. + """ + c_api.InitializeReplicaCount(replica_count) + + +def get_replica_count(): + """Returns the current replica count used for the XLA service. + + Note: this will return a value whether the XLA service has been initialized + yet or not. + """ + return c_api.GetReplicaCount() + + +def GetPaddingConfigFromTriples(triples): + """Create PaddingConfig proto from list of triples of integers.""" + padding_config = xla_data_pb2.PaddingConfig() + for lo, hi, interior in triples: + dimension = padding_config.dimensions.add() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + return padding_config + + +def GetDotDimensionsFromLists(dimension_numbers): + (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dimension_numbers + dot_dims_proto = xla_data_pb2.DotDimensionNumbers() + dot_dims_proto.lhs_contracting_dimensions.extend(lhs_contract) + dot_dims_proto.rhs_contracting_dimensions.extend(rhs_contract) + dot_dims_proto.lhs_batch_dimensions.extend(lhs_batch) + dot_dims_proto.rhs_batch_dimensions.extend(rhs_batch) + return dot_dims_proto diff --git a/tensorflow/compiler/xla/python/xla_client_test.py b/tensorflow/compiler/xla/python/xla_client_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d09cd5d57e001fd48d2dba9f2b0ee18374231b --- /dev/null +++ b/tensorflow/compiler/xla/python/xla_client_test.py @@ -0,0 +1,1308 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Python extension-based XLA client.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import threading + +import numpy as np + +from tensorflow.compiler.xla.python import xla_client +import unittest + + +class LocalComputationTest(unittest.TestCase): + """Base class for running an XLA Computation through the local client.""" + + def _NewComputation(self, name=None): + if name is None: + name = self.id() + return xla_client.ComputationBuilder(name) + + def _Execute(self, c, arguments): + compiled_c = c.Build().CompileWithExampleArguments(arguments) + return compiled_c.Execute(arguments) + + def _ExecuteAndAssertWith(self, assert_func, c, arguments, expected): + assert expected is not None + result = self._Execute(c, arguments) + # Numpy's comparison methods are a bit too lenient by treating inputs as + # "array-like", meaning that scalar 4 will be happily compared equal to + # [[4]]. We'd like to be more strict so assert shapes as well. + self.assertEqual(np.asanyarray(result).shape, np.asanyarray(expected).shape) + assert_func(result, expected) + + def _ExecuteAndCompareExact(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_equal, c, arguments, expected) + + def _ExecuteAndCompareClose(self, c, arguments=(), expected=None): + self._ExecuteAndAssertWith(np.testing.assert_allclose, c, arguments, + expected) + + +def NumpyArrayF32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float32 dtype.""" + return np.array(*args, dtype=np.float32, **kwargs) + + +def NumpyArrayF64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.float64 dtype.""" + return np.array(*args, dtype=np.float64, **kwargs) + + +def NumpyArrayS32(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int32 dtype.""" + return np.array(*args, dtype=np.int32, **kwargs) + + +def NumpyArrayS64(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.int64 dtype.""" + return np.array(*args, dtype=np.int64, **kwargs) + + +def NumpyArrayBool(*args, **kwargs): + """Convenience wrapper to create Numpy arrays with a np.bool dtype.""" + return np.array(*args, dtype=np.bool, **kwargs) + + +class ComputationsWithConstantsTest(LocalComputationTest): + """Tests focusing on Constant ops.""" + + def testConstantScalarSumF32(self): + c = self._NewComputation() + root = c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self.assertEqual(c.GetShape(root), c.GetReturnValueShape()) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumF64(self): + c = self._NewComputation() + c.Add(c.ConstantF64Scalar(1.11), c.ConstantF64Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testConstantScalarSumS32(self): + c = self._NewComputation() + c.Add(c.ConstantS32Scalar(1), c.ConstantS32Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantScalarSumS64(self): + c = self._NewComputation() + c.Add(c.ConstantS64Scalar(1), c.ConstantS64Scalar(2)) + self._ExecuteAndCompareClose(c, expected=3) + + def testConstantVectorMulF32(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF32([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF32([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorMulF64(self): + c = self._NewComputation() + c.Mul( + c.Constant(NumpyArrayF64([2.5, 3.3, -1.2, 0.7])), + c.Constant(NumpyArrayF64([-1.2, 2, -2, -3]))) + self._ExecuteAndCompareClose(c, expected=[-3, 6.6, 2.4, -2.1]) + + def testConstantVectorScalarDivF32(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF32([1.5, 2.5, 3.0, -10.8])), + c.ConstantF32Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarDivF64(self): + c = self._NewComputation() + c.Div( + c.Constant(NumpyArrayF64([1.5, 2.5, 3.0, -10.8])), + c.ConstantF64Scalar(2.0)) + self._ExecuteAndCompareClose(c, expected=[0.75, 1.25, 1.5, -5.4]) + + def testConstantVectorScalarPowF32(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF32([1.5, 2.5, 3.0])), c.ConstantF32Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testConstantVectorScalarPowF64(self): + c = self._NewComputation() + c.Pow(c.Constant(NumpyArrayF64([1.5, 2.5, 3.0])), c.ConstantF64Scalar(2.)) + self._ExecuteAndCompareClose(c, expected=[2.25, 6.25, 9.]) + + def testBooleanAnd(self): + c = self._NewComputation() + c.And( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, False]) + + def testBooleanOr(self): + c = self._NewComputation() + c.Or( + c.Constant(NumpyArrayBool([True, False, True, False])), + c.Constant(NumpyArrayBool([True, True, False, False]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False]) + + def testSum2DF32(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF32([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DF64(self): + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6]])), + c.Constant(NumpyArrayF64([[1, -1, 1], [-1, 1, -1]]))) + self._ExecuteAndCompareClose(c, expected=[[2, 1, 4], [3, 6, 5]]) + + def testSum2DWith1DBroadcastDim0F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim0F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 0 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 12, 13], [24, 25, 26], [37, 38, 39]]) + + def testSum2DWith1DBroadcastDim1F32(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testSum2DWith1DBroadcastDim1F64(self): + # sum of a 2D array with a 1D array where the latter is replicated across + # dimension 1 to match the former's shape. + c = self._NewComputation() + c.Add( + c.Constant(NumpyArrayF64([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF64([10, 20, 30])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareClose( + c, expected=[[11, 22, 33], [14, 25, 36], [17, 28, 39]]) + + def testConstantAxpyF32(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF32Scalar(2), + c.Constant(NumpyArrayF32([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF32([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + def testConstantAxpyF64(self): + c = self._NewComputation() + c.Add( + c.Mul( + c.ConstantF64Scalar(2), + c.Constant(NumpyArrayF64([2.2, 3.3, 4.4, 5.5]))), + c.Constant(NumpyArrayF64([100, -100, 200, -200]))) + self._ExecuteAndCompareClose(c, expected=[104.4, -93.4, 208.8, -189]) + + +class ParametersTest(LocalComputationTest): + """Tests focusing on Parameter ops and argument-passing.""" + + def setUp(self): + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.f32_4vector = NumpyArrayF32([-2.3, 3.3, -4.3, 5.3]) + self.f64_scalar_2 = NumpyArrayF64(2.0) + self.f64_4vector = NumpyArrayF64([-2.3, 3.3, -4.3, 5.3]) + self.s32_scalar_3 = NumpyArrayS32(3) + self.s32_4vector = NumpyArrayS32([10, 15, -2, 7]) + self.s64_scalar_3 = NumpyArrayS64(3) + self.s64_4vector = NumpyArrayS64([10, 15, -2, 7]) + + def testScalarTimesVectorAutonumberF32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f32_scalar_2) + p1 = c.ParameterFromNumpy(self.f32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorAutonumberF64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.f64_scalar_2) + p1 = c.ParameterFromNumpy(self.f64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.6, 6.6, -8.6, 10.6]) + + def testScalarTimesVectorS32(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s32_scalar_3) + p1 = c.ParameterFromNumpy(self.s32_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s32_scalar_3, self.s32_4vector], + expected=[30, 45, -6, 21]) + + def testScalarTimesVectorS64(self): + c = self._NewComputation() + p0 = c.ParameterFromNumpy(self.s64_scalar_3) + p1 = c.ParameterFromNumpy(self.s64_4vector) + c.Mul(p0, p1) + self._ExecuteAndCompareExact( + c, + arguments=[self.s64_scalar_3, self.s64_4vector], + expected=[30, 45, -6, 21]) + + def testScalarMinusVectorExplicitNumberingF32(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f32_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f32_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f32_scalar_2, self.f32_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + def testScalarMinusVectorExplicitNumberingF64(self): + # Use explicit numbering and pass parameter_num first. Sub is used since + # it's not commutative and can help catch parameter reversal within the + # computation. + c = self._NewComputation() + p1 = c.ParameterFromNumpy(self.f64_4vector, parameter_num=1) + p0 = c.ParameterFromNumpy(self.f64_scalar_2, parameter_num=0) + c.Sub(p1, p0) + self._ExecuteAndCompareClose( + c, + arguments=[self.f64_scalar_2, self.f64_4vector], + expected=[-4.3, 1.3, -6.3, 3.3]) + + +class LocalBufferTest(LocalComputationTest): + """Tests focusing on execution with LocalBuffers.""" + + def _Execute(self, c, arguments): + compiled_c = c.Build().CompileWithExampleArguments(arguments) + arg_buffers = [xla_client.LocalBuffer.from_py(arg) for arg in arguments] + result_buffer = compiled_c.ExecuteWithLocalBuffers(arg_buffers) + return result_buffer.to_py() + + def testConstantSum(self): + c = self._NewComputation() + c.Add(c.ConstantF32Scalar(1.11), c.ConstantF32Scalar(3.14)) + self._ExecuteAndCompareClose(c, expected=4.25) + + def testOneParameterSum(self): + c = self._NewComputation() + c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11)], + expected=4.25) + + def testTwoParameterSum(self): + c = self._NewComputation() + c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), + c.ParameterFromNumpy(NumpyArrayF32(0.))) + self._ExecuteAndCompareClose( + c, + arguments=[NumpyArrayF32(1.11), NumpyArrayF32(3.14)], + expected=4.25) + + def testCannotCallWithDeletedBuffers(self): + c = self._NewComputation() + c.Add(c.ParameterFromNumpy(NumpyArrayF32(0.)), c.ConstantF32Scalar(3.14)) + arg = NumpyArrayF32(1.11) + compiled_c = c.Build().CompileWithExampleArguments([arg]) + arg_buffer = xla_client.LocalBuffer.from_py(arg) + arg_buffer.delete() + with self.assertRaises(ValueError): + compiled_c.ExecuteWithLocalBuffers([arg_buffer]) + + +class SingleOpTest(LocalComputationTest): + """Tests for single ops. + + The goal here is smoke testing - to exercise the most basic functionality of + single XLA ops. As minimal as possible number of additional ops are added + around the op being tested. + """ + + def testConcatenateF32(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF32([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF32([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConcatenateF64(self): + c = self._NewComputation() + c.Concatenate( + (c.Constant(NumpyArrayF64([1.0, 2.0, 3.0])), + c.Constant(NumpyArrayF64([4.0, 5.0, 6.0]))), + dimension=0) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + + def testConvertElementType(self): + xla_types = { + np.bool: xla_client.xla_data_pb2.PRED, + np.int32: xla_client.xla_data_pb2.S32, + np.int64: xla_client.xla_data_pb2.S64, + np.float32: xla_client.xla_data_pb2.F32, + np.float64: xla_client.xla_data_pb2.F64, + } + + def _ConvertAndTest(template, src_dtype, dst_dtype): + c = self._NewComputation() + x = c.Constant(np.array(template, dtype=src_dtype)) + c.ConvertElementType(x, xla_types[dst_dtype]) + + result = c.Build().Compile().Execute() + expected = np.array(template, dtype=dst_dtype) + + self.assertEqual(result.shape, expected.shape) + self.assertEqual(result.dtype, expected.dtype) + np.testing.assert_equal(result, expected) + + x = [0, 1, 0, 0, 1] + for src_dtype, dst_dtype in itertools.product(xla_types, xla_types): + _ConvertAndTest(x, src_dtype, dst_dtype) + + def testCrossReplicaSumOneReplica(self): + samples = [ + NumpyArrayF32(42.0), + NumpyArrayF32([97.0]), + NumpyArrayF32([64.0, 117.0]), + NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]), + ] + for lhs in samples: + c = self._NewComputation() + c.CrossReplicaSum(c.Constant(lhs)) + self._ExecuteAndCompareExact(c, expected=lhs) + + def testDotMatrixVectorF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixVectorF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0], [20.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF32(self): + c = self._NewComputation() + lhs = NumpyArrayF32([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF32([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotMatrixMatrixF64(self): + c = self._NewComputation() + lhs = NumpyArrayF64([[2.0, 3.0], [4.0, 5.0]]) + rhs = NumpyArrayF64([[10.0, 20.0], [100.0, 200.0]]) + c.Dot(c.Constant(lhs), c.Constant(rhs)) + self._ExecuteAndCompareClose(c, expected=np.dot(lhs, rhs)) + + def testDotGeneral(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + dimension_numbers = (([2], [1]), ([0], [0])) + c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + + def testDotGeneralWithDotDimensionNumbersProto(self): + c = self._NewComputation() + rng = np.random.RandomState(0) + lhs = NumpyArrayF32(rng.randn(10, 3, 4)) + rhs = NumpyArrayF32(rng.randn(10, 4, 5)) + + dimension_numbers = xla_client.xla_data_pb2.DotDimensionNumbers() + dimension_numbers.lhs_contracting_dimensions.append(2) + dimension_numbers.rhs_contracting_dimensions.append(1) + dimension_numbers.lhs_batch_dimensions.append(0) + dimension_numbers.rhs_batch_dimensions.append(0) + + c.DotGeneral(c.Constant(lhs), c.Constant(rhs), dimension_numbers) + self._ExecuteAndCompareClose(c, expected=np.matmul(lhs, rhs)) + + def testConvF32Same(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 3, 4) + rhs = a(1, 2, 1, 2) * 10 + c.Conv(c.Constant(lhs), c.Constant(rhs), + [1, 1], xla_client.PaddingType.SAME) + result = np.array([[[[640., 700., 760., 300.], + [880., 940., 1000., 380.], + [1120., 1180., 1240., 460.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvF32Valid(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 2, 3, 4) + rhs = a(1, 2, 1, 2) * 10 + c.Conv(c.Constant(lhs), c.Constant(rhs), + [2, 1], xla_client.PaddingType.VALID) + result = np.array([[[[640., 700., 760.], + [1120., 1180., 1240.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testConvWithGeneralPaddingF32(self): + c = self._NewComputation() + a = lambda *dims: np.arange(np.prod(dims)).reshape(dims).astype("float32") + lhs = a(1, 1, 2, 3) + rhs = a(1, 1, 1, 2) * 10 + strides = [1, 1] + pads = [(1, 0), (0, 1)] + lhs_dilation = (2, 1) + rhs_dilation = (1, 1) + c.ConvWithGeneralPadding(c.Constant(lhs), c.Constant(rhs), + strides, pads, lhs_dilation, rhs_dilation) + result = np.array([[[[0., 0., 0.], + [10., 20., 0.], + [0., 0., 0.], + [40., 50., 0.]]]]) + self._ExecuteAndCompareClose(c, expected=result) + + def testBooleanNot(self): + c = self._NewComputation() + arr = NumpyArrayBool([True, False, True]) + c.Not(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=~arr) + + def testExp(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Exp(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.exp(arr)) + + def testRound(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Round(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.round(arr)) + + def testLog(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Log(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.log(arr)) + + def testNeg(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Neg(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=-arr) + + def testFloor(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Floor(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.floor(arr)) + + def testCeil(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Ceil(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.ceil(arr)) + + def testAbs(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, -12.1, 2.4, -1.]) + c.Abs(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.abs(arr)) + + def testTanh(self): + c = self._NewComputation() + arr = NumpyArrayF32([3.3, 12.1]) + c.Tanh(c.Constant(arr)) + self._ExecuteAndCompareClose(c, expected=np.tanh(arr)) + + def testTrans(self): + + def _TransposeAndTest(array): + c = self._NewComputation() + c.Trans(c.Constant(array)) + self._ExecuteAndCompareClose(c, expected=array.T) + + # Test square and non-square matrices in both default (C) and F orders. + for array_fun in [NumpyArrayF32, NumpyArrayF64]: + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]])) + _TransposeAndTest(array_fun([[1, 2, 3], [4, 5, 6]], order="F")) + _TransposeAndTest(array_fun([[1, 2], [4, 5]])) + _TransposeAndTest(array_fun([[1, 2], [4, 5]], order="F")) + + def testTranspose(self): + + def _TransposeAndTest(array, permutation): + c = self._NewComputation() + c.Transpose(c.Constant(array), permutation) + expected = np.transpose(array, permutation) + self._ExecuteAndCompareClose(c, expected=expected) + + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2, 3], [4, 5, 6]]), [1, 0]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [0, 1]) + _TransposeAndTest(NumpyArrayF32([[1, 2], [4, 5]]), [1, 0]) + + arr = np.random.RandomState(0).randn(2, 3, 4).astype(np.float32) + for permutation in itertools.permutations(range(arr.ndim)): + _TransposeAndTest(arr, permutation) + _TransposeAndTest(np.asfortranarray(arr), permutation) + + def testEq(self): + c = self._NewComputation() + c.Eq( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False]) + + def testNe(self): + c = self._NewComputation() + c.Ne( + c.Constant(NumpyArrayS32([1, 2, 3, 4])), + c.Constant(NumpyArrayS32([4, 2, 3, 1]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True]) + + c.Ne( + c.Constant(NumpyArrayF32([-2.0, 0.0, + float("nan"), + float("nan")])), + c.Constant(NumpyArrayF32([2.0, -0.0, 1.0, float("nan")]))) + self._ExecuteAndAssertWith( + np.testing.assert_allclose, c, (), expected=[True, False, True, True]) + + def testGt(self): + c = self._NewComputation() + c.Gt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, True, True, False, False]) + + def testGe(self): + c = self._NewComputation() + c.Ge( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, True, True, False, False]) + + def testLt(self): + c = self._NewComputation() + c.Lt( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[False, False, False, True, True]) + + def testLe(self): + c = self._NewComputation() + c.Le( + c.Constant(NumpyArrayS32([1, 2, 3, 4, 9])), + c.Constant(NumpyArrayS32([1, 0, 2, 7, 12]))) + self._ExecuteAndCompareExact(c, expected=[True, False, False, True, True]) + + def testMax(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 2.0, 3.0, 7.0, 12.0]) + + def testMaxExplicitBroadcastDim0(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(0,)) + self._ExecuteAndCompareExact(c, expected=[[3, 3, 3], [4, 5, 6], [7, 8, 9]]) + + def testMaxExplicitBroadcastDim1(self): + c = self._NewComputation() + c.Max( + c.Constant(NumpyArrayF32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayF32([3, 4, 5])), + broadcast_dimensions=(1,)) + self._ExecuteAndCompareExact(c, expected=[[3, 4, 5], [4, 5, 6], [7, 8, 9]]) + + def testMin(self): + c = self._NewComputation() + c.Min( + c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0, 9.0])), + c.Constant(NumpyArrayF32([1.0, 0.0, 2.0, 7.0, 12.0]))) + self._ExecuteAndCompareExact(c, expected=[1.0, 0.0, 2.0, 4.0, 9.0]) + + def testPad(self): + c = self._NewComputation() + c.Pad( + c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + c.Constant(NumpyArrayF32(0.0)), + [(1, 2, 1), (0, 1, 0)]) + self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], + [1.0, 2.0, 0.0], + [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + + def testPadWithPaddingConfig(self): + c = self._NewComputation() + padding_config = xla_client.xla_data_pb2.PaddingConfig() + for lo, hi, interior in [(1, 2, 1), (0, 1, 0)]: + dimension = padding_config.dimensions.add() + dimension.edge_padding_low = lo + dimension.edge_padding_high = hi + dimension.interior_padding = interior + c.Pad( + c.Constant(NumpyArrayF32([[1.0, 2.0], [3.0, 4.0]])), + c.Constant(NumpyArrayF32(0.0)), + padding_config) + self._ExecuteAndCompareClose(c, expected=[[0.0, 0.0, 0.0], + [1.0, 2.0, 0.0], + [0.0, 0.0, 0.0], + [3.0, 4.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0]]) + + def testReshape(self): + c = self._NewComputation() + c.Reshape( + c.Constant(NumpyArrayS32([[1, 2], [3, 4], [5, 6]])), + dimensions=[0, 1], + new_sizes=[2, 3]) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 5, 6]]) + + def testCollapse(self): + c = self._NewComputation() + c.Collapse( + c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[1, 2]) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3, 4], [5, 6, 7, 8]]) + + def testRev(self): + c = self._NewComputation() + c.Rev( + c.Constant(NumpyArrayS32([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])), + dimensions=[0, 2]) + self._ExecuteAndCompareExact( + c, expected=[[[6, 5], [8, 7]], [[2, 1], [4, 3]]]) + + def testClampF32(self): + c = self._NewComputation() + c.Clamp( + c.Constant(NumpyArrayF32(-1)), + c.Constant(NumpyArrayF32([-2, -1, 0, 1, 2, 3])), + c.Constant(NumpyArrayF32(2))) + self._ExecuteAndCompareExact(c, expected=[-1, -1, 0, 1, 2, 2]) + + # TODO(b/72689392): re-enable when bug S32 resolved + def DISABLED_testClampS32(self): + c = self._NewComputation() + c.Clamp( + c.Constant(NumpyArrayS32(-1)), + c.Constant(NumpyArrayS32([-2, -1, 0, 1, 2, 3])), + c.Constant(NumpyArrayS32(2))) + self._ExecuteAndCompareExact(c, expected=[-1, 0, 1, 2, 2]) + + def testSelect(self): + c = self._NewComputation() + c.Select( + c.Constant(NumpyArrayBool([True, False, False, True, False])), + c.Constant(NumpyArrayS32([1, 2, 3, 4, 5])), + c.Constant(NumpyArrayS32([-1, -2, -3, -4, -5]))) + self._ExecuteAndCompareExact(c, expected=[1, -2, -3, 4, -5]) + + def testSlice(self): + c = self._NewComputation() + c.Slice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), [1, 0], + [3, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicSlice(self): + c = self._NewComputation() + c.DynamicSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([1, 0])), [2, 2]) + self._ExecuteAndCompareExact(c, expected=[[4, 5], [7, 8]]) + + def testDynamicUpdateSlice(self): + c = self._NewComputation() + c.DynamicUpdateSlice( + c.Constant(NumpyArrayS32([[1, 2, 3], [4, 5, 6], [7, 8, 9]])), + c.Constant(NumpyArrayS32([[1, 2], [3, 4]])), + c.Constant(NumpyArrayS32([1, 1]))) + self._ExecuteAndCompareExact(c, expected=[[1, 2, 3], [4, 1, 2], [7, 3, 4]]) + + def testTuple(self): + c = self._NewComputation() + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))) + result = c.Build().Compile().Execute() + self.assertIsInstance(result, tuple) + np.testing.assert_equal(result[0], 42) + np.testing.assert_allclose(result[1], [1.0, 2.0]) + np.testing.assert_equal(result[2], [True, False, False, True]) + + def testGetTupleElement(self): + c = self._NewComputation() + c.GetTupleElement( + c.Tuple( + c.ConstantS32Scalar(42), c.Constant(NumpyArrayF32([1.0, 2.0])), + c.Constant(NumpyArrayBool([True, False, False, True]))), 1) + self._ExecuteAndCompareClose(c, expected=[1.0, 2.0]) + + def testBroadcast(self): + c = self._NewComputation() + c.Broadcast(c.Constant(NumpyArrayS32([10, 20, 30, 40])), sizes=(3,)) + self._ExecuteAndCompareExact( + c, expected=[[10, 20, 30, 40], [10, 20, 30, 40], [10, 20, 30, 40]]) + + def testRngNormal(self): + shape = (2, 3) + c = self._NewComputation() + c.RngNormal(c.Constant(NumpyArrayF32(0.)), c.Constant(NumpyArrayF32(1.)), + dims=shape) + result = c.Build().Compile().Execute() + # since the result is random, we just check shape and uniqueness + self.assertEqual(result.shape, shape) + self.assertEqual(len(np.unique(result)), np.prod(shape)) + + def testRngUniformF32(self): + lo, hi = 2., 4. + shape = (2, 3) + c = self._NewComputation() + c.RngUniform(c.Constant(NumpyArrayF32(lo)), c.Constant(NumpyArrayF32(hi)), + dims=shape) + result = c.Build().Compile().Execute() + # since the result is random, we just check shape, uniqueness, and range + self.assertEqual(result.shape, shape) + self.assertEqual(len(np.unique(result)), np.prod(shape)) + self.assertTrue(np.all(lo <= result)) + self.assertTrue(np.all(result < hi)) + + def testRngUniformS32(self): + lo, hi = 2, 4 + shape = (2, 3) + c = self._NewComputation() + c.RngUniform(c.Constant(NumpyArrayS32(lo)), c.Constant(NumpyArrayS32(hi)), + dims=shape) + result = c.Build().Compile().Execute() + # since the result is random, we just check shape, integrality, and range + self.assertEqual(result.shape, shape) + self.assertEqual(result.dtype, np.int32) + self.assertTrue(np.all(lo <= result)) + self.assertTrue(np.all(result < hi)) + + +class EmbeddedComputationsTest(LocalComputationTest): + """Tests for XLA graphs with embedded computations (such as maps).""" + + def _CreateConstantS32Computation(self): + """Computation (f32) -> s32 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s32_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantS32Scalar(1) + return c.Build() + + def _CreateConstantS64Computation(self): + """Computation (f64) -> s64 that returns a constant 1 for any input.""" + c = self._NewComputation("constant_s64_one") + # TODO(eliben): consider adding a nicer way to create new parameters without + # having to create dummy Numpy arrays or populating Shape messages. Perhaps + # we need our own (Python-client-own) way to represent Shapes conveniently. + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantS64Scalar(1) + return c.Build() + + def _CreateConstantF32Computation(self): + """Computation (f32) -> f32 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f32_one") + c.ParameterFromNumpy(NumpyArrayF32(0)) + c.ConstantF32Scalar(1.0) + return c.Build() + + def _CreateConstantF64Computation(self): + """Computation (f64) -> f64 that returns a constant 1.0 for any input.""" + c = self._NewComputation("constant_f64_one") + c.ParameterFromNumpy(NumpyArrayF64(0)) + c.ConstantF64Scalar(1.0) + return c.Build() + + def _CreateMulF32By2Computation(self): + """Computation (f32) -> f32 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f32_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(2.0)) + return c.Build() + + def _CreateMulF32ByParamComputation(self): + """Computation (f32) -> f32 that multiplies one parameter by the other.""" + c = self._NewComputation("mul_f32_by_param") + c.Mul(c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateMulF64By2Computation(self): + """Computation (f64) -> f64 that multiplies its parameter by 2.""" + c = self._NewComputation("mul_f64_by2") + c.Mul(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(2.0)) + return c.Build() + + def _CreateBinaryAddF32Computation(self): + """Computation (f32, f32) -> f32 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryAddF64Computation(self): + """Computation (f64, f64) -> f64 that adds its two parameters.""" + c = self._NewComputation("add_param0_by_param1") + c.Add( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateBinaryDivF32Computation(self): + """Computation (f32, f32) -> f32 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryDivF64Computation(self): + """Computation (f64, f64) -> f64 that divides its two parameters.""" + c = self._NewComputation("div_param0_by_param1") + c.Div( + c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _CreateTestF32Lt10Computation(self): + """Computation (f32) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f32_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF32(0)), c.ConstantF32Scalar(10.)) + return c.Build() + + def _CreateTestF64Lt10Computation(self): + """Computation (f64) -> bool that tests if its parameter is less than 10.""" + c = self._NewComputation("test_f64_lt_10") + c.Lt(c.ParameterFromNumpy(NumpyArrayF64(0)), c.ConstantF64Scalar(10.)) + return c.Build() + + def _CreateBinaryGeF32Computation(self): + """Computation (f32, f32) -> bool that tests first_param >= second_param.""" + c = self._NewComputation("param0_lt_param1") + c.Ge(c.ParameterFromNumpy(NumpyArrayF32(0)), + c.ParameterFromNumpy(NumpyArrayF32(0))) + return c.Build() + + def _CreateBinaryGeF64Computation(self): + """Computation (f64, f64) -> bool that tests first_param >= second_param.""" + c = self._NewComputation("param0_lt_param1") + c.Ge(c.ParameterFromNumpy(NumpyArrayF64(0)), + c.ParameterFromNumpy(NumpyArrayF64(0))) + return c.Build() + + def _MakeSample3DArrayF32(self): + return NumpyArrayF32([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def _MakeSample3DArrayF64(self): + return NumpyArrayF64([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]], + [[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [4, 5, 6]]]) + + def testCallF32(self): + c = self._NewComputation() + c.Call( + self._CreateMulF32By2Computation(), + operands=(c.ConstantF32Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testCallF64(self): + c = self._NewComputation() + c.Call( + self._CreateMulF64By2Computation(), + operands=(c.ConstantF64Scalar(5.0),)) + self._ExecuteAndCompareClose(c, expected=10.0) + + def testMapEachElementToS32Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS32Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapEachElementToS64Constant(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantS64Computation(), [0]) + self._ExecuteAndCompareExact(c, expected=[1, 1, 1, 1]) + + def testMapMulBy2F32(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testMapMulBy2F64(self): + c = self._NewComputation() + c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 4.0, 6.0, 8.0]) + + def testSimpleMapChainF32(self): + # Chains a map of constant-f32 with a map of mul-by-2 + c = self._NewComputation() + const_f32 = c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF32Computation(), [0]) + c.Map([const_f32], self._CreateMulF32By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testSimpleMapChainF64(self): + # Chains a map of constant-f64 with a map of mul-by-2 + c = self._NewComputation() + const_f64 = c.Map([c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0]))], + self._CreateConstantF64Computation(), [0]) + c.Map([const_f64], self._CreateMulF64By2Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[2.0, 2.0, 2.0, 2.0]) + + def testDivVectorsWithMapF32(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF32([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF32Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def testDivVectorsWithMapF64(self): + c = self._NewComputation() + c.Map((c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + c.Constant(NumpyArrayF64([5.0, 5.0, 4.0, 4.0]))), + self._CreateBinaryDivF64Computation(), [0]) + self._ExecuteAndCompareClose(c, expected=[0.2, 0.4, 0.75, 1.0]) + + def DISABLED_testMapWithStaticOperands(self): + c = self._NewComputation() + factor = c.ConstantF32Scalar(3.0) + c.Map([c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0]))], + self._CreateMulF32ByParamComputation(), [0], + static_operands=[factor]) + self._ExecuteAndCompareClose(c, expected=[3.0, 6.0, 9.0, 12.0]) + + def testSelectAndScatterF32(self): + c = self._NewComputation() + c.SelectAndScatter(c.Constant(NumpyArrayF32([[1., 2., 6.], [4., 5., 3.]])), + select=self._CreateBinaryGeF32Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID, + source=c.Constant(NumpyArrayF32([[0.1, 0.2]])), + init_value=c.Constant(NumpyArrayF32(1)), + scatter=self._CreateBinaryAddF32Computation()) + self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + + def testSelectAndScatterF64(self): + c = self._NewComputation() + c.SelectAndScatter(c.Constant(NumpyArrayF64([[1., 2., 6.], [4., 5., 3.]])), + select=self._CreateBinaryGeF64Computation(), + window_dimensions=(2, 1), + window_strides=(1, 2), + padding=xla_client.PaddingType.VALID, + source=c.Constant(NumpyArrayF64([[0.1, 0.2]])), + init_value=c.Constant(NumpyArrayF64(1)), + scatter=self._CreateBinaryAddF64Computation()) + self._ExecuteAndCompareClose(c, expected=[[1., 1., 1.2], [1.1, 1., 1.]]) + + def testReduce1DtoScalarF32(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF32([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce1DtoScalarF64(self): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(NumpyArrayF64([1.0, 2.0, 3.0, 4.0])), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=10) + + def testReduce2DTo1DDim0F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim0F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[0]) + self._ExecuteAndCompareClose(c, expected=[5, 7, 9]) + + def testReduce2DTo1DDim1F32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce2DTo1DDim1F64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=[1]) + self._ExecuteAndCompareClose(c, expected=[6, 15]) + + def testReduce3DAllPossibleWaysF32(self): + input_array = self._MakeSample3DArrayF32() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testReduce3DAllPossibleWaysF64(self): + input_array = self._MakeSample3DArrayF64() + + def _ReduceAndTest(*dims): + c = self._NewComputation() + c.Reduce( + operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + dimensions=dims) + self._ExecuteAndCompareClose( + c, expected=np.sum(input_array, axis=tuple(dims))) + + _ReduceAndTest(0) + _ReduceAndTest(0) + _ReduceAndTest(0, 1) + _ReduceAndTest(0, 2) + _ReduceAndTest(1, 2) + _ReduceAndTest(0, 1, 2) + + def testReduceWindowValidUnitStridesF32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + + def testReduceWindowSameUnitStridesF32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.SAME) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + + def testReduceWindowValidGeneralStridesF32(self): + input_array = NumpyArrayF32([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF32Scalar(0), + computation_to_apply=self._CreateBinaryAddF32Computation(), + window_dimensions=(2, 1), window_strides=(1, 2), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + + def testReduceWindowValidUnitStridesF64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.]]) + + def testReduceWindowSameUnitStridesF64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), window_strides=(1, 1), + padding=xla_client.PaddingType.SAME) + self._ExecuteAndCompareClose(c, expected=[[5., 7., 9.], [4., 5., 6.]]) + + def testReduceWindowValidGeneralStridesF64(self): + input_array = NumpyArrayF64([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + c = self._NewComputation() + c.ReduceWindow(operand=c.Constant(input_array), + init_value=c.ConstantF64Scalar(0), + computation_to_apply=self._CreateBinaryAddF64Computation(), + window_dimensions=(2, 1), window_strides=(1, 2), + padding=xla_client.PaddingType.VALID) + self._ExecuteAndCompareClose(c, expected=[[5., 9.]]) + + def testWhileF32(self): + cond = self._CreateTestF32Lt10Computation() + body = self._CreateMulF32By2Computation() + c = self._NewComputation() + init = c.ConstantF32Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + def testWhileF64(self): + cond = self._CreateTestF64Lt10Computation() + body = self._CreateMulF64By2Computation() + c = self._NewComputation() + init = c.ConstantF64Scalar(1.) + c.While(cond, body, init) + self._ExecuteAndCompareClose(c, expected=16.) + + def testConditionalTrue(self): + c = self._NewComputation() + pred = c.ConstantPredScalar(True) + true_operand = c.ConstantF32Scalar(3.) + true_computation = self._CreateMulF32By2Computation() + false_operand = c.ConstantF32Scalar(2.) + false_computation = self._CreateConstantF32Computation() + c.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=6.) + + def testConditionalFalse(self): + c = self._NewComputation() + pred = c.ConstantPredScalar(False) + true_operand = c.ConstantF32Scalar(3.) + true_computation = self._CreateMulF32By2Computation() + false_operand = c.ConstantF32Scalar(2.) + false_computation = self._CreateConstantF32Computation() + c.Conditional(pred, true_operand, true_computation, false_operand, + false_computation) + self._ExecuteAndCompareClose(c, expected=1.) + + def testInfeedS32Values(self): + to_infeed = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + c.Infeed(xla_client.Shape.from_numpy(to_infeed[0])) + compiled_c = c.Build().CompileWithExampleArguments() + for item in to_infeed: + xla_client.transfer_to_infeed(item) + + for item in to_infeed: + result = compiled_c.Execute() + self.assertEqual(result, item) + + def testInfeedThenOutfeedS32(self): + to_round_trip = NumpyArrayS32([1, 2, 3, 4]) + c = self._NewComputation() + x = c.Infeed(xla_client.Shape.from_numpy(to_round_trip[0])) + c.Outfeed(x) + + compiled_c = c.Build().CompileWithExampleArguments() + + for want in to_round_trip: + execution = threading.Thread(target=compiled_c.Execute) + execution.start() + xla_client.transfer_to_infeed(want) + got = xla_client.transfer_from_outfeed( + xla_client.Shape.from_numpy(to_round_trip[0])) + execution.join() + self.assertEqual(want, got) + + +class ErrorTest(LocalComputationTest): + + def setUp(self): + self.f32_scalar_2 = NumpyArrayF32(2.0) + self.s32_scalar_2 = NumpyArrayS32(2) + + def testInvokeWithWrongElementType(self): + c = self._NewComputation() + c.SetOpMetadata(xla_client.CurrentSourceInfoMetadata()) + c.ParameterFromNumpy(self.s32_scalar_2) + c.ClearOpMetadata() + self.assertRaisesRegexp( + RuntimeError, r"Invalid argument shape.*xla_client_test.py.*" + r"expected s32\[\], got f32\[\]", + lambda: c.Build().CompileWithExampleArguments([self.f32_scalar_2])) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index 5bb81b80dde4c6d9324d33ddd5d6b6d6ad3cc1ac..a9acdae380af5b7f9efb3d08302fc717108f5e40 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -195,14 +195,26 @@ ReferenceUtil::ReduceWindow1DGeneric( const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding) { std::vector dim_lengths{static_cast(operand.size())}; - auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + return ReduceWindow1DGeneric( + operand, init, reduce_func, window, stride, + xla::MakePadding(dim_lengths, window, stride, padding)); +} +/* static */ std::unique_ptr> +ReferenceUtil::ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding) { + std::vector dim_lengths{static_cast(operand.size())}; std::vector window_counts(window.size(), 0); std::vector pad_low(window.size(), 0); for (int64 i = 0; i < window.size(); ++i) { + int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second; window_counts[i] = - WindowCount(dim_lengths[i], window[i], stride[i], padding); - pad_low[i] = padding_both[i].first; + window_util::StridedBound(padded_width, window[i], stride[i]); + pad_low[i] = padding[i].first; } auto result = MakeUnique>(window_counts[0]); @@ -269,6 +281,51 @@ ReferenceUtil::ReduceWindow1DAdd( return result; } +/* static */ std::unique_ptr> ReferenceUtil::ReduceWindow3DAdd( + const Array3D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding) { + std::vector dim_lengths{operand.n1(), operand.n2(), operand.n3()}; + auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding); + + std::vector window_counts(window.size(), 0); + std::vector pad_low(window.size(), 0); + for (int64 i = 0; i < window.size(); ++i) { + window_counts[i] = + WindowCount(dim_lengths[i], window[i], stride[i], padding); + pad_low[i] = padding_both[i].first; + } + auto result = MakeUnique>(window_counts[0], window_counts[1], + window_counts[2]); + + for (int64 i0 = 0; i0 < window_counts[0]; ++i0) { + for (int64 i1 = 0; i1 < window_counts[1]; ++i1) { + for (int64 i2 = 0; i2 < window_counts[2]; ++i2) { + int64 i0_base = i0 * stride[0] - pad_low[0]; + int64 i1_base = i1 * stride[1] - pad_low[1]; + int64 i2_base = i2 * stride[2] - pad_low[2]; + + float val = init; + for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) { + for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) { + for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) { + if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 && + i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() && + i1_base + i1_win < operand.n2() && + i2_base + i2_win < operand.n3()) { + val += operand(i0_base + i0_win, i1_base + i1_win, + i2_base + i2_win); + } + } + } + } + (*result)(i0, i1, i2) = val; + } + } + } + return result; +} + /* static */ std::unique_ptr> ReferenceUtil::ReduceWindow4DGeneric( const Array4D& operand, float init, @@ -520,7 +577,7 @@ ReferenceUtil::ConvArray4DGeneralDimensionsDilated( HloEvaluator evaluator; std::unique_ptr result_literal = - evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); + evaluator.Evaluate(*computation, {}).ConsumeValueOrDie(); CHECK_EQ(ShapeUtil::Rank(result_literal->shape()), 4); auto result = @@ -594,8 +651,12 @@ ReferenceUtil::ReduceToRowArray2D( i2 == 0 || (dim_set.count(2) && i2 < array.n3()); ++i2) { for (int64 i3 = 0; i3 == 0 || (dim_set.count(3) && i3 < array.n4()); ++i3) { - accumulator = reduce_function( - accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); + // Handle zero-sized arrays. + if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 && + array.n4() > 0) { + accumulator = reduce_function( + accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3)); + } } } } diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h index 62d455d71a70407e903a1e0be42a7e9f1898e523..3ec96f2f38b8f91e1549419b60481327fa9bbd5f 100644 --- a/tensorflow/compiler/xla/reference_util.h +++ b/tensorflow/compiler/xla/reference_util.h @@ -70,7 +70,7 @@ class ReferenceUtil { // dilation factors. static std::unique_ptr> ConvArray4DGeneralDimensionsDilated( const Array4D& lhs, const Array4D& rhs, - std::pair stride, Padding padding, + std::pair kernel_stride, Padding padding, std::pair lhs_dilation, std::pair rhs_dilation, ConvolutionDimensionNumbers dnums); @@ -173,6 +173,10 @@ class ReferenceUtil { const Array2D& operand, float init, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow3DAdd( + const Array3D& operand, float init, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, Padding padding); static std::unique_ptr> ReduceWindow4DAdd( const Array4D& operand, float init, const tensorflow::gtl::ArraySlice& window, @@ -184,11 +188,18 @@ class ReferenceUtil { const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + static std::unique_ptr> ReduceWindow1DGeneric( + const tensorflow::gtl::ArraySlice& operand, float init, + const std::function& reduce_func, + const tensorflow::gtl::ArraySlice& window, + const tensorflow::gtl::ArraySlice& stride, + const tensorflow::gtl::ArraySlice>& padding); static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, const tensorflow::gtl::ArraySlice& window, const tensorflow::gtl::ArraySlice& stride, Padding padding); + // With arbitrary padding. static std::unique_ptr> ReduceWindow4DGeneric( const Array4D& operand, float init, const std::function& reduce_func, diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 846ccdc83df900e3afedb6ababe07ebb1bd68f41..9da9bc60a2025e63b57a3be9ed360d150f88d73c 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -86,6 +86,13 @@ TEST_F(ReferenceUtilTest, ReduceToRowArray2D) { ErrorSpec(0.0001)); } +TEST_F(ReferenceUtilTest, Reduce4Dto1DZeroSizedArray) { + auto result = Literal::CreateR1(ReferenceUtil::Reduce4DTo1D( + Array4D(1, 0, 1, 1), /*init=*/0, /*dims=*/{0, 1, 2}, + [](float a, float b) { return a + b; })); + LiteralTestUtil::ExpectR1Equal({0}, *result); +} + TEST_F(ReferenceUtilTest, MapArray2D) { auto identity = [](float value) { return log(exp(value)); }; auto result = ReferenceUtil::MapArray2D(*matrix_, identity); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index d3175c1e49974b060cc495d463d4995c925abcf7..83c67ed9368bc617a90c528f200b566ee8754edd 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -29,6 +29,11 @@ xla_proto_library( deps = ["//tensorflow/compiler/xla:xla_data_proto"], ) +xla_proto_library( + name = "hlo_profile_printer_data", + srcs = ["hlo_profile_printer_data.proto"], +) + # Filegroup used to collect source files for dependency checking. filegroup( name = "c_srcs", @@ -38,6 +43,81 @@ filegroup( ]), ) +cc_library( + name = "bfloat16_support", + srcs = ["bfloat16_support.cc"], + hdrs = ["bfloat16_support.h"], + deps = [ + ":hlo", + ], +) + +cc_library( + name = "bfloat16_conversion_folding", + srcs = ["bfloat16_conversion_folding.cc"], + hdrs = ["bfloat16_conversion_folding.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_conversion_folding_test", + srcs = ["bfloat16_conversion_folding_test.cc"], + deps = [ + ":bfloat16_conversion_folding", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +cc_library( + name = "bfloat16_normalization", + srcs = ["bfloat16_normalization.cc"], + hdrs = ["bfloat16_normalization.h"], + deps = [ + ":bfloat16_support", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "bfloat16_normalization_test", + srcs = ["bfloat16_normalization_test.cc"], + deps = [ + ":bfloat16_normalization", + ":bfloat16_support", + ":hlo", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "shape_inference", srcs = ["shape_inference.cc"], @@ -108,6 +188,7 @@ tf_cc_test( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:reference_util", "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:test", @@ -115,6 +196,7 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep @@ -450,8 +532,10 @@ cc_library( ":hlo_evaluator", ":hlo_execution_profile", ":hlo_module_config", + ":hlo_proto_util", ":platform_util", ":session_proto", + ":source_map_util", ":transfer_manager", ":user_computation", ":versioned_computation_handle", @@ -500,6 +584,7 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:executable_build_options", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], @@ -903,6 +988,7 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", ], ) @@ -1009,9 +1095,9 @@ tf_cc_test( ) cc_library( - name = "batchnorm_rewriter", - srcs = ["batchnorm_rewriter.cc"], - hdrs = ["batchnorm_rewriter.h"], + name = "batchnorm_expander", + srcs = ["batchnorm_expander.cc"], + hdrs = ["batchnorm_expander.h"], deps = [ ":hlo", ":hlo_pass", @@ -1029,11 +1115,11 @@ cc_library( ) tf_cc_test( - name = "batchnorm_rewriter_test", + name = "batchnorm_expander_test", size = "small", - srcs = ["batchnorm_rewriter_test.cc"], + srcs = ["batchnorm_expander_test.cc"], deps = [ - ":batchnorm_rewriter", + ":batchnorm_expander", ":hlo", ":hlo_matchers", ":hlo_pass", @@ -1082,6 +1168,7 @@ tf_cc_test( "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep @@ -1143,6 +1230,49 @@ tf_cc_test( ], ) +cc_library( + name = "implicit_broadcast_remover", + srcs = ["implicit_broadcast_remover.cc"], + hdrs = ["implicit_broadcast_remover.h"], + deps = [ + ":hlo", + ":hlo_dce", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "implicit_broadcast_remover_test", + srcs = ["implicit_broadcast_remover_test.cc"], + deps = [ + ":hlo_matchers", + ":implicit_broadcast_remover", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + ], +) + +cc_library( + name = "dot_decomposer", + srcs = ["dot_decomposer.cc"], + hdrs = ["dot_decomposer.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:lib", + ], +) + cc_library( name = "tuple_simplifier", srcs = ["tuple_simplifier.cc"], @@ -1663,6 +1793,7 @@ tf_cc_test( ":hlo", ":hlo_graph_dumper", ":hlo_matchers", + ":hlo_runner", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:test", @@ -1670,7 +1801,6 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:test", ], ) @@ -1703,6 +1833,22 @@ cc_library( ], ) +tf_cc_test( + name = "hlo_verifier_test", + srcs = ["hlo_verifier_test.cc"], + deps = [ + ":hlo", + ":hlo_verifier", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "hlo_rematerialization", srcs = ["hlo_rematerialization.cc"], @@ -1781,7 +1927,9 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -1812,6 +1960,7 @@ cc_library( ":hlo", ":hlo_graph_dumper", ":hlo_pass", + ":hlo_proto_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -1889,6 +2038,32 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_element_type_converter", + srcs = ["hlo_element_type_converter.cc"], + hdrs = ["hlo_element_type_converter.h"], + deps = [ + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + ":hlo_query", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "hlo_element_type_converter_test", + srcs = ["hlo_element_type_converter_test.cc"], + deps = [ + ":hlo_element_type_converter", + ":hlo_matchers", + "//tensorflow/compiler/xla/tests:hlo_test_base", + ], +) + cc_library( name = "device_memory_allocator", srcs = ["device_memory_allocator.cc"], @@ -2021,6 +2196,7 @@ cc_library( "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla:xla_proto", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:regexp_internal", ], alwayslink = 1, @@ -2074,6 +2250,41 @@ tf_cc_test( ], ) +cc_library( + name = "zero_sized_hlo_elimination", + srcs = ["zero_sized_hlo_elimination.cc"], + hdrs = ["zero_sized_hlo_elimination.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "zero_sized_hlo_elimination_test", + srcs = ["zero_sized_hlo_elimination_test.cc"], + deps = [ + ":hlo", + ":shape_inference", + ":zero_sized_hlo_elimination", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + cc_library( name = "pool", hdrs = ["pool.h"], @@ -2165,11 +2376,96 @@ cc_library( srcs = ["hlo_profile_printer.cc"], hdrs = ["hlo_profile_printer.h"], deps = [ + ":hlo_profile_printer_data", ":human_readable_profile_builder", "//tensorflow/compiler/xla:types", ], ) +cc_library( + name = "tuple_util", + srcs = ["tuple_util.cc"], + hdrs = ["tuple_util.h"], + deps = [ + ":hlo", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "tuple_util_test", + srcs = ["tuple_util_test.cc"], + deps = [ + ":tuple_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_util", + srcs = ["while_util.cc"], + hdrs = ["while_util.h"], + deps = [ + ":call_inliner", + ":hlo", + ":tuple_util", + ], +) + +tf_cc_test( + name = "while_util_test", + srcs = ["while_util_test.cc"], + deps = [ + ":while_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + ], +) + +cc_library( + name = "while_loop_invariant_code_motion", + srcs = ["while_loop_invariant_code_motion.cc"], + hdrs = ["while_loop_invariant_code_motion.h"], + deps = [ + ":hlo", + ":hlo_pass", + ":tuple_util", + ":while_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_invariant_code_motion_test", + srcs = ["while_loop_invariant_code_motion_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_invariant_code_motion", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "source_map_util", + srcs = ["source_map_util.cc"], + hdrs = ["source_map_util.h"], + deps = [ + ":executable", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index 71491218aa221cb26ea45f288ddc47173a15df3f..fb857559f972a220a19b108baa4c441e09b90e1f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -193,6 +193,33 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { enable_dot_strength_reduction_(enable_dot_strength_reduction), enable_conv_simplification_(enable_conv_simplification) {} + // Transforms Dots where at least one input is a vector or has a degenerate + // dimension and converts it into a multiply and reduce. This should enable + // more fusion than leaving the nodes as Dot operations. + StatusOr HandleDotStrengthReduction(HloInstruction* dot); + + // Reshapes an instruction to rank 1 if it is not already rank 1. + HloInstruction* Flatten(HloInstruction* hlo) { + if (ShapeUtil::Rank(hlo->shape()) == 1) { + return hlo; + } + return computation_->AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(hlo->shape().element_type(), + {ShapeUtil::ElementsIn(hlo->shape())}), + hlo)); + } + + // Helper method to perform and add reduction in a single dimension. + HloInstruction* AddReduce(HloInstruction* hlo, int64 dim) { + HloInstruction* zero = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + HloComputation* AddReduce_computation = CreateScalarBinaryComputation( + computation_->parent(), F32, HloOpcode::kAdd); + Shape shape = ShapeUtil::DeleteDimension(dim, hlo->shape()); + return computation_->AddInstruction(HloInstruction::CreateReduce( + shape, hlo, zero, {dim}, AddReduce_computation)); + } + // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -252,6 +279,11 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + StatusOr OptimizeDotOfConcat(HloInstruction* dot); + StatusOr OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped); + // Current HloComputation instance the AlgebraicSimplifierVisitor is // traversing. HloComputation* computation_; @@ -329,6 +361,39 @@ Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { return Status::OK(); } + // Canonicalization: Put constants on the right. This makes the reassociation + // rules below simpler. + VLOG(10) << "trying transform [Const + A => A + Const]"; + if (lhs->IsConstant() && !rhs->IsConstant()) { + return ReplaceWithNewInstruction( + add, + HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, rhs, lhs)); + } + + // Reassociate to allow constant folding. + // + // Note: This is not general. For example, we won't reassociate + // + // (A + C1) + (B + C2) => A + B + (C1 + C2). + // + VLOG(10) << "trying transform [(A + C1) + C2 => A + (C1 + C2)]"; + if (rhs->IsConstant() && lhs->opcode() == HloOpcode::kAdd && + !lhs->operand(0)->IsConstant() && lhs->operand(1)->IsConstant()) { + auto* c1 = lhs->mutable_operand(1); + auto* c2 = rhs; + TF_ASSIGN_OR_RETURN( + Shape sum_of_constants_shape, + ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, c1, c2)); + + auto* sum_of_constants = + computation_->AddInstruction(HloInstruction::CreateBinary( + sum_of_constants_shape, HloOpcode::kAdd, c1, c2)); + return ReplaceWithNewInstruction( + add, HloInstruction::CreateBinary(add->shape(), HloOpcode::kAdd, + lhs->mutable_operand(0), + sum_of_constants)); + } + return Status::OK(); } @@ -433,13 +498,14 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, if (ShapeUtil::IsTuple(literal.shape())) { std::vector elems; elems.reserve(ShapeUtil::TupleElementCount(literal.shape())); - for (const Literal& child : literal.tuple_literals()) { - elems.push_back(BuildTupleConstant(computation, child)); + for (int i = 0; i < ShapeUtil::TupleElementCount(literal.shape()); ++i) { + elems.push_back( + BuildTupleConstant(computation, LiteralView::Create(literal, {i}))); } return computation->AddInstruction(HloInstruction::CreateTuple(elems)); } else { return computation->AddInstruction( - HloInstruction::CreateConstant(MakeUnique(literal))); + HloInstruction::CreateConstant(literal.CloneToUnique())); } } @@ -462,6 +528,16 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { return Status::OK(); } + // Canonicalize subtraction of a constant to addition. + VLOG(10) << "trying transform [A - Const => A + (-Const)]"; + if (rhs->IsConstant() && !lhs->IsConstant()) { + HloInstruction* negative_const = computation_->AddInstruction( + HloInstruction::CreateUnary(rhs->shape(), HloOpcode::kNegate, rhs)); + return ReplaceWithNewInstruction( + sub, HloInstruction::CreateBinary(sub->shape(), HloOpcode::kAdd, lhs, + negative_const)); + } + return Status::OK(); } @@ -523,6 +599,23 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } + // A / Const => A * (1 / Const) + // + // (Backends can do this transformation, but generally only if the constant is + // a scalar.) + if (lhs->opcode() != HloOpcode::kConstant && + rhs->opcode() == HloOpcode::kConstant) { + HloInstruction* one = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::One(lhs->shape().element_type()).CloneToUnique())); + HloInstruction* inverse = + computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kDivide, one, rhs)); + return ReplaceWithNewInstruction( + divide, HloInstruction::CreateBinary( + divide->shape(), HloOpcode::kMultiply, lhs, inverse)); + } + // (A / B) / (C / D) => (A / B)*(D / C) => (A * D) / (B * C) if (lhs->opcode() == HloOpcode::kDivide && rhs->opcode() == HloOpcode::kDivide) { @@ -574,70 +667,72 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { - auto lhs = dot->mutable_operand(0); - auto rhs = dot->mutable_operand(1); +StatusOr AlgebraicSimplifierVisitor::HandleDotStrengthReduction( + HloInstruction* dot) { + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + int64 lhs_collapsing_dim = + dot->dot_dimension_numbers().lhs_contracting_dimensions(0); + if (lhs->IsRank2Transpose()) { + lhs = lhs->mutable_operand(0); + lhs_collapsing_dim = 1 - lhs_collapsing_dim; + } + const int64 lhs_kept_dim = 1 - lhs_collapsing_dim; + + int64 rhs_collapsing_dim = + dot->dot_dimension_numbers().rhs_contracting_dimensions(0); + if (rhs->IsRank2Transpose()) { + rhs = rhs->mutable_operand(0); + rhs_collapsing_dim = 1 - rhs_collapsing_dim; + } + const int64 rhs_kept_dim = 1 - rhs_collapsing_dim; + + auto reshape_if_necessary = [&](HloInstruction* hlo) { + if (ShapeUtil::SameDimensions(hlo->shape(), dot->shape())) { + return hlo; + } + return computation_->AddInstruction( + HloInstruction::CreateReshape(dot->shape(), hlo)); + }; - // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or - // below. - if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || - ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { - return Status::OK(); - } + auto broadcast_to_dim = [&](HloInstruction* hlo, const Shape& shape, + int64 dim) { + return computation_->AddInstruction( + HloInstruction::CreateBroadcast(shape, hlo, {dim})); + }; - // Replace a zero element dot with a broadcast of the constant 0. - if (ShapeUtil::HasZeroElements(dot->shape()) || - ShapeUtil::HasZeroElements(lhs->shape()) || - ShapeUtil::HasZeroElements(rhs->shape())) { - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); - } + auto multiply = [&](HloInstruction* local_lhs, HloInstruction* local_rhs) { + return computation_->AddInstruction(HloInstruction::CreateBinary( + local_lhs->shape(), HloOpcode::kMultiply, local_lhs, local_rhs)); + }; - // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). - if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { - auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot, - rhs->mutable_operand(0), lhs->mutable_operand(0))); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); + // Strength reduce dot(a[K] , b[K]) = + // reshape(result.shape, + // reduce_sum(multiply(a, b), {0})) + if (ShapeUtil::Rank(rhs->shape()) == 1 && + ShapeUtil::Rank(lhs->shape()) == 1) { + TF_RETURN_IF_ERROR( + ReplaceInstruction(dot, reshape_if_necessary(AddReduce( + multiply(Flatten(lhs), Flatten(rhs)), 0)))); + return true; } - if (!enable_dot_strength_reduction_) { - return Status::OK(); + if (ShapeUtil::IsEffectiveScalar(rhs->shape()) && + ShapeUtil::IsEffectiveScalar(lhs->shape())) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(multiply(Flatten(lhs), Flatten(rhs))))); + return true; } // Simplify outer product into multiply with implicit broadcasting. // // A dot(a[M, 1], b[1, N]) = multiply(a [M,1], b [1, N]) - if (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(0) == 1) { - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateBinary(dot->shape(), HloOpcode::kMultiply, - lhs, rhs)); - } - - // The following graph transformations take Dots where at least one input is a - // vector or has a degenerate dimension and converts it into a multiply and - // reduce. This should enable more fusion than leaving the nodes as Dot - // operations. - - // Strength reduce dot(a[K] , b[K]) = - // reshape(result.shape, - // reduce_sum(multiply(a, b), {0})) - if (ShapeUtil::Rank(rhs->shape()) == 1 && - ShapeUtil::Rank(lhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, lhs, rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + if (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_collapsing_dim) == 1) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, multiply(broadcast_to_dim(Flatten(lhs), dot->shape(), 0), + broadcast_to_dim(Flatten(rhs), dot->shape(), 1)))); + return true; } // Strength reduce dot(a[1, K], b) = @@ -648,35 +743,21 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // ) // ) if (ShapeUtil::Rank(lhs->shape()) == 1 || - (ShapeUtil::Rank(lhs->shape()) == 2 && lhs->shape().dimensions(0) == 1)) { - auto new_lhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(lhs->shape().element_type(), - {ShapeUtil::ElementsIn(lhs->shape())}), - lhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - HloInstruction* reduce; + (ShapeUtil::Rank(lhs->shape()) == 2 && + lhs->shape().dimensions(lhs_kept_dim) == 1)) { if (ShapeUtil::Rank(rhs->shape()) == 1) { - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), {}), multiply, zero, - {0}, add_reduce_computation)); - } else { - new_lhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(rhs->shape(), new_lhs, {0})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - rhs->shape(), HloOpcode::kMultiply, new_lhs, rhs)); - - reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {rhs->shape().dimensions(1)}), - multiply, zero, {0}, add_reduce_computation)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, + reshape_if_necessary(AddReduce(multiply(Flatten(lhs), rhs), 0)))); + return true; } - return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary( + AddReduce(multiply(broadcast_to_dim(Flatten(lhs), rhs->shape(), + rhs_collapsing_dim), + rhs), + rhs_collapsing_dim)))); + return true; } // Strength reduce dot(a, b[K, 1]) = @@ -684,26 +765,208 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { // reduce_sum(multiply(a, broadcast(reshape([K],b), {1})), {0}) // ) if (ShapeUtil::Rank(rhs->shape()) == 1 || - (ShapeUtil::Rank(rhs->shape()) == 2 && rhs->shape().dimensions(1) == 1)) { - auto new_rhs = computation_->AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(rhs->shape().element_type(), - {ShapeUtil::ElementsIn(rhs->shape())}), - rhs)); - new_rhs = computation_->AddInstruction( - HloInstruction::CreateBroadcast(lhs->shape(), new_rhs, {1})); - auto multiply = computation_->AddInstruction(HloInstruction::CreateBinary( - lhs->shape(), HloOpcode::kMultiply, lhs, new_rhs)); - HloComputation* add_reduce_computation = CreateScalarBinaryComputation( - computation_->parent(), F32, HloOpcode::kAdd); + (ShapeUtil::Rank(rhs->shape()) == 2 && + rhs->shape().dimensions(rhs_kept_dim) == 1)) { + TF_RETURN_IF_ERROR(ReplaceInstruction( + dot, reshape_if_necessary(AddReduce( + multiply(lhs, broadcast_to_dim(Flatten(rhs), lhs->shape(), + lhs_collapsing_dim)), + lhs_collapsing_dim)))); + return true; + } + return false; +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcat( + HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_contracting_dimensions_size() != 1 || + dnums.lhs_batch_dimensions_size() != 0) { + return nullptr; + } + + const int64 lhs_contracting_dim = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dim = dnums.rhs_contracting_dimensions(0); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + + TF_ASSIGN_OR_RETURN( + HloInstruction * optimized_lhs_concat, + OptimizeDotOfConcatHelper(dot->shape(), lhs, lhs_contracting_dim, rhs, + rhs_contracting_dim, /*swapped=*/false)); + if (optimized_lhs_concat) { + return optimized_lhs_concat; + } + + return OptimizeDotOfConcatHelper(dot->shape(), rhs, rhs_contracting_dim, lhs, + lhs_contracting_dim, /*swapped=*/true); +} + +StatusOr AlgebraicSimplifierVisitor::OptimizeDotOfConcatHelper( + const Shape& dot_shape, HloInstruction* lhs, int64 lhs_contracting_dim, + HloInstruction* rhs, int64 rhs_contracting_dim, bool swapped) { + bool can_optimize = lhs->opcode() == HloOpcode::kConcatenate && + lhs->concatenate_dimension() == lhs_contracting_dim && + rhs->opcode() == HloOpcode::kConstant; + if (!can_optimize) { + return nullptr; + } + + // We're replacing this: + // + // +-----+-----+-----+ +-------------------+ + // | | | | | | + // | | | | | R_0 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | L_0 | L_1 | L_2 | * | R_1 | + // | | | | | | + // | | | | +-------------------+ + // | | | | | | + // | | | | | R_2 | + // | | | | | | + // +-----+-----+-----+ +-------------------+ + // + // with this: + // + // [Sum over i] + // + // +-----+ +-------------------+ + // | | | | + // | | * | R_i | + // | | | | + // | | +-------------------+ + // | | + // | L_i | + // | | + // | | + // | | + // | | + // | | + // +-----+ + // + // where the LHS is a concatenate operation (so we can "split" the LHS tensor + // for free) and the RHS is a constant tensor (and thus can be split at + // compile time). In the future, we may also want to do this when both the + // LHS and the RHS are concatenate operations that line up along the dimension + // being contracted over. + // + // We should be able to generalize this transform to work on a non-constant + // RHS when/if we have in-place slices or support input-fusing slices into + // Dots. + + // Dimension numbers for the new dot instructions we'll create (L_i * R_i in + // the diagram above). + DotDimensionNumbers new_dot_dnums; + new_dot_dnums.add_lhs_contracting_dimensions(swapped ? rhs_contracting_dim + : lhs_contracting_dim); + new_dot_dnums.add_rhs_contracting_dimensions(swapped ? lhs_contracting_dim + : rhs_contracting_dim); + + // Here we use the MKN notation, where the contracted dimension has K + // elements and the two non-contracted dimensions have M and N elements. + HloInstruction* add_result = nullptr; + int64 rhs_contracting_dim_offset = 0; + int64 n = rhs->shape().dimensions(1 - rhs_contracting_dim); + for (HloInstruction* concat_op : lhs->operands()) { + int64 sub_k = concat_op->shape().dimensions(lhs_contracting_dim); + Shape rhs_slice_shape(rhs->shape()); + rhs_slice_shape.set_dimensions(rhs_contracting_dim, sub_k); + + std::array start_indices; + start_indices[rhs_contracting_dim] = rhs_contracting_dim_offset; + start_indices[1 - rhs_contracting_dim] = 0; + + std::array limit_indices; + limit_indices[rhs_contracting_dim] = rhs_contracting_dim_offset + sub_k; + limit_indices[1 - rhs_contracting_dim] = n; + + HloInstruction* rhs_slice = + computation_->AddInstruction(HloInstruction::CreateSlice( + rhs_slice_shape, rhs, /*start_indices=*/start_indices, + /*limit_indices=*/limit_indices, /*strides=*/{1, 1})); + + // TODO(b/69062148): We can get rid of `swapped` once all backends support + // "non-canonical" contraction dimensions (that contracts dimension 1 of the + // LHS with dimension 0 of the RHS). But for now we keep the same + // contraction dimensions as the incoming dot operation to ensure the new + // dot operations can be lowered. + HloInstruction *new_dot_lhs, *new_dot_rhs; + if (swapped) { + new_dot_lhs = rhs_slice; + new_dot_rhs = concat_op; + } else { + new_dot_lhs = concat_op; + new_dot_rhs = rhs_slice; + } + + auto* new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_shape, new_dot_lhs, new_dot_rhs, new_dot_dnums)); + + if (add_result) { + add_result = computation_->AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, add_result, new_dot)); + } else { + add_result = new_dot; + } + + rhs_contracting_dim_offset += sub_k; + } + + return add_result; +} + +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { + auto lhs = dot->mutable_operand(0); + auto rhs = dot->mutable_operand(1); + + // Only optimize F32 dot operations where the dot, rhs and lhs are rank 2 or + // below. + if (dot->shape().element_type() != F32 || ShapeUtil::Rank(lhs->shape()) > 2 || + ShapeUtil::Rank(rhs->shape()) > 2 || ShapeUtil::Rank(dot->shape()) > 2) { + return Status::OK(); + } + + // Replace a zero element dot with a broadcast of the constant 0. + if (ShapeUtil::HasZeroElements(dot->shape()) || + ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { auto zero = computation_->AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); - auto reduce = computation_->AddInstruction(HloInstruction::CreateReduce( - ShapeUtil::MakeShape(dot->shape().element_type(), - {lhs->shape().dimensions(0)}), - multiply, zero, {1}, add_reduce_computation)); return ReplaceWithNewInstruction( - dot, HloInstruction::CreateReshape(dot->shape(), reduce)); + dot, HloInstruction::CreateBroadcast(dot->shape(), zero, {})); + } + + TF_ASSIGN_OR_RETURN(HloInstruction * dot_of_concat_optimized, + OptimizeDotOfConcat(dot)); + if (dot_of_concat_optimized) { + VLOG(10) << "Replaced dot(concat(...), constant) with add(dot(..., " + "constant)...)"; + return ReplaceInstruction(dot, dot_of_concat_optimized); + } + + if (enable_dot_strength_reduction_ && !is_layout_sensitive_) { + TF_ASSIGN_OR_RETURN(bool did_strength_reduction, + HandleDotStrengthReduction(dot)); + if (did_strength_reduction) { + return Status::OK(); + } + } + + // Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)). + if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) { + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot( + ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), + rhs->mutable_operand(0), lhs->mutable_operand(0), + dot_dimension_numbers)); + return ReplaceWithNewInstruction( + dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0})); } + return Status::OK(); } @@ -980,6 +1243,11 @@ Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { } Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { + if (ShapeUtil::HasZeroElements(pad->operand(0)->shape())) { + return ReplaceWithNewInstruction( + pad, HloInstruction::CreateBroadcast(pad->shape(), + pad->mutable_operand(1), {})); + } // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. bool all_zero = true; @@ -1120,6 +1388,27 @@ Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kDivide, broadcast_one, lhs)); } + + VLOG(10) << "trying transform [pow(pow(A, X), Y) => pow(A, X*Y)]: " + << power->ToString(); + + // Don't perform this optimization if either of the exponents is complex; this + // identity is true only for real-valued exponents. In addition, we cowardly + // refuse to do this transformation if the two expontents have different + // element types. + if (lhs->opcode() == HloOpcode::kPower && + !ShapeUtil::ElementIsComplex(lhs->operand(1)->shape()) && + !ShapeUtil::ElementIsComplex(rhs->shape()) && + ShapeUtil::SameElementType(lhs->operand(1)->shape(), rhs->shape())) { + auto exponent_product = + computation_->AddInstruction(HloInstruction::CreateBinary( + rhs->shape(), HloOpcode::kMultiply, lhs->mutable_operand(1), rhs)); + return ReplaceWithNewInstruction( + power, HloInstruction::CreateBinary(power->shape(), HloOpcode::kPower, + lhs->mutable_operand(0), + exponent_product)); + } + return Status::OK(); } @@ -1173,7 +1462,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), - AsInt64Slice(operand->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(operand->shape())), new_user_operands)); VLOG(4) << " new user: " << new_user->ToString(); HloInstruction* new_reshape_or_broadcast = nullptr; @@ -1183,8 +1472,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - AsInt64Slice( - reshape_or_broadcast->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), new_user)); } else { TF_RET_CHECK(reshape_or_broadcast->opcode() == HloOpcode::kBroadcast); @@ -1193,8 +1481,7 @@ StatusOr AlgebraicSimplifierVisitor:: ShapeUtil::MakeShapeWithLayout( user->shape().element_type(), AsInt64Slice(reshape_or_broadcast->shape().dimensions()), - AsInt64Slice( - reshape_or_broadcast->shape().layout().minor_to_major())), + LayoutUtil::MinorToMajor(reshape_or_broadcast->shape())), new_user, reshape_or_broadcast->dimensions())); } VLOG(4) << " new reshape/broadcast: " @@ -1331,9 +1618,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { reduce, HloInstruction::CreateBroadcast(reduce->shape(), init_value, {})); } + // A Transpose feeding a reduce can simply permute the reduction dimensions - // field. - if (arg->opcode() == HloOpcode::kTranspose) { + // field if the output of the reduce is a vector or scalar. Higher ranked + // result may require a transpose of the output. + if (ShapeUtil::Rank(reduce->shape()) <= 1 && + arg->opcode() == HloOpcode::kTranspose) { auto transpose_dimensions = arg->dimensions(); std::vector new_reduce_dimensions; for (auto dim : dimensions) { @@ -1403,6 +1693,12 @@ Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { Status AlgebraicSimplifierVisitor::HandleReduceWindow( HloInstruction* reduce_window) { + if (ShapeUtil::HasZeroElements(reduce_window->operand(0)->shape())) { + return ReplaceWithNewInstruction( + reduce_window, + HloInstruction::CreateBroadcast(reduce_window->shape(), + reduce_window->mutable_operand(1), {})); + } auto operand = reduce_window->mutable_operand(0); const Window& window = reduce_window->window(); auto function = reduce_window->to_apply(); @@ -1448,6 +1744,63 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( } } + // If the pad puts a single non-identity value in each window that we're + // reducing, then this is a broadcast. + HloInstruction* pad_operand = operand->mutable_operand(0); + auto is_effective_broadcast = [&] { + if (window_util::HasStride(window)) { + VLOG(10) << "Window has stride."; + return false; + } + if (!window_util::HasSymmetricPadding(pad_config)) { + VLOG(10) << "Window has uneven padding."; + return false; + } + for (int64 i = 0; i < pad_config.dimensions_size(); ++i) { + const auto& pad_dimension = pad_config.dimensions(i); + if ((pad_dimension.edge_padding_low() != 0 || + pad_dimension.edge_padding_high() != 0) && + pad_operand->shape().dimensions(i) != 1) { + VLOG(10) << "Found non-trivial dimension being padded: " << i; + return false; + } + } + VLOG(10) << "Found to be padding trivial dimensions only."; + + for (int64 i = 0; i < window.dimensions_size(); ++i) { + const auto& pad_dimension = pad_config.dimensions(i); + const WindowDimension& window_dimension = window.dimensions(i); + bool dimension_has_padding = (pad_dimension.edge_padding_low() != 0 || + pad_dimension.edge_padding_high() != 0); + if (dimension_has_padding && + window_dimension.size() < pad_dimension.edge_padding_low() + 1) { + VLOG(10) << "Found window did not cover single unpadded element in " + "dimension: " + << i; + return false; + } + if (pad_operand->shape().dimensions(i) != 1 && + window_dimension.size() != 1) { + VLOG(10) << "Found window covers more than one element in non-trivial " + "dimension: " + << i; + return false; + } + } + VLOG(10) << "Found window covers a single unpadded element."; + return true; + }; + if (is_effective_broadcast()) { + VLOG(10) << "Replacing pad/reduce-window with (implicit) broadcast."; + auto fadd = [this](std::unique_ptr x) { + return computation_->AddInstruction(std::move(x)); + }; + return ReplaceWithNewInstruction( + reduce_window, HloInstruction::CreateBroadcastSequence( + /*output_shape=*/reduce_window->shape(), + /*operand=*/pad_operand, fadd)); + } + // Carry out the folding of the pad into reduce_window. VLOG(10) << "Folding pad into reduce-window."; Window new_window = window; @@ -1465,7 +1818,7 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( return ReplaceWithNewInstruction( reduce_window, HloInstruction::CreateReduceWindow( /*shape=*/reduce_window->shape(), - /*operand=*/operand->mutable_operand(0), + /*operand=*/pad_operand, /*init_value=*/reduce_window->mutable_operand(1), /*window=*/new_window, /*reduce_computation=*/function)); @@ -1473,7 +1826,6 @@ Status AlgebraicSimplifierVisitor::HandleReduceWindow( Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { auto operand = transpose->mutable_operand(0); - if (std::is_sorted(transpose->dimensions().begin(), transpose->dimensions().end())) { VLOG(10) << "deleting no-op transpose"; @@ -1500,6 +1852,18 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( HloInstruction* convolution) { auto lhs = convolution->mutable_operand(0); auto rhs = convolution->mutable_operand(1); + if (ShapeUtil::HasZeroElements(lhs->shape()) || + ShapeUtil::HasZeroElements(rhs->shape())) { + return ReplaceWithNewInstruction( + convolution, + HloInstruction::CreateBroadcast( + convolution->shape(), + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::MakeShape(convolution->shape().element_type(), {}), + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))))), + {})); + } const auto& window = convolution->window(); if (!enable_conv_simplification_) { return Status::OK(); @@ -1556,15 +1920,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // still convert Conv into more efficient Matmul with operand transposition // (such as the transposition flags in cuBLAS SGEMM). if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || - input_shape.layout().minor_to_major(0) != + LayoutUtil::Minor(input_shape.layout(), 0) != dnums.input_feature_dimension() || - convolution_shape.layout().minor_to_major(0) != + LayoutUtil::Minor(convolution_shape.layout(), 0) != dnums.output_feature_dimension() || // The input feature dimension should come later in the minor-to-major // order. - (PositionInContainer(filter_shape.layout().minor_to_major(), + (PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_input_feature_dimension()) < - PositionInContainer(filter_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(filter_shape), dnums.kernel_output_feature_dimension()))) { return Status::OK(); } @@ -1592,18 +1956,15 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // We already checked feature_dimension is most minor, so data in input_shape // and row-major {conv_width,input_channels} are bitwise identical. - const Shape new_input_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {conv_width, input_channels}); + const Shape new_input_shape = ShapeUtil::MakeShapeWithDescendingLayout( + input_shape.element_type(), {conv_width, input_channels}); // We already checked input_feature_dimension is more major than // output_feature_dimension, so data in filter_shape and row-major // {input_channels,output_channels} are bitwise identical. - const Shape new_filter_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - filter_shape.element_type(), {input_channels, output_channels}); - const Shape dot_output_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - convolution_shape.element_type(), {conv_width, output_channels}); + const Shape new_filter_shape = ShapeUtil::MakeShapeWithDescendingLayout( + filter_shape.element_type(), {input_channels, output_channels}); + const Shape dot_output_shape = ShapeUtil::MakeShapeWithDescendingLayout( + convolution_shape.element_type(), {conv_width, output_channels}); // We cannot insert bitcasts if the layouts will not be compatible. // TODO(b/33178038): Consider inserting a transpose if a bitcast would be @@ -1616,8 +1977,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( auto new_lhs = add_bitcast(new_input_shape, lhs); auto new_rhs = add_bitcast(new_filter_shape, rhs); - auto dot = computation_->AddInstruction(HloInstruction::CreateBinary( - dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs)); + DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.add_lhs_contracting_dimensions(1); + dot_dimension_numbers.add_rhs_contracting_dimensions(0); + auto dot = computation_->AddInstruction(HloInstruction::CreateDot( + dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers)); return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot)); } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index 56dfb1cf0bc22ed62653d1f0772fdcae58498c27..0f08eb3a3267c4b7b04958270a5788fc48d3fa04 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -30,6 +30,7 @@ limitations under the License. #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -60,17 +61,63 @@ TEST_F(AlgebraicSimplifierTest, AddZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } +// Test that Const + A is canonicalized to A + Const. +TEST_F(AlgebraicSimplifierTest, AddConstOnLHS) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, constant, param0)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Constant())); +} + +// Test that [(A + C1) + C2] => [A + (C1 + C2)] for constants C1 and C2. +TEST_F(AlgebraicSimplifierTest, AddReassociateMergeConstants) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + HloInstruction* constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.14159f))); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, param0, constant1)); + builder.AddInstruction( + HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, add1, constant2)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kAdd); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Add(constant1, constant2))); +} + TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { Shape r2f32 = ShapeUtil::MakeShape(F32, {3, 2}); HloComputation::Builder builder(TestName()); @@ -83,13 +130,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR0Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -106,13 +152,12 @@ TEST_F(AlgebraicSimplifierTest, AddBroadcastZeroR1Operand) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kAdd, bcast, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kAdd); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -128,17 +173,37 @@ TEST_F(AlgebraicSimplifierTest, SubZero) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kSubtract, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } +// Test that A - Const is canonicalized to A + (-Const). +TEST_F(AlgebraicSimplifierTest, SubConstCanonicalization) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "param0")); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction(HloInstruction::CreateBinary( + r0f32, HloOpcode::kSubtract, param0, constant)); + + auto computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root->opcode(), HloOpcode::kSubtract); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_THAT(root, op::Add(param0, op::Negate(constant))); +} + // Test that (A/B)/C is simplified to A/(B*C). TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -154,15 +219,14 @@ TEST_F(AlgebraicSimplifierTest, LhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, div, param2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Divide(param0, param1), param2)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Multiply(param1, param2))); @@ -183,15 +247,14 @@ TEST_F(AlgebraicSimplifierTest, RhsDivOfDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, div)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Divide(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Multiply(param0, param2), param1)); @@ -217,8 +280,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, div0, div1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -226,7 +288,7 @@ TEST_F(AlgebraicSimplifierTest, DivOfDivAndDiv) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -248,15 +310,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, exp)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Exp(op::Negate(param1)))); @@ -277,15 +338,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, power)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -308,15 +368,14 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(param0, op::Power(param1, param2))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); ASSERT_THAT(computation->root_instruction(), op::Multiply(param0, op::Power(param1, op::Negate(param2)))); @@ -327,6 +386,75 @@ TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { EXPECT_EQ(0, negate_shape.dimensions_size()); } +// A / Const => A * (1 / Const) +TEST_F(AlgebraicSimplifierTest, DivideByConstant) { + Shape r1f32 = ShapeUtil::MakeShape(F32, {3}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* constant = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR1({0.f, 1.f, 2.f}))); + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, + param0, constant)); + + auto computation = module().AddEntryComputation(builder.Build()); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Divide(op::Constant(), constant))); +} + +// pow(pow(A, X), Y) => pow(A, X*Y) +TEST_F(AlgebraicSimplifierTest, PowerOfPower) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* base = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "param1")); + HloInstruction* exp2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* inner_power = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, + inner_power, exp2)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Power(base, op::Multiply(exp1, exp2))); +} + +// Don't simplify pow(pow(A, X), Y) => pow(A, X*Y) if X and Y are complex +// numbers. +TEST_F(AlgebraicSimplifierTest, PowerOfPowerComplex) { + Shape r0c64 = ShapeUtil::MakeShape(C64, {}); + Shape r1c64 = ShapeUtil::MakeShape(C64, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* base = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1c64, "param0")); + HloInstruction* exp1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r0c64, "param1")); + HloInstruction* exp2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0c64, "param2")); + HloInstruction* inner_power = builder.AddInstruction( + HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, base, exp1)); + builder.AddInstruction(HloInstruction::CreateBinary(r1c64, HloOpcode::kPower, + inner_power, exp2)); + + module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_FALSE(simplifier.Run(&module()).ValueOrDie()); +} + // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -338,13 +466,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneScalar) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -360,13 +487,12 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { HloInstruction* div = builder.AddInstruction( HloInstruction::CreateBinary(r2f32, HloOpcode::kDivide, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, div); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -385,13 +511,12 @@ TEST_F(AlgebraicSimplifierTest, ComplexOfRealImagC) { HloInstruction* cplx = builder.AddInstruction( HloInstruction::CreateBinary(r2c64, HloOpcode::kComplex, real, imag)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, cplx); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -410,13 +535,12 @@ TEST_F(AlgebraicSimplifierTest, RealOfComplex) { HloInstruction* real = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, real); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param0); } @@ -435,13 +559,12 @@ TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { HloInstruction* imag = builder.AddInstruction( HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, imag); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_EQ(root, param1); } @@ -463,13 +586,12 @@ TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { HloInstruction* add = builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kAdd, get, param2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root, add); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); root = computation->root_instruction(); EXPECT_THAT(root, op::Add(param1, param2)); } @@ -489,15 +611,14 @@ TEST_F(AlgebraicSimplifierTest, ExpDiv) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kDivide, exp0, exp1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Divide(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Subtract(param0, param1))); @@ -518,15 +639,14 @@ TEST_F(AlgebraicSimplifierTest, ExpMul) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kMultiply, exp0, exp1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Exp(param0), op::Exp(param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Add(param0, param1))); @@ -545,15 +665,14 @@ TEST_F(AlgebraicSimplifierTest, PowExp) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, exp0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(op::Exp(param0), param1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Exp(op::Multiply(param0, param1))); @@ -572,15 +691,14 @@ TEST_F(AlgebraicSimplifierTest, LnPow) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, pow)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Power(param0, param1))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Log(param0), param1)); @@ -597,14 +715,13 @@ TEST_F(AlgebraicSimplifierTest, LnExp) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, exp0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Exp(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -626,15 +743,14 @@ TEST_F(AlgebraicSimplifierTest, LnExpDiv) { builder.AddInstruction( HloInstruction::CreateUnary(r0f32, HloOpcode::kLog, div)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Log(op::Divide(op::Exp(param0), op::Exp(param1)))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Subtract(param0, param1)); } @@ -651,14 +767,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Scalar) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Constant()); @@ -676,14 +791,13 @@ TEST_F(AlgebraicSimplifierTest, Pow0Vector) { builder.AddInstruction( HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param0, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, zero)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Broadcast()); @@ -705,14 +819,13 @@ TEST_F(AlgebraicSimplifierTest, Pow1) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), param0); } @@ -728,14 +841,13 @@ TEST_F(AlgebraicSimplifierTest, Pow2) { builder.AddInstruction( HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, two)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, two)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(param0, param0)); } @@ -751,14 +863,13 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { builder.AddInstruction(HloInstruction::CreateBinary(r0f32, HloOpcode::kPower, param0, negative_one)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Power(param0, negative_one)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); HloInstruction* root = computation->root_instruction(); EXPECT_THAT(root, op::Divide(op::Broadcast(), param0)); @@ -767,6 +878,117 @@ TEST_F(AlgebraicSimplifierTest, PowNegative1) { 1); } +TEST_F(AlgebraicSimplifierTest, ZeroSizedConvolution) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 3, 0}), "lhs")); + + HloInstruction* rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {3, 0, 3}), "rhs")); + + ConvolutionDimensionNumbers dnums; + dnums.set_input_batch_dimension(0); + dnums.add_input_spatial_dimensions(1); + dnums.set_input_feature_dimension(2); + + dnums.set_output_batch_dimension(0); + dnums.add_output_spatial_dimensions(1); + dnums.set_output_feature_dimension(2); + + dnums.add_kernel_spatial_dimensions(0); + dnums.set_kernel_input_feature_dimension(1); + dnums.set_kernel_output_feature_dimension(2); + Window window; + WindowDimension* dim = window.add_dimensions(); + dim->set_size(3); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_stride(1); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + dim->set_window_reversal(false); + // Create add computation. + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {3, 3, 3}), lhs, rhs, window, dnums)); + module().AddEntryComputation(builder.Build()); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Convolution(lhs, rhs)); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedReduceWindow) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); + Window window; + for (int64 i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_size(1); + dim->set_padding_low(1); + dim->set_padding_high(1); + dim->set_window_dilation(1); + dim->set_base_dilation(1); + } + // Create add computation. + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + builder.AddInstruction(HloInstruction::CreateReduceWindow( + ShapeUtil::MakeShape(F32, {5, 2}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + window, add_computation)); + module().AddEntryComputation(builder.Build()); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::ReduceWindow(param, op::Constant())); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + +TEST_F(AlgebraicSimplifierTest, ZeroSizedPad) { + auto builder = HloComputation::Builder(TestName()); + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "op")); + PaddingConfig padding; + for (int i = 0; i < 2; ++i) { + PaddingConfig::PaddingConfigDimension* dimension = padding.add_dimensions(); + dimension->set_edge_padding_low(1); + dimension->set_edge_padding_high(1); + dimension->set_interior_padding(0); + } + builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(F32, {5, 2}), param, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + padding)); + module().AddEntryComputation(builder.Build()); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Pad(param, op::Constant())); + HloPassFix simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); + EXPECT_THAT(module().entry_computation()->root_instruction(), + op::Broadcast(op::Constant())); +} + TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -781,17 +1003,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeBroadcast) { ShapeUtil::MakeShape(F32, {3, 2}), broadcast)); auto computation = builder.Build(); - auto module = CreateNewModule(); - module->AddEntryComputation(std::move(computation)); + module().AddEntryComputation(std::move(computation)); - EXPECT_THAT(module->entry_computation()->root_instruction(), + EXPECT_THAT(module().entry_computation()->root_instruction(), op::Reshape(op::Broadcast(op::Reshape(op)))); HloPassFix simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); - EXPECT_THAT(module->entry_computation()->root_instruction(), op); + EXPECT_THAT(module().entry_computation()->root_instruction(), op); } // Test that convert(A, $TYPE) is simplified to A if A is of type $TYPE. @@ -802,14 +1023,13 @@ TEST_F(AlgebraicSimplifierTest, ConvertBetweenSameType) { builder.AddInstruction( HloInstruction::CreateConvert(ShapeUtil::MakeShape(F32, {}), input)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Convert(input)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), input); } @@ -823,14 +1043,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveCopy) { builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } @@ -844,14 +1063,13 @@ TEST_F(AlgebraicSimplifierTest, RemoveUnaryConcatenate) { builder.AddInstruction( HloInstruction::CreateConcatenate(param0->shape(), {param0}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), param0); } @@ -874,8 +1092,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, param0, param0, empty_slice, param1}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), @@ -883,7 +1100,7 @@ TEST_F(AlgebraicSimplifierTest, RemoveEmptyConcatenateOperands) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(param0, param0, param1)); @@ -905,15 +1122,14 @@ TEST_F(AlgebraicSimplifierTest, OnlyEmptyConcatenateOperands) { builder.AddInstruction(HloInstruction::CreateConcatenate( result_shape, {empty_literal, empty_slice}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Concatenate(empty_literal, empty_slice)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_EQ(computation->root_instruction(), empty_literal); } @@ -930,14 +1146,13 @@ TEST_F(AlgebraicSimplifierTest, ConcatenateOfBroadcastBecomesPad) { HloInstruction* broadcast = builder.AddInstruction( HloInstruction::CreateBroadcast(r1f32, param1, {})); builder.AddInstruction(HloInstruction::CreateConcatenate( - param0->shape(), {broadcast, param0}, 0)); + ShapeUtil::MakeShape(F32, {200}), {broadcast, param0}, 0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Pad(param0, param1)); } @@ -951,8 +1166,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); // Set to different layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -962,7 +1176,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithDifferentLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); // Copy has not been removed. EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); @@ -978,8 +1192,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { HloInstruction* copy = builder.AddInstruction( HloInstruction::CreateUnary(param0->shape(), HloOpcode::kCopy, param0)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); // Set to same layouts. *param0->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); @@ -989,7 +1202,7 @@ TEST_F(AlgebraicSimplifierTest, CopyWithSameLayout) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Copy has been removed. EXPECT_THAT(computation->root_instruction(), param0); @@ -1010,14 +1223,13 @@ TEST_F(AlgebraicSimplifierTest, NoBitcastAdded) { *reshape->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3, 4, 5}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); // Reshape is not replaced with a bitcast. EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); @@ -1056,8 +1268,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { builder.AddInstruction(HloInstruction::CreateTuple( {transformable_reshape, dimensions_wrong_reshape, layout_wrong_reshape})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Tuple(transformable_reshape, dimensions_wrong_reshape, @@ -1065,7 +1276,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeReplacedWithBitcast) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); // Verify that only the first reshape is replaced. EXPECT_THAT( @@ -1086,8 +1297,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { builder.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(F32, {1, 2, 3, 4, 5}), HloOpcode::kMaximum, movable_reshape, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1095,7 +1305,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAfterEffectiveUnary) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Maximum(param, zero))); } @@ -1113,8 +1323,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { HloInstruction::CreateConstant(Literal::CreateR1({1., 2., 3.}))); builder.AddInstruction(HloInstruction::CreateBinary( ShapeUtil::MakeShape(F32, {3}), HloOpcode::kMaximum, reshape, zero)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1122,7 +1331,7 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - simplifier.Run(module.get()).ValueOrDie(); + simplifier.Run(&module()).ValueOrDie(); EXPECT_THAT(computation->root_instruction(), op::Maximum(op::Reshape(param), zero)); @@ -1147,9 +1356,8 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); } // Regression test for a bug where if we failed to sink a reshape, we'd set the @@ -1166,14 +1374,14 @@ TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{0, 0}, {0, 0}}))))); - builder.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0})); + builder.AddInstruction( + HloInstruction::CreateBroadcast(ShapeUtil::MakeShape(F32, {2, 2, 2}), add, + /*broadcast_dimensions=*/{0, 1})); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, bitcasting_callback()); - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(&module()).ValueOrDie()); } TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { @@ -1190,14 +1398,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1, 2, 3}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); @@ -1217,14 +1424,13 @@ TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast2) { *transpose->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({3, 1, 2, 0}); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); // Verify that the reshape is replaced. EXPECT_THAT(computation->root_instruction(), op::Bitcast(param)); @@ -1243,15 +1449,14 @@ TEST_F(AlgebraicSimplifierTest, ReshapesMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 2, 1, 1, 2, 1}), reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(param0)); } @@ -1260,7 +1465,7 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { HloComputation::Builder builder(TestName()); HloInstruction* param0 = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(F32, {2, 2, 2}), + 0, ShapeUtil::MakeShapeWithDescendingLayout(F32, {2, 2, 2}), "param0")); HloInstruction* copy1 = builder.AddInstruction(HloInstruction::CreateUnary( @@ -1271,14 +1476,13 @@ TEST_F(AlgebraicSimplifierTest, CopiesMerged) { ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 2, 1}), HloOpcode::kCopy, copy1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Copy(op::Copy(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/true, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Copy(param0)); } @@ -1296,14 +1500,13 @@ TEST_F(AlgebraicSimplifierTest, TransposesMerged) { builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {4, 3, 2}), transpose1, {1, 0, 2})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Transpose(transpose1)); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Transpose(param0)); EXPECT_EQ(std::vector({2, 1, 0}), @@ -1318,17 +1521,16 @@ TEST_F(AlgebraicSimplifierTest, ReshapeAndBroadcastMerged) { auto reshape1 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {1, 5, 1}), param0)); builder.AddInstruction(HloInstruction::CreateBroadcast( - ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 2, 3})); + ShapeUtil::MakeShape(F32, {1, 2, 3, 5, 1}), reshape1, {0, 3, 2})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(op::Reshape(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } @@ -1343,15 +1545,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshapeMerged) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {2, 3, 7, 2, 1, 3, 2}), broadcast1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param0))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param0)); } @@ -1365,15 +1566,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x1_3) { builder.AddInstruction( HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), broadcast)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); @@ -1388,15 +1588,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4_6x1x1x4) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 4}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); EXPECT_THAT(computation->root_instruction()->dimensions(), @@ -1412,15 +1611,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_1_3x2x1_6x1x1x1) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 1, 1, 1}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Broadcast(param)); const std::vector broadcast_dims = @@ -1438,15 +1636,14 @@ TEST_F(AlgebraicSimplifierTest, BroadcastAndReshape_4_3x2x4x2_6x8) { builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {6, 8}), broadcast)); - auto module = CreateNewModule(); - HloComputation* computation = module->AddEntryComputation(builder.Build()); + HloComputation* computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - EXPECT_FALSE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_FALSE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Broadcast(param))); @@ -2138,8 +2335,10 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x")); HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y")); - builder.AddInstruction( - HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums)); std::unique_ptr dot_computation(builder.Build()); HloComputation::Builder call_builder(TestName() + ".Call"); @@ -2150,12 +2349,11 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); - auto module = CreateNewModule(); - module->AddEmbeddedComputation(std::move(dot_computation)); - module->AddEntryComputation(call_builder.Build()); + module().AddEmbeddedComputation(std::move(dot_computation)); + module().AddEntryComputation(call_builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); } // Test that a constant with tuple shape becomes a tuple of constants. @@ -2168,12 +2366,11 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { Literal::CreateR1(constant_vector).get()}); builder.AddInstruction(HloInstruction::CreateConstant(std::move(value))); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Tuple(op::Constant(), op::Constant())); } @@ -2193,11 +2390,10 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), /*slice_sizes=*/{10, 100, 1000})); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Parameter()); } @@ -2227,14 +2423,354 @@ TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))))); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, non_bitcasting_callback()); - ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + ASSERT_TRUE(simplifier.Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::DynamicSlice(op::Parameter(), op::Parameter())); } +struct PadReduceWindowEffectiveBroadcastCase { + std::vector input_spatials; + std::vector symmetric_pad_spatials; + std::vector reduce_window_spatials; + // Whether to use `B F S0 S1` form vs `B S0 S1 F` form. + // + // This doesn't test any different functionality but is useful for making sure + // kBroadcast nodes are well formed. + bool prepend_a; + bool should_become_broadcast; + + string ToTestCaseName() const { + return tensorflow::strings::StrCat( + tensorflow::str_util::Join(input_spatials, ","), ";", + tensorflow::str_util::Join(symmetric_pad_spatials, ","), ";", + tensorflow::str_util::Join(reduce_window_spatials, ","), ";", prepend_a, + ";", should_become_broadcast); + } +}; + +void PrintTo(const PadReduceWindowEffectiveBroadcastCase& c, std::ostream* os) { + *os << c.ToTestCaseName(); +} + +class PadReduceWindowEffectiveBroadcastTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + PadReduceWindowEffectiveBroadcastCase> {}; + +TEST_P(PadReduceWindowEffectiveBroadcastTest, DoIt) { + const auto& param = GetParam(); + + // a and b are parallel bounds we can either turn into a B F S0 S1 or + // `B S0 S1 F` kind of pattern. + auto decorate_spatials = [¶m](tensorflow::gtl::ArraySlice spatials, + int64 a, int64 b) { + std::vector result; + if (param.prepend_a) { + result.push_back(a); + } + for (int64 s : spatials) { + result.push_back(s); + } + if (!param.prepend_a) { + result.push_back(a); + } + result.push_back(b); + return result; + }; + + HloComputation::Builder builder(TestName()); + const Shape input_shape = ShapeUtil::MakeShape( + F32, decorate_spatials(param.input_spatials, 128, 2048)); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, input_shape, "input")); + + PaddingConfig padding = window_util::MakeSymmetricPadding( + decorate_spatials(param.symmetric_pad_spatials, 0, 0)); + TF_ASSERT_OK_AND_ASSIGN( + const Shape pad_shape, + ShapeInference::InferPadShape(input->shape(), + ShapeUtil::MakeShape(F32, {}), padding)); + HloInstruction* pad = builder.AddInstruction(HloInstruction::CreatePad( + pad_shape, input, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))), + padding)); + + HloComputation* add_computation = nullptr; + { + HloComputation::Builder builder(TestName() + ".add"); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* p0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "p0")); + HloInstruction* p1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(scalar_shape, HloOpcode::kAdd, p0, p1)); + add_computation = module().AddEmbeddedComputation(builder.Build()); + } + + Window window = window_util::MakeWindow( + decorate_spatials(param.reduce_window_spatials, 1, 1)); + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0f))); + TF_ASSERT_OK_AND_ASSIGN(const Shape output_shape, + ShapeInference::InferReduceWindowShape( + pad->shape(), zero->shape(), window, + add_computation->ComputeProgramShape())); + builder.AddInstruction(HloInstruction::CreateReduceWindow( + output_shape, pad, zero, window, add_computation)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), output_shape)); + + if (param.should_become_broadcast) { + EXPECT_THAT(computation->root_instruction(), op::Broadcast(::testing::_)); + } else { + EXPECT_THAT(computation->root_instruction(), + op::ReduceWindow(::testing::_, zero)); + } +} + +const std::vector& +PadReduceWindowEffectiveBroadcastCases() { + static auto* cases = new std::vector{ + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6}, + /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, + /*should_become_broadcast=*/true}, // + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{6, 6}, + /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/false, + /*should_become_broadcast=*/true}, // + {/*input_spatials=*/{2, 2}, /*symmetric_pad_amount=*/{6, 6}, + /*reduce_window_spatials=*/{7, 7}, /*prepend_a=*/true, + /*should_become_broadcast=*/false}, // + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, + /*reduce_window_spatials=*/{5, 5}, /*prepend_a=*/true, + /*should_become_broadcast=*/true}, // + {/*input_spatials=*/{1, 1}, /*symmetric_pad_amount=*/{2, 2}, + /*reduce_window_spatials=*/{1, 1}, /*prepend_a=*/true, + /*should_become_broadcast=*/false}, // + {/*input_spatials=*/{5, 1}, /*symmetric_pad_amount=*/{0, 2}, + /*reduce_window_spatials=*/{2, 5}, /*prepend_a=*/true, + /*should_become_broadcast=*/false}, // + }; + return *cases; +} + +INSTANTIATE_TEST_CASE_P( + PadReduceWindowEffectiveBroadcastInstantiation, + PadReduceWindowEffectiveBroadcastTest, + ::testing::ValuesIn(PadReduceWindowEffectiveBroadcastCases())); + +class DotStrengthReductionTest + : public AlgebraicSimplifierTest, + public ::testing::WithParamInterface< + ::testing::tuple> {}; +TEST_P(DotStrengthReductionTest, DotStrengthReduction) { + int m, k, n; + bool transpose_lhs, transpose_rhs; + std::tie(m, k, n, transpose_lhs, transpose_rhs) = GetParam(); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape transposed_lhs_shape = ShapeUtil::MakeShape(F32, {k, m}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape transposed_rhs_shape = ShapeUtil::MakeShape(F32, {n, k}); + HloComputation::Builder builder(TestName()); + + auto lhs = builder.AddInstruction(HloInstruction::CreateParameter( + 0, transpose_lhs ? transposed_lhs_shape : lhs_shape, "lhs")); + if (transpose_lhs) { + lhs = builder.AddInstruction( + HloInstruction::CreateTranspose(lhs_shape, lhs, {1, 0})); + } + auto rhs = builder.AddInstruction(HloInstruction::CreateParameter( + 1, transpose_rhs ? transposed_rhs_shape : rhs_shape, "rhs")); + if (transpose_rhs) { + rhs = builder.AddInstruction( + HloInstruction::CreateTranspose(rhs_shape, rhs, {1, 0})); + } + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool changed, simplifier.Run(&module())); + const bool dot_should_be_transformed = m == 1 || k == 1 || n == 1; + const bool computation_should_be_modified = + dot_should_be_transformed || (transpose_lhs && transpose_rhs); + EXPECT_EQ(changed, computation_should_be_modified); + bool has_no_dot = true; + for (const auto& hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kDot) { + has_no_dot = false; + break; + } + } + EXPECT_EQ(has_no_dot, dot_should_be_transformed); +} + +INSTANTIATE_TEST_CASE_P( + DotStrengthReductionTestInstantiation, DotStrengthReductionTest, + ::testing::Combine(::testing::Values(1, 2), ::testing::Values(1, 2), + ::testing::Values(1, 2), ::testing::Bool(), + ::testing::Bool())); + +struct DotOfConcatTestSpec { + int64 m; + int64 k; + int64 n; +}; + +class DotOfConcatSimplificationTest + : public HloVerifiedTestBase, + public ::testing::WithParamInterface {}; + +// Test that we transform +// dot(const, concat(A, B, C)) +// to +// add(dot(const_0, A), dot(const_1, B), dot(const_2, C)) +TEST_P(DotOfConcatSimplificationTest, ConstantLHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 3); + + int64 k0 = spec.k / 3; + int64 k1 = spec.k / 3; + int64 k2 = spec.k - k0 - k1; + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + auto* lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.m, /*cols=*/spec.k))); + + Shape rhs0_shape = ShapeUtil::MakeShape(F32, {k0, spec.n}); + Shape rhs1_shape = ShapeUtil::MakeShape(F32, {k1, spec.n}); + Shape rhs2_shape = ShapeUtil::MakeShape(F32, {k2, spec.n}); + + HloInstruction* rhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, rhs0_shape, "rhs0")); + HloInstruction* rhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, rhs1_shape, "rhs1")); + HloInstruction* rhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, rhs2_shape, "rhs2")); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); + HloInstruction* rhs = builder.AddInstruction( + HloInstruction::CreateConcatenate(rhs_shape, {rhs0, rhs1, rhs2}, 0)); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Slice(op::Constant()), op::Parameter(0)); + auto match_dot_1 = op::Dot(op::Slice(op::Constant()), op::Parameter(1)); + auto match_dot_2 = op::Dot(op::Slice(op::Constant()), op::Parameter(2)); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2)); +} + +// Test that we transform +// dot(concat(A, B, C), const) +// to +// add(dot(A, const_0), dot(B, const_1), dot(C, const_2)) +TEST_P(DotOfConcatSimplificationTest, ConstantRHS) { + HloComputation::Builder builder(TestName()); + + DotOfConcatTestSpec spec = GetParam(); + + ASSERT_GE(spec.k, 4); + + int64 k0 = spec.k / 4; + int64 k1 = spec.k / 4; + int64 k2 = spec.k / 4; + int64 k3 = spec.k - k0 - k1 - k2; + + Shape lhs0_shape = ShapeUtil::MakeShape(F32, {spec.m, k0}); + Shape lhs1_shape = ShapeUtil::MakeShape(F32, {spec.m, k1}); + Shape lhs2_shape = ShapeUtil::MakeShape(F32, {spec.m, k2}); + Shape lhs3_shape = ShapeUtil::MakeShape(F32, {spec.m, k3}); + + HloInstruction* lhs0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, lhs0_shape, "lhs0")); + HloInstruction* lhs1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, lhs1_shape, "lhs1")); + HloInstruction* lhs2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, lhs2_shape, "lhs2")); + HloInstruction* lhs3 = builder.AddInstruction( + HloInstruction::CreateParameter(3, lhs3_shape, "lhs3")); + + Shape lhs_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.k}); + HloInstruction* lhs = + builder.AddInstruction(HloInstruction::CreateConcatenate( + lhs_shape, {lhs0, lhs1, lhs2, lhs3}, 1)); + + Shape rhs_shape = ShapeUtil::MakeShape(F32, {spec.k, spec.n}); + auto* rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2F32Linspace( + /*from=*/10.0, /*to=*/10000.0, /*rows=*/spec.k, /*cols=*/spec.n))); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + + Shape dot_shape = ShapeUtil::MakeShape(F32, {spec.m, spec.n}); + builder.AddInstruction( + HloInstruction::CreateDot(dot_shape, lhs, rhs, dot_dnums)); + + auto computation = module().AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + TF_ASSERT_OK_AND_ASSIGN(bool run_successful, simplifier.Run(&module())); + ASSERT_TRUE(run_successful); + EXPECT_TRUE( + ShapeUtil::Equal(computation->root_instruction()->shape(), dot_shape)); + + auto match_dot_0 = op::Dot(op::Parameter(0), op::Slice(op::Constant())); + auto match_dot_1 = op::Dot(op::Parameter(1), op::Slice(op::Constant())); + auto match_dot_2 = op::Dot(op::Parameter(2), op::Slice(op::Constant())); + auto match_dot_3 = op::Dot(op::Parameter(3), op::Slice(op::Constant())); + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Add(op::Add(match_dot_0, match_dot_1), match_dot_2), + match_dot_3)); +} + +DotOfConcatTestSpec kDotOfConcatTestSpecs[] = { + {/*m=*/3, /*k=*/9, /*n=*/3}, // + {/*m=*/3, /*k=*/20, /*n=*/3}, // + {/*m=*/1, /*k=*/18, /*n=*/5}, // + {/*m=*/20, /*k=*/20, /*n=*/1}, // + {/*m=*/1, /*k=*/16, /*n=*/1}, // +}; + +INSTANTIATE_TEST_CASE_P(DotOfConcatSimplificationTestInstantiation, + DotOfConcatSimplificationTest, + ::testing::ValuesIn(kDotOfConcatTestSpecs)); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc index ad2fee2d39a8ca183b87212bdeea22c351aaa88a..4e80679c11dfdf7fdf8077a9f354139a4cab6803 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.cc +++ b/tensorflow/compiler/xla/service/allocation_tracker.cc @@ -27,191 +27,161 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" - -namespace se = ::perftools::gputools; namespace xla { -AllocationTracker::AllocationTracker() : next_handle_(1) {} - -GlobalDataHandle AllocationTracker::Register(Backend* backend, - int device_ordinal, - se::DeviceMemoryBase device_memory, - const Shape& shape, - const string& tag) { - tensorflow::mutex_lock lock(allocation_mutex_); +StatusOr AllocationTracker::Register( + std::unique_ptr shaped_buffer, const string& tag) { + tensorflow::mutex_lock lock(mutex_); VLOG(2) << "Register"; - return RegisterInternal(backend, device_ordinal, device_memory, shape, tag, - /*initial_ref_count=*/1); + return RegisterInternal(std::move(shaped_buffer), tag); } -GlobalDataHandle AllocationTracker::RegisterInternal( - Backend* backend, int device_ordinal, se::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) { +StatusOr AllocationTracker::RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) { VLOG(2) << "RegisterInternal(" << "tag: \"" << tag << "\" " - << "device_ordinal: " << device_ordinal << " " - << "device_memory: " << device_memory.opaque() << " " - << "shape: " << shape.ShortDebugString() << ")"; - TF_CHECK_OK(ShapeUtil::ValidateShape(shape)); - - int64 handle; - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory.opaque()); - if (handle_it != handle_map.end()) { - handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - CHECK_GT(ref_count, 0); - VLOG(2) << "ref_count: " << ref_count << " -> " << - (ref_count + initial_ref_count); - allocation->increment_ref_count(initial_ref_count); - } else { - handle = next_handle_++; - VLOG(2) << "ref_count: " << initial_ref_count; - InsertOrDie(&handle_map, device_memory.opaque(), handle); - auto inserted = handle_to_allocation_.emplace( - handle, MakeUnique(backend, device_ordinal, device_memory, - shape, tag, initial_ref_count)); - CHECK(inserted.second); + << "shaped_buffer: " << *shaped_buffer; + if (shaped_buffer->platform() != backend_->platform()) { + return InvalidArgument( + "AllocationTracker for platform %s cannot register buffer from " + "platform %s", + backend_->platform()->Name().c_str(), + shaped_buffer->platform()->Name().c_str()); } + int64 handle = next_handle_++; + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + AddAllocationOrIncrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal()); + } GlobalDataHandle result; result.set_handle(handle); + + handle_to_shaped_buffer_[handle] = std::move(shaped_buffer); + VLOG(2) << "handle: " << handle; return result; } tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); - std::set deallocated_buffers; - TF_RETURN_IF_ERROR( - DeallocateShape(allocation->backend(), allocation->device_ordinal(), - allocation->mutable_device_memory(), allocation->shape(), - &deallocated_buffers)); - return tensorflow::Status::OK(); -} - -tensorflow::Status AllocationTracker::DeallocateShape( - Backend* backend, int device_ordinal, se::DeviceMemoryBase* device_memory, - const Shape& shape, std::set* deallocated_buffers) { - VLOG(2) << "DeallocateShape(" - << "shape: \"" << shape.ShortDebugString() << "\" " - << "device_memory: " << device_memory->opaque() << ")"; - if (ContainsKey(*deallocated_buffers, device_memory->opaque())) { - // Buffer has already been deallocated. Nothing to do. - VLOG(2) << "already deallocated"; - return tensorflow::Status::OK(); - } - - // Add buffer to deallocated set so we do not try to deallocate it again - // if it is encountered again while traversing a tuple. - deallocated_buffers->insert(device_memory->opaque()); - - HandleMap& handle_map = GetOrCreateOpaqueToHandleMap(device_ordinal); - auto handle_it = handle_map.find(device_memory->opaque()); - if (handle_it != handle_map.end()) { - int64 handle = handle_it->second; - auto& allocation = FindOrDie(handle_to_allocation_, handle); - int ref_count = allocation->ref_count(); - VLOG(2) << "ref_count: " << ref_count << " -> " << ref_count - 1; - allocation->decrement_ref_count(); - if (allocation->ref_count() > 0) { - // Buffer is referred to by another allocation. Don't deallocate it. - return tensorflow::Status::OK(); - } - handle_map.erase(device_memory->opaque()); + tensorflow::mutex_lock lock(mutex_); + VLOG(2) << "Unregister(" + << "handle: " << data.handle() << ")"; + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + std::vector shape_indices; + ShapeUtil::ForEachSubshape(shaped_buffer->on_device_shape(), + [this, &shape_indices](const Shape& /*subshape*/, + const ShapeIndex& index) { + shape_indices.push_back(index); + }); + for (const ShapeIndex& index : shape_indices) { + TF_RETURN_IF_ERROR(DecrementRefCount(shaped_buffer->buffer(index), + shaped_buffer->device_ordinal())); } - if (ShapeUtil::IsTuple(shape)) { - // Traverse into tuple recursively deallocating buffers. - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - backend->stream_executor(device_ordinal)); - TF_ASSIGN_OR_RETURN(std::vector elements, - backend->transfer_manager()->ShallowCopyTupleFromDevice( - executor, *device_memory, shape)); - - TF_RET_CHECK(ShapeUtil::TupleElementCount(shape) == elements.size()) - << "tuple has unexpected number of elements: " << elements.size() - << " != " << ShapeUtil::TupleElementCount(shape); - for (size_t i = 0; i < elements.size(); ++i) { - VLOG(2) << "recursing onto the tuple elements"; - TF_RETURN_IF_ERROR(DeallocateShape(backend, device_ordinal, &elements[i], - shape.tuple_shapes(i), - deallocated_buffers)); - } - } + // Keep a nullptr as a tombstone for unregistered handles. This enables better + // error messages. That is, "handle has been deallocated" versus "handle does + // not exist". + handle_to_shaped_buffer_.at(data.handle()).reset(); - return backend->memory_allocator()->Deallocate(device_ordinal, device_memory); + return tensorflow::Status::OK(); } StatusOr> AllocationTracker::DeconstructTuple( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); - TF_ASSIGN_OR_RETURN(Allocation * allocation, ResolveInternal(data)); + tensorflow::mutex_lock lock(mutex_); - if (!ShapeUtil::IsTuple(allocation->shape())) { + TF_ASSIGN_OR_RETURN(ShapedBuffer * shaped_buffer, ResolveInternal(data)); + if (!ShapeUtil::IsTuple(shaped_buffer->on_host_shape())) { return InvalidArgument("global data handle %lld is not a tuple", data.handle()); } + // If the on-host representation is a tuple, then the on-device one should be + // as well. + TF_RET_CHECK(ShapeUtil::IsTuple(shaped_buffer->on_device_shape())); - if (ShapeUtil::IsNestedTuple(allocation->shape())) { + if (ShapeUtil::IsNestedTuple(shaped_buffer->on_device_shape())) { return Unimplemented("deconstructing nested tuples not yet supported"); } - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - TF_ASSIGN_OR_RETURN( - std::vector element_bases, - allocation->backend()->transfer_manager()->ShallowCopyTupleFromDevice( - executor, allocation->device_memory(), allocation->shape())); - std::vector element_handles; - element_handles.reserve(element_bases.size()); - for (int i = 0; i < element_bases.size(); ++i) { - element_handles.push_back(RegisterInternal( - allocation->backend(), allocation->device_ordinal(), element_bases[i], - ShapeUtil::GetSubshape(allocation->shape(), {i}), - tensorflow::strings::StrCat(allocation->tag(), ".element_", i), - /*initial_ref_count=*/2)); + for (int i = 0; + i < ShapeUtil::TupleElementCount(shaped_buffer->on_device_shape()); + ++i) { + auto element_buffer = MakeUnique( + ShapeUtil::GetTupleElementShape(shaped_buffer->on_host_shape(), i), + ShapeUtil::GetTupleElementShape(shaped_buffer->on_device_shape(), i), + shaped_buffer->platform(), shaped_buffer->device_ordinal()); + element_buffer->set_buffer(shaped_buffer->buffer(/*index=*/{i}), + /*index=*/{}); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle element_handle, + RegisterInternal(std::move(element_buffer), "deconstructed tuple")); + + element_handles.push_back(element_handle); } return std::move(element_handles); } -StatusOr AllocationTracker::Resolve( +StatusOr AllocationTracker::Resolve( const GlobalDataHandle& data) { - tensorflow::mutex_lock lock(allocation_mutex_); + tensorflow::mutex_lock lock(mutex_); return AllocationTracker::ResolveInternal(data); } -StatusOr AllocationTracker::ResolveInternal( +StatusOr AllocationTracker::ResolveInternal( const GlobalDataHandle& data) { VLOG(2) << "resolve:" << data.handle(); - auto it = handle_to_allocation_.find(data.handle()); - if (it == handle_to_allocation_.end()) { + auto it = handle_to_shaped_buffer_.find(data.handle()); + if (it == handle_to_shaped_buffer_.end()) { return NotFound("no allocation record for global data handle: %lld", data.handle()); } - Allocation* allocation = it->second.get(); + ShapedBuffer* shaped_buffer = it->second.get(); - if (allocation->is_deallocated()) { + if (shaped_buffer == nullptr) { return InvalidArgument("global data handle %lld was previously deallocated", data.handle()); } - return allocation; + return shaped_buffer; +} + +void AllocationTracker::AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + if (it == allocation_map.end()) { + allocation_map[device_memory.opaque()] = {device_memory, device_ordinal, + /*ref_count=*/1}; + } else { + it->second.ref_count++; + } } -AllocationTracker::HandleMap& AllocationTracker::GetOrCreateOpaqueToHandleMap( - int device_ordinal) { - if (opaque_to_handle_.size() <= device_ordinal) { - opaque_to_handle_.resize(device_ordinal + 1); +Status AllocationTracker::DecrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) { + AllocationMap& allocation_map = opaque_to_allocation_map_[device_ordinal]; + auto it = allocation_map.find(device_memory.opaque()); + TF_RET_CHECK(it != allocation_map.end()); + Allocation& allocation = it->second; + TF_RET_CHECK(allocation.ref_count >= 1); + if (allocation.ref_count == 1) { + TF_RETURN_IF_ERROR(backend_->memory_allocator()->Deallocate( + device_ordinal, &device_memory)); + allocation_map.erase(it); + } else { + allocation.ref_count--; } - return opaque_to_handle_[device_ordinal]; + return tensorflow::Status::OK(); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h index ebbf35b6fe87bc7322ccb99cfe8f8eed56de06b3..807af8694972083d097604a67ee46d2f73d9545a 100644 --- a/tensorflow/compiler/xla/service/allocation_tracker.h +++ b/tensorflow/compiler/xla/service/allocation_tracker.h @@ -28,147 +28,92 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/thread_annotations.h" #include "tensorflow/core/platform/types.h" namespace xla { -// A global allocation in device space, tracked by the XLA service. -class Allocation { - public: - Allocation(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag, int initial_ref_count) - : backend_(backend), - device_ordinal_(device_ordinal), - device_memory_(device_memory), - shape_(shape), - tag_(tag), - ref_count_(initial_ref_count) {} - - Backend* backend() const { return backend_; } - int device_ordinal() const { return device_ordinal_; } - perftools::gputools::DeviceMemoryBase device_memory() const { - return device_memory_; - } - const Shape& shape() const { return shape_; } - const string& tag() const { return tag_; } - - bool is_deallocated() const { - CHECK_GE(ref_count_, 0); - return ref_count_ == 0; - } - int ref_count() const { - CHECK_GE(ref_count_, 0); - return ref_count_; - } - void increment_ref_count(int inc) { - CHECK_GT(ref_count_, 0); - CHECK_LE(ref_count_, INT_MAX - inc); - ref_count_ += inc; - } - void decrement_ref_count() { - CHECK_GT(ref_count_, 0); - --ref_count_; - } - perftools::gputools::DeviceMemoryBase* mutable_device_memory() { - return &device_memory_; - } - - private: - // The backend that the memory is allocated on. - Backend* backend_; - - // The device that the memory is allocated on. - int device_ordinal_; - - // The pointer to this allocation. - perftools::gputools::DeviceMemoryBase device_memory_; - - // The shape of this allocation. - Shape shape_; - - // An informal description of this allocation shown in tools. - string tag_; - - // This is the number of Allocation objects which refer to this memory - // allocation. - int ref_count_; - - // Return a string representation of this allocation for debugging or logging - // purposes. - string ToString() const; -}; - // Tracks allocations for the XLA service; allocations can be registered // with shape/device/tag and resolved from a handle for later use. class AllocationTracker { public: - AllocationTracker(); + // The allocator is used for deallocating memory when allocations are + // deregistered. All registered allocations must have the same platform as the + // allocator. + AllocationTracker(Backend* backend) : backend_(backend), next_handle_(1) {} - // Registers device memory with a given shape, device identifier, and tag, and - // returns a corresponding handle that can be used for talking to XLA - // clients. - GlobalDataHandle Register(Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, - const Shape& shape, const string& tag); + // Registers a shaped buffer of device memory, and returns a corresponding + // handle that can be used for talking to XLA clients. + StatusOr Register( + std::unique_ptr shaped_buffer, const string& tag); // Unregister the allocation for the given data handle. - tensorflow::Status Unregister(const GlobalDataHandle& data); + Status Unregister(const GlobalDataHandle& data); // Returns a vector of global data handles that point to the tuple elements. StatusOr> DeconstructTuple( const GlobalDataHandle& Data); - // Resolve a handle from an XLA client to an allocation, or provide an - // error status to say whether it was not found (or found, but found - // deallocated). - StatusOr Resolve(const GlobalDataHandle& data); + // Resolve a handle from an XLA client to a shaped buffer, or provide an error + // status to say whether it was not found (or found, but found deallocated). + StatusOr Resolve(const GlobalDataHandle& data); private: - // Internal helper which resolves the given GlobalDataHandle to an Allocation. - StatusOr ResolveInternal(const GlobalDataHandle& data) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - GlobalDataHandle RegisterInternal( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase device_memory, const Shape& shape, - const string& tag, int initial_ref_count) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Helper function which deallocates the memory buffer containing the given - // shape referred to by device_memory. Tuples are traversed recursively - // deallocating all nested buffers. The parameter deallocated_buffers contains - // the set of buffers deallocated so far stored as opaque values (void *) from - // DeviceMemoryBase. Keeping track of deallocated buffers prevents - // double-freeing of buffers which may be referred to more than once in a - // nested tuple. - tensorflow::Status DeallocateShape( - Backend* backend, int device_ordinal, - perftools::gputools::DeviceMemoryBase* device_memory, const Shape& shape, - std::set* deallocated_buffers) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - // Returns the opaque_to_handle_ map for the given device_ordinal, creating - // a new map if there is not one for the device_ordinal. - using HandleMap = std::map; - HandleMap& GetOrCreateOpaqueToHandleMap(int device_ordinal) - EXCLUSIVE_LOCKS_REQUIRED(allocation_mutex_); - - tensorflow::mutex allocation_mutex_; // Guards the allocation mapping. + // Data structure encapsulating single memory allocation on the device. + struct Allocation { + // The pointer to this allocation. + perftools::gputools::DeviceMemoryBase device_memory; + + // The device that the memory is allocated on. + int device_ordinal; + + // This is the number of times this memory allocation is referred to by + // registered data handles. + int ref_count; + }; + + // Internal helper which resolves the given GlobalDataHandle to a + // ShapedBuffer. + StatusOr ResolveInternal(const GlobalDataHandle& data) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Internal helper which registers a shaped buffer. + StatusOr RegisterInternal( + std::unique_ptr shaped_buffer, const string& tag) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Adds the given device address to the allocation tracker, or if it already + // exists, then increment it's reference count. + void AddAllocationOrIncrementRefCount( + perftools::gputools::DeviceMemoryBase device_memory, int device_ordinal) + EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // Decrements the reference count of the given device memory. Then, if it is + // zero, deallocate the memory. + Status DecrementRefCount(perftools::gputools::DeviceMemoryBase device_memory, + int device_ordinal) EXCLUSIVE_LOCKS_REQUIRED(mutex_); + + // A map from device memory opaque value to allocation. One such map is + // maintained per device ordinal. + using AllocationMap = tensorflow::gtl::FlatMap; + + tensorflow::mutex mutex_; + + // Backend to use with this tracker. The backend supplies the memory allocator + // to use when deallocating memory. + Backend* backend_; // The next handle to assign to an allocation, guarded by the same mutex as // the mapping as they'll be mutated at the same time. - int64 next_handle_ GUARDED_BY(allocation_mutex_); + int64 next_handle_ GUARDED_BY(mutex_); - // A map from DeviceMemoryBase to handle for each device_ordinal. - std::vector opaque_to_handle_ GUARDED_BY(allocation_mutex_); + // A map from device ordinal to AllocationMap. + tensorflow::gtl::FlatMap opaque_to_allocation_map_ + GUARDED_BY(mutex_); - // Mapping from GlobalDataHandle handle to the corresponding registered - // Allocation object. - std::map> handle_to_allocation_ - GUARDED_BY(allocation_mutex_); + // A map from data handle to ShapedBuffer. + tensorflow::gtl::FlatMap> + handle_to_shaped_buffer_ GUARDED_BY(mutex_); TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker); }; diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc similarity index 58% rename from tensorflow/compiler/xla/service/batchnorm_rewriter.cc rename to tensorflow/compiler/xla/service/batchnorm_expander.cc index c6193b3fbbd651088a823605af3ba84bca4a77ee..27ddfd47aa3096afd3e245af1ac3cedd9b48ce4a 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include #include @@ -45,9 +45,9 @@ limitations under the License. namespace xla { -// BatchNormRewriterVisitor traverses the HLO computation and rewrites BatchNorm +// BatchNormExpanderVisitor traverses the HLO computation and rewrites BatchNorm // operations into smaller operations. -class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { +class BatchNormExpanderVisitor : public DfsHloVisitorWithDefault { public: // Default visitor action is to do nothing and return OK. Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { @@ -68,10 +68,10 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { // Returns whether any batch norm ops were rewritten. const bool changed() const { return changed_; } - ~BatchNormRewriterVisitor() override = default; + ~BatchNormExpanderVisitor() override = default; private: - explicit BatchNormRewriterVisitor(HloComputation* computation, + explicit BatchNormExpanderVisitor(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) @@ -94,7 +94,7 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { return computation_->parent()->AddEmbeddedComputation(b.Build(scalar_op)); } - // Current HloComputation instance the BatchNormRewriter is + // Current HloComputation instance the BatchNormExpander is // traversing. HloComputation* computation_; @@ -130,11 +130,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { } }; -bool BatchNormRewriterVisitor::Run(HloComputation* computation, +bool BatchNormExpanderVisitor::Run(HloComputation* computation, bool rewrite_training_op, bool rewrite_inference_op, bool rewrite_grad_op, bool use_fusion) { - BatchNormRewriterVisitor visitor( + BatchNormExpanderVisitor visitor( computation, /*rewrite_training_op=*/rewrite_training_op, /*rewrite_inference_op=*/rewrite_inference_op, @@ -144,11 +144,20 @@ bool BatchNormRewriterVisitor::Run(HloComputation* computation, return visitor.changed_; } -Status BatchNormRewriterVisitor::HandleBatchNormTraining( +Status BatchNormExpanderVisitor::HandleBatchNormTraining( HloInstruction* batch_norm) { if (!rewrite_training_op_) { return Status::OK(); } + + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + // Expand batch norm training into smaller HLO ops. HloInstruction* operand = batch_norm->mutable_operand(0); const Shape operand_shape = operand->shape(); @@ -160,7 +169,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( Literal::CreateR0(size_in_elements / feature_count); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = computation_->AddInstruction( + auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); HloInstruction* scale = batch_norm->mutable_operand(1); @@ -169,14 +178,12 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(zero_literal))); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); - + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) { @@ -185,109 +192,116 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining( } } - auto scale_broadcasted = computation_->AddInstruction( + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); HloComputation* add_reduce_computation = GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // X^2. - auto operand_squared = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, operand, operand)); + auto operand_squared = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, operand, operand)); // Sum[X]. - auto sum = computation_->AddInstruction(HloInstruction::CreateReduce( - feature_shape, operand, zero, dimensions_without_feature, - add_reduce_computation)); + auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero, + dimensions_without_feature, + add_reduce_computation)); // Sum[X^2]. - auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce( + auto squared_sum = add(HloInstruction::CreateReduce( feature_shape, operand_squared, zero, dimensions_without_feature, add_reduce_computation)); // Fuse two parallel reduces together to improve performance. - if (use_fusion_) { - auto tuple = computation_->AddInstruction( - HloInstruction::CreateTuple({sum, squared_sum})); + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum})); auto fused = computation_->CreateFusionInstruction( {tuple, sum, squared_sum, operand_squared}, HloInstruction::FusionKind::kInput); - sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - squared_sum = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + squared_sum = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // E[X]. - auto mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, sum, elements_per_feature)); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); // E[X^2]. - auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary( + auto square_mean = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature)); // E^2[X]. - auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary( + auto mean_square = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, mean, mean)); // Var[X]. - auto var = computation_->AddInstruction(HloInstruction::CreateBinary( + auto var = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kSubtract, square_mean, mean_square)); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. - auto shifted_normalized = computation_->AddInstruction( - HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd, - scaled_normalized, offset_broadcasted)); - - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({shifted_normalized, mean, var}))); + auto shifted_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted)); + + auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var}); + + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + HloSharding operand_sharding = + batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(operand_sharding); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } -Status BatchNormRewriterVisitor::HandleBatchNormInference( +Status BatchNormExpanderVisitor::HandleBatchNormInference( HloInstruction* batch_norm) { if (!rewrite_inference_op_) { return Status::OK(); @@ -317,58 +331,75 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference( } } - auto scale_broadcasted = computation_->AddInstruction( + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); + + auto scale_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index})); - auto offset_broadcasted = computation_->AddInstruction( + auto offset_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index})); - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index})); - auto var_broadcasted = computation_->AddInstruction( - HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); + auto var_broadcasted = + add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index})); // Var[X] + epsilon. - auto var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); + auto var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon)); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); // 1 / Sqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half)); // X - E[X]. - auto operand_minus_mean = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); + auto operand_minus_mean = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon]. - auto normalized = computation_->AddInstruction( + auto normalized = add( HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply, operand_minus_mean, rsqrt_var_add_epsilon)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale. - auto scaled_normalized = - computation_->AddInstruction(HloInstruction::CreateBinary( - operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); + auto scaled_normalized = add(HloInstruction::CreateBinary( + operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted)); // (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset. auto shifted_normalized = HloInstruction::CreateBinary( operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted); + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + if (batch_norm->has_sharding()) { + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), operand_shape)) { + inst->set_sharding(batch_norm->sharding()); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + shifted_normalized->set_sharding(batch_norm->sharding()); + } TF_CHECK_OK( ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized))); return Status::OK(); } -Status BatchNormRewriterVisitor::HandleBatchNormGrad( +Status BatchNormExpanderVisitor::HandleBatchNormGrad( HloInstruction* batch_norm) { // Use the following formulas to calculate gradients: // scale_grad = @@ -385,6 +416,13 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( if (!rewrite_grad_op_) { return Status::OK(); } + std::vector added_instructions; + auto add = [&](std::unique_ptr inst) { + HloInstruction* added_inst = computation_->AddInstruction(std::move(inst)); + added_instructions.push_back(added_inst); + return added_inst; + }; + int64 instruction_count_before = computation_->instruction_count(); HloInstruction* activation = batch_norm->mutable_operand(0); const Shape activation_shape = activation->shape(); @@ -403,23 +441,22 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( Literal::CreateR0(size_in_elements / feature_count); TF_ASSIGN_OR_RETURN(elements_per_feature_literal, elements_per_feature_literal->Convert(ptype)); - auto elements_per_feature = computation_->AddInstruction( + auto elements_per_feature = add( HloInstruction::CreateConstant(std::move(elements_per_feature_literal))); auto zero_literal = Literal::CreateR0(0.0f); TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype)); - auto zero = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(zero_literal))); + auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal))); auto neg_half_literal = Literal::CreateR0(-0.5f); TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype)); - auto neg_half = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(neg_half_literal))); + auto neg_half = + add(HloInstruction::CreateConstant(std::move(neg_half_literal))); auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon()); TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype)); - auto epsilon = computation_->AddInstruction( - HloInstruction::CreateConstant(std::move(epsilon_literal))); + auto epsilon = + add(HloInstruction::CreateConstant(std::move(epsilon_literal))); std::vector dimensions_without_feature; @@ -429,141 +466,148 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( } } - auto scale_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, scale, {feature_index})); - auto variance_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, variance, {feature_index})); + auto scale_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, scale, {feature_index})); + auto variance_broadcasted = add(HloInstruction::CreateBroadcast( + activation_shape, variance, {feature_index})); // E[X]. - auto mean_broadcasted = computation_->AddInstruction( + auto mean_broadcasted = add( HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index})); // rsqrt[Var[X] + epsilon]. - auto rsqrt_var_add_epsilon_broadcasted = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kPower, - computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, - variance_broadcasted, epsilon)), - neg_half)); - - auto rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kPower, - computation_->AddInstruction(HloInstruction::CreateBinary( - feature_shape, HloOpcode::kAdd, variance, epsilon)), - neg_half)); + auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)), + neg_half)); + + auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + feature_shape, HloOpcode::kPower, + add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance, + epsilon)), + neg_half)); // X - E[X]. - auto activation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, - activation, mean_broadcasted)); + auto activation_minus_mean = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted)); // Grad[Y] * (X - E[X]). - auto grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, activation_minus_mean)); + auto grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, activation_minus_mean)); HloComputation* add_reduce_computation = GetScalarBinaryComputation(ptype, HloOpcode::kAdd); // sum(Grad[Y] * (X - E[X])). auto sum_grad_output_times_activiation_minus_mean = - computation_->AddInstruction(HloInstruction::CreateReduce( + add(HloInstruction::CreateReduce( feature_shape, grad_output_times_activiation_minus_mean, zero, dimensions_without_feature, add_reduce_computation)); // Grad[beta] = Sum(Grad[Y]). - auto grad_beta = computation_->AddInstruction(HloInstruction::CreateReduce( + auto grad_beta = add(HloInstruction::CreateReduce( feature_shape, grad_output, zero, dimensions_without_feature, add_reduce_computation)); - if (use_fusion_) { - auto tuple = computation_->AddInstruction(HloInstruction::CreateTuple( + if (use_fusion_ && !batch_norm->has_sharding()) { + auto tuple = add(HloInstruction::CreateTuple( {sum_grad_output_times_activiation_minus_mean, grad_beta})); auto fused = computation_->CreateFusionInstruction( {tuple, sum_grad_output_times_activiation_minus_mean, grad_beta}, HloInstruction::FusionKind::kInput); - sum_grad_output_times_activiation_minus_mean = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); + sum_grad_output_times_activiation_minus_mean = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0)); - grad_beta = computation_->AddInstruction( - HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); + grad_beta = + add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1)); } // Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]). - auto grad_scale = computation_->AddInstruction(HloInstruction::CreateBinary( + auto grad_scale = add(HloInstruction::CreateBinary( feature_shape, HloOpcode::kMultiply, sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon)); // I2 = Sum(Grad[Y]) - auto I2 = computation_->AddInstruction(HloInstruction::CreateBroadcast( - activation_shape, grad_beta, {feature_index})); + auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta, + {feature_index})); // I3 = Sum(Grad[Y] * (X - E[X])) - auto I3 = computation_->AddInstruction(HloInstruction::CreateBroadcast( + auto i3 = add(HloInstruction::CreateBroadcast( activation_shape, sum_grad_output_times_activiation_minus_mean, {feature_index})); // I4 = (X - E[X]) * I3 - auto I4 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, I3, activation_minus_mean)); + auto i4 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean)); // I5 = I4 / (Var[X] + epsilon) - auto I5 = computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, I4, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kAdd, variance_broadcasted, epsilon)))); + auto i5 = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, i4, + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd, + variance_broadcasted, epsilon)))); // scale * rsqrt[Var[X] + epsilon] * 1/N - auto scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kMultiply, scale_broadcasted, - rsqrt_var_add_epsilon_broadcasted)); + auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kMultiply, scale_broadcasted, + rsqrt_var_add_epsilon_broadcasted)); - scale_times_rsqrt_var_add_epsilon = - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kDivide, - scale_times_rsqrt_var_add_epsilon, elements_per_feature)); + scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary( + activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon, + elements_per_feature)); - auto I1 = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - grad_output, elements_per_feature)); + auto i1 = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + grad_output, elements_per_feature)); // I6 = I1 - I2 - I5 - auto I6 = computation_->AddInstruction(HloInstruction::CreateBinary( + auto i6 = add(HloInstruction::CreateBinary( activation_shape, HloOpcode::kSubtract, - computation_->AddInstruction(HloInstruction::CreateBinary( - activation_shape, HloOpcode::kSubtract, I1, I2)), - I5)); + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract, + i1, i2)), + i5)); // Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6. - auto grad_activation = computation_->AddInstruction( - HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, - scale_times_rsqrt_var_add_epsilon, I6)); + auto grad_activation = + add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply, + scale_times_rsqrt_var_add_epsilon, i6)); + auto tuple = + HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}); + if (batch_norm->has_sharding()) { + int64 instruction_count_after = computation_->instruction_count(); + CHECK_EQ(instruction_count_after, + instruction_count_before + added_instructions.size()); + HloSharding activation_sharding = + batch_norm->sharding().GetAsShapeTree(batch_norm->shape()).element({0}); + for (HloInstruction* inst : added_instructions) { + if (ShapeUtil::Equal(inst->shape(), activation_shape)) { + inst->set_sharding(activation_sharding); + } else { + inst->set_sharding(HloSharding::Replicate()); + } + } + tuple->set_sharding(batch_norm->sharding()); + } - TF_CHECK_OK(ReplaceWithNewInstruction( - batch_norm, - HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta}))); + TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple))); return Status::OK(); } -StatusOr BatchNormRewriter::Run(HloModule* module) { - XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), before:\n" + module->ToString()); +StatusOr BatchNormExpander::Run(HloModule* module) { + XLA_VLOG_LINES(2, "BatchNormExpander::Run(), before:\n" + module->ToString()); bool changed = false; for (auto* comp : module->MakeNonfusionComputations()) { - if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, + if (BatchNormExpanderVisitor::Run(comp, rewrite_training_op_, rewrite_inference_op_, rewrite_grad_op_, use_fusion_)) { changed = true; } } - XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), after:\n" + module->ToString()); + XLA_VLOG_LINES(2, "BatchNormExpander::Run(), after:\n" + module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.h b/tensorflow/compiler/xla/service/batchnorm_expander.h similarity index 83% rename from tensorflow/compiler/xla/service/batchnorm_rewriter.h rename to tensorflow/compiler/xla/service/batchnorm_expander.h index f601741d964376058a2bafade311ede4c8567fd2..4ad987085da91684bb7891070afeefd19be4138f 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.h +++ b/tensorflow/compiler/xla/service/batchnorm_expander.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ #include @@ -26,18 +26,18 @@ namespace xla { // A pass which rewrites batch norm operations into more operations. Breaking a // big operation into smaller operations helps leverage our generic fusion // logic. -class BatchNormRewriter : public HloPassInterface { +class BatchNormExpander : public HloPassInterface { public: // When use_fusion is set, a multi-output fusion node is created. - BatchNormRewriter(bool rewrite_training_op = false, + BatchNormExpander(bool rewrite_training_op = false, bool rewrite_inference_op = false, bool rewrite_grad_op = false, bool use_fusion = true) : rewrite_training_op_(rewrite_training_op), rewrite_inference_op_(rewrite_inference_op), rewrite_grad_op_(rewrite_grad_op), use_fusion_(use_fusion) {} - ~BatchNormRewriter() = default; - tensorflow::StringPiece name() const override { return "batchnorm_rewriter"; } + ~BatchNormExpander() = default; + tensorflow::StringPiece name() const override { return "batchnorm_expander"; } // Run operation expander on the given computation. Returns whether the // computation was changed. @@ -52,4 +52,4 @@ class BatchNormRewriter : public HloPassInterface { } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_REWRITER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BATCHNORM_EXPANDER_H_ diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc similarity index 93% rename from tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc rename to tensorflow/compiler/xla/service/batchnorm_expander_test.cc index 590f79aee51ccf410823b91fd8ad09fc7c429c7d..aa36e64b07099a372dab67babc7a18a2d39596bc 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter_test.cc +++ b/tensorflow/compiler/xla/service/batchnorm_expander_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include #include @@ -36,10 +36,10 @@ limitations under the License. namespace xla { namespace { -using BatchNormRewriterTest = HloTestBase; +using BatchNormExpanderTest = HloTestBase; // Test that we expand BatchNormTraining. -TEST_F(BatchNormRewriterTest, BatchNormTraining) { +TEST_F(BatchNormExpanderTest, BatchNormTraining) { Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); Shape scale_shape = ShapeUtil::MakeShape(F32, {2}); Shape offset_shape = ShapeUtil::MakeShape(F32, {2}); @@ -63,7 +63,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormTraining); - BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); @@ -73,7 +73,7 @@ TEST_F(BatchNormRewriterTest, BatchNormTraining) { } // Test that we expand BatchNormGrad. -TEST_F(BatchNormRewriterTest, BatchNormGrad) { +TEST_F(BatchNormExpanderTest, BatchNormGrad) { Shape input_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); Shape scale_shape = ShapeUtil::MakeShape(F32, {2}); Shape mean_shape = ShapeUtil::MakeShape(F32, {2}); @@ -105,7 +105,7 @@ TEST_F(BatchNormRewriterTest, BatchNormGrad) { auto computation = module->AddEntryComputation(builder.Build()); HloInstruction* root = computation->root_instruction(); EXPECT_EQ(root->opcode(), HloOpcode::kBatchNormGrad); - BatchNormRewriter rewriter(/*rewrite_training_op=*/true, + BatchNormExpander rewriter(/*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true); ASSERT_TRUE(rewriter.Run(module.get()).ValueOrDie()); diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc new file mode 100644 index 0000000000000000000000000000000000000000..cde990e176ddb57a8e93ecc3c60260b2dbae32a8 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.cc @@ -0,0 +1,184 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16ConversionFoldingVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16ConversionFoldingVisitor( + HloComputation* computation, const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16ConversionFoldingVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO has a BF16 -> F32 conversion as input, or a F32 -> BF16 + // conversion as output, and folds them to the HLO itself if feasible. + Status TryFoldBF16Conversions(HloInstruction* hlo); + + // Folds the F32 -> BF16 conversions from the HLO's output. + // + // Precondition: all of the HLO's users are F32 -> BF16 conversions. + Status FoldOutputConversions(HloInstruction* hlo); + + // Folds the BF16 -> F32 conversion operand to the HLO. + // + // Precondition: the operand is a F32 -> BF16 conversion. + Status FoldOperandConversion(HloInstruction* hlo, int64 operand_index); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16ConversionFoldingVisitor::FoldOutputConversions( + HloInstruction* hlo) { + std::vector materialized_users = hlo->users(); + hlo->mutable_shape()->set_element_type(BF16); + for (auto user : materialized_users) { + CHECK_EQ(user->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(hlo)); + changed_ = true; + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::FoldOperandConversion( + HloInstruction* hlo, int64 operand_index) { + // The operand is a convert from BF16 to F32. + auto operand = hlo->mutable_operand(operand_index); + CHECK_EQ(operand->opcode(), HloOpcode::kConvert); + TF_RETURN_IF_ERROR( + hlo->ReplaceOperandWith(operand_index, operand->mutable_operand(0))); + changed_ = true; + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::TryFoldBF16Conversions( + HloInstruction* hlo) { + std::vector bf16_to_f32_operands; + bool has_other_f32_operands = false; + for (int64 i = 0; i < hlo->operands().size(); ++i) { + auto operand = hlo->operand(i); + if (operand->shape().element_type() == F32) { + if (operand->opcode() == HloOpcode::kConvert && + operand->operand(0)->shape().element_type() == BF16 && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + // Operand is a convert from BF16 to F32 and we support BF16 input + // directly in the current HLO at the operand index. + bf16_to_f32_operands.push_back(i); + } else { + has_other_f32_operands = true; + } + continue; + } + } + + bool fold_output_conversion = hlo->user_count() > 0 && + hlo->shape().element_type() == F32 && + bfloat16_support_->SupportsBF16Output(*hlo) && + hlo != computation_->root_instruction(); + if (fold_output_conversion) { + for (auto user : hlo->users()) { + if (user->opcode() == HloOpcode::kConvert && + user->shape().element_type() == BF16) { + continue; + } + // We should not change the output type if any user is not a conversion + // from F32 to BF16. + fold_output_conversion = false; + break; + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + if (has_other_f32_operands || + (!fold_output_conversion && hlo->shape().element_type() == F32)) { + // Some of the operands/output will remain F32, but we cannot use mixed + // precisions, so we cannot do anything here. + return Status::OK(); + } + } + + if (fold_output_conversion) { + TF_RETURN_IF_ERROR(FoldOutputConversions(hlo)); + } + + for (int64 i : bf16_to_f32_operands) { + TF_RETURN_IF_ERROR(FoldOperandConversion(hlo, i)); + } + return Status::OK(); +} + +Status BFloat16ConversionFoldingVisitor::DefaultAction(HloInstruction* hlo) { + // Do not fold BF16 conversions for instructions related to tuples, entry and + // exit of a computation, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + if (hlo == computation_->root_instruction() && + !bfloat16_support_->SupportsMixedPrecisions(*hlo)) { + // If hlo is the root instruction, we cannot change its output, so folding + // can only happen when it supports mixed precision so that we can change + // its operands. + return Status::OK(); + } + return TryFoldBF16Conversions(hlo); +} + +StatusOr BFloat16ConversionFolding::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (BFloat16ConversionFoldingVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES( + 2, "BFloat16ConversionFolding::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h new file mode 100644 index 0000000000000000000000000000000000000000..c9398387098fad84ba28735c30e426fedd9b0cb0 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which folds F32 <-> BF16 conversions to their operands or users, when +// it is supported by the backend. +// +// This pass follows the passed-in backend-specific BF16 support rules, but can +// introduce mixed precision in individual HLOs which breaks the assumption of +// some other HLO passes. So it should be used at the end of the HLO +// optimization pipeline followed by a DCE pass. If other passes are needed +// after this pass, run BFloat16MixedPrecisionRemoval first to undo some of the +// changed made by this pass. +class BFloat16ConversionFolding : public HloPassInterface { + public: + explicit BFloat16ConversionFolding(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16ConversionFolding() override = default; + tensorflow::StringPiece name() const override { return "bfloat16-fold"; } + + // Run BF16 conversion folding on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_CONVERSION_FOLDING_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb37759439debf41a305ec7dccaa548e1bf234cd --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -0,0 +1,209 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_conversion_folding.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16ConversionFoldingTest : public HloTestBase { + protected: + bool FoldConversions(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16ConversionFolding fold(&bfloat16_support_); + StatusOr result = fold.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16ConversionFoldingTest, FoldIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, convert1, c)); + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, add1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), BF16); + EXPECT_EQ(add1->operand(0), add0); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldIfUnsupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kMultiply, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* mul1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kMultiply, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, mul1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldUnsupportedMixedPrecision) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, f32_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kSubtract, a, b)); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub0)); + HloInstruction* convert1 = builder.AddInstruction( + HloInstruction::CreateConvert(f32_shape, convert0)); + + HloInstruction* sub1 = builder.AddInstruction(HloInstruction::CreateBinary( + f32_shape, HloOpcode::kSubtract, convert1, c)); + HloInstruction* convert2 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, sub1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert2); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0), convert1); +} + +TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* convert0 = + builder.AddInstruction(HloInstruction::CreateConvert(f32_shape, b)); + + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({a, convert0})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(f32_shape, tuple, 0)); + HloInstruction* convert1 = + builder.AddInstruction(HloInstruction::CreateConvert(bf16_shape, gte)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(FoldConversions(module.get())); + + EXPECT_EQ(computation->root_instruction(), convert1); + EXPECT_EQ(gte->shape().element_type(), F32); + EXPECT_EQ(tuple->operand(1), convert0); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc new file mode 100644 index 0000000000000000000000000000000000000000..b032c040e8aff49f9e0fc1ff9a1c1e79ea4bb77f --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -0,0 +1,351 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { + public: + explicit BFloat16NormalizationVisitor(HloComputation* computation, + const BFloat16Support* bfloat16_support) + : computation_(computation), bfloat16_support_(bfloat16_support) {} + + Status DefaultAction(HloInstruction* hlo) override; + + // Special handling for cross-replica-sum which can have a tuple output. + Status HandleCrossReplicaSum(HloInstruction* crs) override; + + static bool Run(HloComputation* computation, + const BFloat16Support* bfloat16_support) { + BFloat16NormalizationVisitor visitor(computation, bfloat16_support); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + private: + // Checks if the HLO uses BF16 in an unsupported way, and if so, inserts + // conversions between F32 and BF16 to make it supported. + Status HandleInstruction(HloInstruction* hlo); + + // Inserts a conversion HLO that changes the given HLO's output type. + Status InsertConvertAfterOutput(HloInstruction* hlo, PrimitiveType to, + HloComputation* computation); + + // Changes the output type to the specified type, then inserts a conversion + // to the original type. + Status ChangeOutputTypeThenInsertConvertBack(HloInstruction* hlo, + PrimitiveType to, + HloComputation* computation); + + // Inserts a conversion HLO that changes the given HLO's operand type. + Status InsertConvertBeforeOperand(HloInstruction* hlo, int64 operand_idx, + PrimitiveType to, + HloComputation* computation); + + // Inserts conversion HLOs to replace the called computations' BF16 + // operands/outputs to F32. + Status ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps); + + HloComputation* computation_; + const BFloat16Support* bfloat16_support_; + bool changed_ = false; +}; + +Status BFloat16NormalizationVisitor::InsertConvertAfterOutput( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + bool is_root = computation->root_instruction() == hlo; + std::vector materialized_users = hlo->users(); + // Use inst's shape temporarily, in order to pass checks in ReplaceUseWith. + auto convert = computation->AddInstruction( + HloInstruction::CreateConvert(hlo->shape(), hlo)); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(hlo->ReplaceUseWith(user, convert)); + } + if (is_root) { + computation->set_root_instruction(convert); + } + convert->mutable_shape()->set_element_type(to); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ChangeOutputTypeThenInsertConvertBack( + HloInstruction* hlo, PrimitiveType to, HloComputation* computation) { + auto original_type = hlo->shape().element_type(); + hlo->mutable_shape()->set_element_type(to); + return InsertConvertAfterOutput(hlo, original_type, computation); +} + +Status BFloat16NormalizationVisitor::InsertConvertBeforeOperand( + HloInstruction* hlo, int64 operand_idx, PrimitiveType to, + HloComputation* computation) { + auto operand = hlo->mutable_operand(operand_idx); + auto convert = computation->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(operand->shape(), to), operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(operand_idx, convert)); + changed_ = true; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::ConvertCalledComputations( + HloInstruction* hlo, + tensorflow::gtl::ArraySlice bf16_called_comps) { + std::map cloned_computations; + for (auto& comp : bf16_called_comps) { + auto cloned = comp->parent()->AddEmbeddedComputation(comp->Clone()); + cloned_computations[comp] = cloned; + changed_ = true; + } + hlo->ReplaceCalledComputations([&](HloComputation* comp) { + auto it = cloned_computations.find(comp); + if (it != cloned_computations.end()) { + return it->second; + } + return comp; + }); + for (auto& comp_pair : cloned_computations) { + auto comp = comp_pair.second; + if (comp->root_instruction()->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + InsertConvertAfterOutput(comp->root_instruction(), F32, comp)); + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == BF16) { + // This changes the parameter to F32 then inserts a convert after it. + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(param, F32, comp)); + } + } + } + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( + HloInstruction* crs) { + if (!ShapeUtil::IsTuple(crs->shape())) { + return HandleInstruction(crs); + } + + std::vector operand_types(crs->operand_count()); + std::vector output_types(crs->operand_count()); + bool has_f32 = false; + bool has_bf16 = false; + bool has_bf16_output = false; + for (int64 i = 0; i < crs->operand_count(); ++i) { + operand_types[i] = crs->operand(i)->shape().element_type(); + output_types[i] = ShapeUtil::GetSubshape(crs->shape(), {i}).element_type(); + if (operand_types[i] == F32 || output_types[i] == F32) { + has_f32 = true; + } else if (operand_types[i] == BF16) { + has_bf16 = true; + } + if (output_types[i] == BF16) { + has_bf16 = true; + has_bf16_output = true; + } + } + + for (int64 i = 0; i < crs->operand_count(); ++i) { + if (operand_types[i] != BF16) { + continue; + } + if (bfloat16_support_->SupportsBF16Operand(*crs, i) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + continue; + } + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(crs, i, F32, computation_)); + has_f32 = true; + } + + if (!has_bf16_output) { + return Status::OK(); + } + + if (bfloat16_support_->SupportsBF16Output(*crs) && + (bfloat16_support_->SupportsMixedPrecisions(*crs) || !has_f32)) { + return Status::OK(); + } + + std::vector output_elements(crs->operand_count()); + auto original_shape = crs->shape(); + for (int64 i = 0; i < crs->operand_count(); ++i) { + auto subshape = ShapeUtil::GetMutableSubshape(crs->mutable_shape(), {i}); + if (output_types[i] != BF16) { + output_elements[i] = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + continue; + } + subshape->set_element_type(F32); + auto gte = computation_->AddInstruction( + HloInstruction::CreateGetTupleElement(*subshape, crs, i)); + output_elements[i] = + computation_->AddInstruction(HloInstruction::CreateConvert( + ShapeUtil::ChangeElementType(*subshape, BF16), gte)); + } + auto tuple = computation_->AddInstruction( + HloInstruction::CreateTuple(output_elements)); + + std::vector materialized_users = crs->users(); + // Use the crs' shape temporarily, in order to pass checks in + // ReplaceUseWith. + *tuple->mutable_shape() = crs->shape(); + for (auto* user : materialized_users) { + TF_RETURN_IF_ERROR(crs->ReplaceUseWith(user, tuple)); + } + *tuple->mutable_shape() = original_shape; + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::HandleInstruction(HloInstruction* hlo) { + std::vector bf16_operands; + std::vector f32_operands; + bool has_f32 = false; + bool has_bf16 = false; + + for (int64 i = 0; i < hlo->operand_count(); ++i) { + if (hlo->operand(i)->shape().element_type() == F32) { + f32_operands.push_back(i); + has_f32 = true; + } else if (hlo->operand(i)->shape().element_type() == BF16) { + bf16_operands.push_back(i); + has_bf16 = true; + } + } + + if (hlo->shape().element_type() == F32) { + has_f32 = true; + } else if (hlo->shape().element_type() == BF16) { + has_bf16 = true; + } + + std::vector bf16_called_comps; + for (auto* comp : hlo->called_computations()) { + bool comp_has_bf16 = false; + if (comp->root_instruction()->shape().element_type() == F32) { + has_f32 = true; + } else if (comp->root_instruction()->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + for (auto* param : comp->parameter_instructions()) { + if (param->shape().element_type() == F32) { + has_f32 = true; + } else if (param->shape().element_type() == BF16) { + has_bf16 = true; + comp_has_bf16 = true; + } + } + if (comp_has_bf16) { + bf16_called_comps.push_back(comp); + } + } + + if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) && has_bf16 && + has_f32) { + // Resolve unsupported mixed precision. + // + // See if we can change everything to BF16. + if (hlo->called_computations().empty() && + hlo->shape().element_type() == BF16) { + bool can_use_bf16 = true; + for (int i : f32_operands) { + if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo, + i) && + bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + continue; + } + can_use_bf16 = false; + break; + } + if (can_use_bf16) { + for (int i : f32_operands) { + TF_RETURN_IF_ERROR( + InsertConvertBeforeOperand(hlo, i, BF16, computation_)); + } + return Status::OK(); + } + } + if (hlo->shape().element_type() == BF16) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + for (int i : bf16_operands) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + return ConvertCalledComputations(hlo, bf16_called_comps); + } + + for (int i : bf16_operands) { + if (!bfloat16_support_->SupportsBF16Operand(*hlo, i)) { + TF_RETURN_IF_ERROR(InsertConvertBeforeOperand(hlo, i, F32, computation_)); + } + } + + if (hlo->shape().element_type() == BF16 && + !bfloat16_support_->SupportsBF16Output(*hlo)) { + TF_RETURN_IF_ERROR( + ChangeOutputTypeThenInsertConvertBack(hlo, F32, computation_)); + } + + return Status::OK(); +} + +Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { + // Do not change instructions related to entry and exit of a computation, + // tuples, fusion, convert, and control flow. + if (hlo->opcode() == HloOpcode::kTuple || // + hlo->opcode() == HloOpcode::kGetTupleElement || // + hlo->opcode() == HloOpcode::kInfeed || // + hlo->opcode() == HloOpcode::kOutfeed || // + hlo->opcode() == HloOpcode::kConstant || // + hlo->opcode() == HloOpcode::kParameter || // + hlo->opcode() == HloOpcode::kFusion || // + hlo->opcode() == HloOpcode::kConvert || // + hlo->opcode() == HloOpcode::kCall || // + hlo->opcode() == HloOpcode::kCustomCall || // + hlo->opcode() == HloOpcode::kWhile || // + hlo->opcode() == HloOpcode::kConditional) { + return Status::OK(); + } + return HandleInstruction(hlo); +} + +StatusOr BFloat16Normalization::Run(HloModule* module) { + XLA_VLOG_LINES( + 2, "BFloat16Normalization::Run(), before:\n" + module->ToString()); + bool changed = false; + for (auto* comp : module->MakeComputationPostOrder()) { + if (BFloat16NormalizationVisitor::Run(comp, bfloat16_support_)) { + changed = true; + } + } + XLA_VLOG_LINES(2, + "BFloat16Normalization::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.h b/tensorflow/compiler/xla/service/bfloat16_normalization.h new file mode 100644 index 0000000000000000000000000000000000000000..2a60fe0af3218484acb95e6c69815d551350764c --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.h @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which adds F32 <-> BF16 conversions for HLO instructions that do not +// support BF16 input/output or mixed precision, according to the passed-in +// backend-specific BF16 support rules. +class BFloat16Normalization : public HloPassInterface { + public: + explicit BFloat16Normalization(const BFloat16Support* bfloat16_support) + : bfloat16_support_(bfloat16_support) {} + + ~BFloat16Normalization() override = default; + tensorflow::StringPiece name() const override { return "bf16-normalization"; } + + // Run BF16 normalization on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; + + private: + const BFloat16Support* bfloat16_support_; +}; + +// A pass that unconditionally removes the mixed F32/BF16 uses in HLO +// instructions (excluding convert) by adding F32 <-> BF16 conversions. Unlike +// BFloat16Normalization, this pass does not use a backend-specific +// BFloat16Support, and does not change HLOs that have BF16 data if they do not +// use mixed precision; it removes mixed precision even if the backend supports +// it. This pass is used to make the HLO module valid for other HLO passes which +// do not support mixed precision. +class BFloat16MixedPrecisionRemoval : public HloPassInterface { + public: + BFloat16MixedPrecisionRemoval() {} + + ~BFloat16MixedPrecisionRemoval() override = default; + + tensorflow::StringPiece name() const override { + return "bf16-mixed-precision-removal"; + } + + // Run mixed precision removal on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override { + BFloat16Normalization normalization(&no_mixed_precision_support_); + return normalization.Run(module); + } + + private: + class BFloat16SupportForMixedPrecisionRemoval : public BFloat16Support { + public: + BFloat16SupportForMixedPrecisionRemoval() {} + + ~BFloat16SupportForMixedPrecisionRemoval() override = default; + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + return true; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + return true; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + return false; + } + } no_mixed_precision_support_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_NORMALIZATION_H_ diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..66c3085842c4afe7ffc4d5891883e4cce9389d45 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -0,0 +1,248 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_normalization.h" +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { + +class TestBFloat16Support : public BFloat16Support { + public: + TestBFloat16Support() {} + ~TestBFloat16Support() override {} + + bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const override { + if (hlo.opcode() == HloOpcode::kAdd || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsBF16Output(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kReduce || + hlo.opcode() == HloOpcode::kSubtract || + hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } + + bool SupportsMixedPrecisions(const HloInstruction& hlo) const override { + if (hlo.opcode() == HloOpcode::kAdd || hlo.opcode() == HloOpcode::kTuple || + hlo.opcode() == HloOpcode::kGetTupleElement) { + return true; + } + return false; + } +}; + +class BFloat16NormalizationTest : public HloTestBase { + protected: + bool Normalize(HloModule* module) { + TestBFloat16Support bfloat16_support_; + BFloat16Normalization normalization(&bfloat16_support_); + StatusOr result = normalization.Run(module); + EXPECT_IS_OK(result.status()); + return result.ValueOrDie(); + } +}; + +TEST_F(BFloat16NormalizationTest, NoopIfSupported) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* add0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kAdd, a, b)); + + HloInstruction* add1 = builder.AddInstruction( + HloInstruction::CreateBinary(f32_shape, HloOpcode::kAdd, add0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_FALSE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), add1); + EXPECT_EQ(add0->shape().element_type(), BF16); + EXPECT_EQ(add1->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveIfUnsupportedBF16) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* mul0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, a, b)); + + HloInstruction* mul1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kMultiply, mul0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), mul1); + EXPECT_EQ(mul0->shape().element_type(), F32); + EXPECT_EQ(mul1->shape().element_type(), F32); + EXPECT_EQ(mul1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionSubtraction) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + HloInstruction* c = builder.AddInstruction( + HloInstruction::CreateParameter(2, f32_shape, "c")); + + HloInstruction* sub0 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, a, b)); + + HloInstruction* sub1 = builder.AddInstruction( + HloInstruction::CreateBinary(bf16_shape, HloOpcode::kSubtract, sub0, c)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction()->opcode(), HloOpcode::kConvert); + EXPECT_EQ(computation->root_instruction()->operand(0), sub1); + EXPECT_EQ(sub0->shape().element_type(), F32); + EXPECT_EQ(sub1->shape().element_type(), F32); + EXPECT_EQ(sub1->operand(0)->opcode(), HloOpcode::kConvert); +} + +TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { + Shape f32_input_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape f32_output_shape = ShapeUtil::MakeShape(F32, {4}); + + Shape bf16_scalar_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + auto reduce_comp_builder = HloComputation::Builder("reduce_comp"); + auto reduce_comp_param0 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(0, bf16_scalar_shape, "param0")); + auto reduce_comp_param1 = reduce_comp_builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "param1")); + reduce_comp_builder.AddInstruction( + HloInstruction::CreateBinary(bf16_scalar_shape, HloOpcode::kAdd, + reduce_comp_param0, reduce_comp_param1)); + + auto module = CreateNewModule(); + auto reduce_computation = + module->AddEmbeddedComputation(reduce_comp_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + HloInstruction* input = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_input_shape, "a")); + HloInstruction* init = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_scalar_shape, "init")); + HloInstruction* reduce = builder.AddInstruction(HloInstruction::CreateReduce( + f32_output_shape, input, init, {0}, reduce_computation)); + + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), reduce); + EXPECT_EQ(reduce->called_computations().size(), 1); + EXPECT_EQ(reduce->called_computations()[0]->num_parameters(), 2); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(0) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->parameter_instruction(1) + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->called_computations()[0] + ->root_instruction() + ->shape() + .element_type(), + F32); + EXPECT_EQ(reduce->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(0), input); + EXPECT_EQ(input->shape().element_type(), F32); + EXPECT_EQ(reduce->operand(1)->opcode(), HloOpcode::kConvert); + EXPECT_EQ(reduce->operand(1)->shape().element_type(), F32); +} + +TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto builder = HloComputation::Builder(TestName()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); + Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); + + HloInstruction* a = builder.AddInstruction( + HloInstruction::CreateParameter(0, f32_shape, "a")); + HloInstruction* b = builder.AddInstruction( + HloInstruction::CreateParameter(1, bf16_shape, "b")); + + HloInstruction* crs = + builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(Normalize(module.get())); + + EXPECT_EQ(computation->root_instruction(), gte); + EXPECT_EQ(gte->shape().element_type(), BF16); + EXPECT_EQ(crs->operand(1)->shape().element_type(), F32); + EXPECT_EQ(ShapeUtil::GetSubshape(crs->shape(), {1}).element_type(), F32); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc new file mode 100644 index 0000000000000000000000000000000000000000..3fd9e24601f27633c8063e4574c7c4f91f30dcff --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/bfloat16_support.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +bool BFloat16Support::SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + CHECK_EQ(operand_index, 0); + return hlo.operand(0)->shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsBF16Output(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + case HloOpcode::kConvert: + return hlo.shape().element_type() == BF16; + default: + break; + } + return false; +} + +bool BFloat16Support::SupportsMixedPrecisions(const HloInstruction& hlo) const { + switch (hlo.opcode()) { + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + return true; + default: + break; + } + return false; +} + +/* static */ +bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index) { + switch (hlo.opcode()) { + case HloOpcode::kAbs: + case HloOpcode::kBroadcast: + case HloOpcode::kClamp: + case HloOpcode::kConcatenate: + case HloOpcode::kCopy: + case HloOpcode::kGetTupleElement: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPad: + case HloOpcode::kReshape: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kSort: + case HloOpcode::kTranspose: + case HloOpcode::kTuple: + return true; + case HloOpcode::kDynamicSlice: + return operand_index == 0; + case HloOpcode::kDynamicUpdateSlice: + return operand_index == 0 || operand_index == 1; + case HloOpcode::kSelect: + return operand_index == 1 || operand_index == 2; + default: + break; + } + return false; +} + +bool BFloat16Support::EffectiveOperandPrecisionIsBF16( + const HloInstruction& hlo, int64 operand_index) const { + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h new file mode 100644 index 0000000000000000000000000000000000000000..29f662d22b4e5486662a1387407d41e0fd2ed1b3 --- /dev/null +++ b/tensorflow/compiler/xla/service/bfloat16_support.h @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { + +class BFloat16Support { + public: + BFloat16Support() {} + virtual ~BFloat16Support() {} + + // Returns whether the backend supports BF16 operand for the HLO instruction + // at the given index. + virtual bool SupportsBF16Operand(const HloInstruction& hlo, + int64 operand_index) const; + + // Returns whether the backend supports BF16 output for the HLO instruction. + virtual bool SupportsBF16Output(const HloInstruction& hlo) const; + + // Returns whether the backend support mixed precision: the operands, output, + // and parameters/output of the called computations can have different + // precisions (BF16 and F32). + virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const; + + // Returns whether the given HLO inherits its BF16 operand precision at the + // given index, so even if the output is F32, elements in the output that + // depend on the BF16 operand will still have BF16 effective precision even if + // they have F32 format. Similarly, this also means if the output is BF16 then + // increasing the operand precision from BF16 to F32 will not change the + // output. This typically includes HLOs that pass elements from the operand to + // the output without arithmetic operations. + static bool EffectiveOperandPrecisionIsOutputPrecision( + const HloInstruction& hlo, int64 operand_index); + + // Returns if the backend only uses BF16 precision for the operand at the + // specified index, even if the operand is F32. + virtual bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo, + int64 operand_index) const; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_SUPPORT_H_ diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 19a9ff04def5fc3d0b3739bbcf546a74114759a6..b1e693da9d5af4babe619b8796007f2da318f6a8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -45,6 +45,8 @@ using ::tensorflow::gtl::FlatMap; using ::tensorflow::gtl::FlatSet; using ::tensorflow::strings::Appendf; using ::tensorflow::strings::HumanReadableNumBytes; +using ::tensorflow::strings::Printf; +using ::tensorflow::strings::StrAppend; size_t BufferAllocation::Slice::Hasher::operator()(Slice s) const { uint64 h = std::hash()(s.index()); @@ -73,9 +75,10 @@ void BufferAllocation::AddAssignment(const LogicalBuffer& buffer, int64 offset, CHECK_LE(offset, size_) << "LogicalBuffer " << buffer << " offset out of range"; CHECK_LE(offset + size, size_) - << "LogicalBuffer " << buffer << " size out of range"; + << "LogicalBuffer " << buffer + << " size out of range at offset: " << offset << " with size: " << size; CHECK_EQ(buffer.color(), color()) - << "Buffer color " << buffer.color() + << "Buffer color " << buffer.color() << " for buffer " << buffer << " does not match allocation color " << color() << "."; OffsetSize offset_size; offset_size.offset = offset; @@ -92,6 +95,9 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto.set_color(color_.value()); if (is_entry_computation_parameter_) { proto.set_is_entry_computation_parameter(true); + for (int64 idx : param_shape_index()) { + proto.add_parameter_shape_index(idx); + } proto.set_parameter_number(parameter_number_); } proto.set_maybe_live_out(maybe_live_out_); @@ -111,25 +117,24 @@ BufferAllocationProto BufferAllocation::ToProto() const { string BufferAllocation::ToString() const { string output; - tensorflow::strings::StrAppend( - &output, tensorflow::strings::Printf("allocation %lld: %p, size %lld", - index_, this, size())); + Appendf(&output, "allocation %lld: %p, size %lld", index_, this, size()); if (color().value() != 0) { - tensorflow::strings::StrAppend(&output, ", color ", color().value()); + StrAppend(&output, ", color ", color().value()); } if (is_entry_computation_parameter()) { - tensorflow::strings::StrAppend(&output, ", parameter ", parameter_number()); + StrAppend(&output, ", parameter ", parameter_number(), " at ShapeIndex ", + param_shape_index().ToString()); } if (is_thread_local()) { - tensorflow::strings::StrAppend(&output, ", thread-local"); + StrAppend(&output, ", thread-local"); } if (maybe_live_out()) { - tensorflow::strings::StrAppend(&output, ", maybe-live-out"); + StrAppend(&output, ", maybe-live-out"); } if (IsPreallocatedTempBuffer()) { - tensorflow::strings::StrAppend(&output, ", preallocated-temp"); + StrAppend(&output, ", preallocated-temp"); } - tensorflow::strings::StrAppend(&output, ":\n"); + StrAppend(&output, ":\n"); // Dump the assigned buffers ordered by id. std::vector sorted_buffers; for (const auto& buffer_offset_size : assigned_buffers_) { @@ -141,12 +146,11 @@ string BufferAllocation::ToString() const { }); for (const LogicalBuffer* buffer : sorted_buffers) { const OffsetSize& offset_size = FindOrDie(assigned_buffers_, buffer); - tensorflow::strings::StrAppend( - &output, - tensorflow::strings::Printf( - " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), - offset_size.offset, offset_size.size, - ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); + StrAppend(&output, + tensorflow::strings::Printf( + " %s [%lld,%lld]: %s\n", buffer->ToString().c_str(), + offset_size.offset, offset_size.size, + ShapeUtil::HumanStringWithLayout(buffer->shape()).c_str())); } return output; } @@ -581,6 +585,7 @@ Status GatherComputationsByAllocationType( instruction->called_computations()) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: // Call and while must be called from a computation with global // allocations as they may return references to buffers inside the @@ -838,20 +843,19 @@ Status BufferAssigner::AssignBuffersForComputation( /*is_thread_local=*/false, /*is_reusable=*/false); allocation->set_entry_computation_parameter( - instruction->parameter_number()); + instruction->parameter_number(), buffer->index()); VLOG(3) << "New allocation #" << allocation->index() << " for entry computation parameter: " << *buffer; continue; } - if (is_thread_local || instruction->opcode() == HloOpcode::kCustomCall) { - // Custom call operations never have reusable buffers. Also we do not - // reuse thread-local buffers for now, because they are dynamically - // allocated and their lifetimes are hard to compute. + if (is_thread_local) { + // We do not reuse thread-local buffers for now, because they are + // dynamically allocated and their lifetimes are hard to compute. BufferAllocation* allocation = assignment->NewAllocation( *buffer, buffer_size, is_thread_local, /*is_reusable=*/false); VLOG(3) << "New allocation #" << allocation->index() - << " for thread-local/CustomCall: " << *buffer; + << " for thread-local: " << *buffer; continue; } @@ -976,8 +980,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( const HloOrdering& hlo_ordering = assignment->liveness().hlo_ordering(); if (run_whole_module_heap_simulation) { // Run the heap simulation over the whole module. This reduces memory usage, - // since buffers for kCall and kWhile sub-computations are only live for the - // duration of their calling instructions. + // since buffers for kCall, kWhile, and kConditional sub-computations are + // only live for the duration of their calling instructions. VLOG(1) << "Running whole-module heap simulation"; SequentialHloOrdering::HloModuleSequence module_sequence; FlatSet all_buffers_to_assign; @@ -996,14 +1000,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( auto color = single_colored_set.first; VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); + HeapSimulator::Options options; + options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( MakeUnique(alignment)), assignment->module(), module_sequence, assignment->points_to_analysis(), - assignment->buffer_size_, - &single_colored_set.second)); + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1023,14 +1028,15 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering( auto color = single_colored_set.first; VLOG(2) << "Simulating heap for color " << color; int64 alignment = assignment->color_alignment_(color); + HeapSimulator::Options options; + options.buffers_to_assign = &single_colored_set.second; TF_ASSIGN_OR_RETURN( const HeapSimulator::Result result, HeapSimulator::Run(MakeUnique( MakeUnique(alignment)), *computation, *instruction_sequence, assignment->points_to_analysis(), - assignment->buffer_size_, - &single_colored_set.second)); + assignment->buffer_size_, options)); AssignBuffersFromHeapSimulator(result, assignment, single_colored_set.first); } @@ -1119,140 +1125,6 @@ void BufferAssigner::AddSetToColocatedBufferSets( } } -// Conceptually the same as AddSetToColocatedBufferSets, but specific to the -// colocated buffers for while instructions. 'colocated_set' contains the -// buffers for a single while instruction that must be colocated. The idea here -// is to apply a memory-saving heuristic for separate while instructions whose -// buffers are disjoint in liveness, by using the colocation mechanism to force -// buffer sharing. This often reduces memory for multi-layer RNNs. -// -// TODO(b/32491382): We should be able to remove this heuristic after we -// implement module-level liveness analysis, which would let us directly detect -// buffer sharing opportunities between the while instruction buffer and the -// buffers from the predicate and body computation, as well as sharing across -// different while instructions. -void BufferAssigner::AddWhileSetToColocatedBufferSets( - const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, - const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, - const HloComputation& computation, const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size, - std::vector* colocated_buffer_sets) { - CHECK(!colocated_set.empty()); - const TuplePointsToAnalysis& points_to_analysis = - buffer_liveness.points_to_analysis(); - - // Parallel while loops cannot safely share colocated buffer sets. - if (buffer_liveness.hlo_ordering().SequentialOrder(computation) == nullptr) { - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); - return; - } - - // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets - // are added in postorder over computations and instructions. - const int64 init_buffer_size = buffer_size(*while_init_buffer); - const bool is_live_out = buffer_liveness.MaybeLiveOut(*while_result_buffer); - for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) { - const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i]; - - // Skip predecessor sets not associated with while loops. - if (std::all_of(predecessor_set.begin(), predecessor_set.end(), - [](const LogicalBuffer* buffer) { - return buffer->instruction()->opcode() != - HloOpcode::kWhile; - })) { - continue; - } - - // Skip predecessor sets already associated with 'while_hlo'. - if (std::any_of(predecessor_set.begin(), predecessor_set.end(), - [&while_hlo](const LogicalBuffer* buffer) { - return buffer->instruction() == while_hlo; - })) { - continue; - } - - // Skip predecessor sets with entry parameter if the while result is live - // out. - if (is_live_out && - std::any_of(predecessor_set.begin(), predecessor_set.end(), - [](const LogicalBuffer* buffer) { - auto* instruction = buffer->instruction(); - auto* computation = instruction->parent(); - auto* module = computation->parent(); - return instruction->opcode() == HloOpcode::kParameter && - computation == module->entry_computation(); - })) { - continue; - } - - // Build vector of predecessor while result and init buffers, which are - // checked for liveness interference below. We must check both the result - // and init buffers because they're aliased together, but - // TuplePointsToAnalysis is unaware of this aliasing. - std::vector predecessor_while_buffers; - for (const LogicalBuffer* buffer : predecessor_set) { - const HloInstruction* instruction = buffer->instruction(); - if (instruction->opcode() == HloOpcode::kWhile && - buffer_size(*buffer) == init_buffer_size && - instruction->parent() == &computation) { - predecessor_while_buffers.push_back(buffer); - // Add the init buffer at the same index, which must also exist in the - // predecessor set, and must be unambiguous. - const PointsToSet& init_points_to = - points_to_analysis.GetPointsToSet(instruction->operand(0)); - const auto& init_buffers = init_points_to.element(buffer->index()); - CHECK_EQ(init_buffers.size(), 1); - CHECK_GT(predecessor_set.count(init_buffers[0]), 0); - predecessor_while_buffers.push_back(init_buffers[0]); - } - } - if (predecessor_while_buffers.empty()) { - continue; - } - - // Skip predecessor set if the live range of any predecessor - // buffers overlaps with 'while_init_buffer' or - // 'while_result_buffer' (we need to check both since they're - // aliased together, but the points-to analysis is unaware of this - // aliasing). Note that tuple element buffer forwarding can cause - // the same buffer to appear on both sides of the interference - // comparison below. - auto may_interfere_with_init_or_result = [&](const LogicalBuffer* buffer) { - if (while_init_buffer->id() != buffer->id() && - buffer_liveness.MayInterfere(*while_init_buffer, *buffer)) { - return true; - } - - if (while_result_buffer->id() != buffer->id() && - buffer_liveness.MayInterfere(*while_result_buffer, *buffer)) { - return true; - } - - return false; - }; - - if (std::any_of(predecessor_while_buffers.begin(), - predecessor_while_buffers.end(), - may_interfere_with_init_or_result)) { - continue; - } - - // All our checks have passed; merge 'predecessor_set' with 'colocated_set', - // and add the merged set to 'colocated_buffer_sets'. This forces the - // colocation of buffers across different while instructions. - FlatSet unique; - unique.insert(predecessor_set.begin(), predecessor_set.end()); - unique.insert(colocated_set.begin(), colocated_set.end()); - std::vector merged_set(unique.begin(), unique.end()); - AddSetToColocatedBufferSets(merged_set, colocated_buffer_sets); - return; - } - - // Failed to merge into predecessor set; add 'colocated_set' as-is. - AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); -} - namespace { // Checks that points-to set of 'instruction' is unambiguous and distinct @@ -1269,10 +1141,133 @@ const LogicalBuffer* AddBufferToColocatedSet( return colocated_set->back(); } +// Given the interference map of a graph (the list of interfering node indices +// for each node), perform graph coloring such that interfering nodes are +// assigned to different colors. Returns the assigned color of the nodes, where +// the colors are represented as integer values [0, color_count). +std::vector ColorInterferenceGraph( + const std::vector>& interference_map) { + const int64 node_count = interference_map.size(); + + // Sort the nodes such that we assign nodes with more interference first. This + // relies on the common heuristic of assigning the most constrained node + // first, but it would be good to investigate other ordering heuristics too. + std::vector nodes(node_count); + std::iota(nodes.begin(), nodes.end(), 0); + std::sort(nodes.begin(), nodes.end(), + [&interference_map](const int64 i, const int64 j) { + return interference_map[i].size() > interference_map[j].size(); + }); + + const int64 kColorUnassigned = -1; + std::vector assigned_colors(node_count, kColorUnassigned); + for (int64 node : nodes) { + // Mark the colors that are already assigned to the neighbors. + std::vector available_colors(node_count, true); + for (int64 neighbor : interference_map[node]) { + int64 color = assigned_colors[neighbor]; + if (color != kColorUnassigned) { + available_colors[color] = false; + } + } + + // Find the color that is not yet assigned to the neighbors. + int64 color = kColorUnassigned; + for (color = 0; color < available_colors.size(); ++color) { + if (available_colors[color]) { + break; + } + } + CHECK_NE(color, kColorUnassigned); + assigned_colors[node] = color; + } + return assigned_colors; +} + } // namespace +std::vector +BufferAssigner::MergeColocatedBufferSets( + const std::vector& colocated_buffer_sets, + const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size) { + VLOG(1) << "colocation sets count before coalescing:" + << colocated_buffer_sets.size(); + + // Returns true if the given buffer is for the entry parameter. + auto is_entry_parameter = [](const LogicalBuffer& buffer) { + auto* instruction = buffer.instruction(); + auto* computation = instruction->parent(); + auto* module = computation->parent(); + return instruction->opcode() == HloOpcode::kParameter && + computation == module->entry_computation(); + }; + + // Returns true if the two colocated buffer sets (specified by their indices + // into the colocated_buffer_sets) can be merged into a single set. + auto cannot_merge_buffer_sets = [&colocated_buffer_sets, &buffer_liveness, + &buffer_size, + &is_entry_parameter](int64 i, int64 j) { + for (auto& buffer_a : colocated_buffer_sets[i]) { + for (auto& buffer_b : colocated_buffer_sets[j]) { + // Do not merge if the set includes live outs or entry parameters. + if ((buffer_liveness.MaybeLiveOut(*buffer_a) && + is_entry_parameter(*buffer_b)) || + (buffer_liveness.MaybeLiveOut(*buffer_b) && + is_entry_parameter(*buffer_a))) { + return true; + } + // Do not merge if the buffers interfere with each other. + if (buffer_a->id() != buffer_b->id() && + buffer_liveness.MayInterfere(*buffer_a, *buffer_b)) { + return true; + } + // Do not merge if the buffer sizes are different. + if (buffer_size(*buffer_a) != buffer_size(*buffer_b)) { + return true; + } + } + } + return false; + }; + + // Build the interference map among the colocated buffer sets (nodes), by + // adding an edge between any two nodes that cannot be merged into a single + // colocated buffer set. + std::vector> interference_map( + colocated_buffer_sets.size()); + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + for (int64 j = i + 1; j < colocated_buffer_sets.size(); ++j) { + if (cannot_merge_buffer_sets(i, j)) { + interference_map[i].push_back(j); + interference_map[j].push_back(i); + } + } + } + + // Assign a color to each colocation set in colocated_buffer_sets, such that + // the sets that can be merged are assigned with the same color. + auto assigned_colors = ColorInterferenceGraph(interference_map); + + // Merge the buffer sets with the same color. + CHECK(!assigned_colors.empty()); + int64 num_sets = + *std::max_element(assigned_colors.begin(), assigned_colors.end()) + 1; + std::vector new_colocated_buffer_sets(num_sets); + for (int64 i = 0; i < colocated_buffer_sets.size(); ++i) { + const auto& buffer_set = colocated_buffer_sets[i]; + new_colocated_buffer_sets[assigned_colors[i]].insert(buffer_set.begin(), + buffer_set.end()); + } + + VLOG(1) << "colocation sets count after coalescing:" + << colocated_buffer_sets.size(); + return new_colocated_buffer_sets; +} + // Builds sets of buffers in 'colocated_buffer_sets' which should be colocated -// in the same allocation (currently just supports kWhile and kCall). +// in the same allocation (currently just supports kWhile, kCall, and +// kConditional). void BufferAssigner::BuildColocatedBufferSets( const HloModule* module, const BufferLiveness& buffer_liveness, const LogicalBuffer::SizeFunction& buffer_size, @@ -1295,12 +1290,11 @@ void BufferAssigner::BuildColocatedBufferSets( const Shape& /*subshape*/, const ShapeIndex& index) { std::vector colocated_set; // Add while.init. - auto* init_buffer = - AddBufferToColocatedSet(while_hlo->operand(0), index, - points_to_analysis, &colocated_set); + AddBufferToColocatedSet(while_hlo->operand(0), index, + points_to_analysis, &colocated_set); // Add while.result. - auto* result_buffer = AddBufferToColocatedSet( - while_hlo, index, points_to_analysis, &colocated_set); + AddBufferToColocatedSet(while_hlo, index, points_to_analysis, + &colocated_set); // Add while.cond.parameter. AddBufferToColocatedSet( while_hlo->while_condition()->parameter_instruction(0), index, @@ -1313,10 +1307,7 @@ void BufferAssigner::BuildColocatedBufferSets( AddBufferToColocatedSet( while_hlo->while_body()->root_instruction(), index, points_to_analysis, &colocated_set); - AddWhileSetToColocatedBufferSets( - colocated_set, init_buffer, result_buffer, while_hlo, - *computation, buffer_liveness, buffer_size, - colocated_buffer_sets); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); } else if (opcode == HloOpcode::kCall) { const HloInstruction* call_hlo = instruction; @@ -1336,9 +1327,82 @@ void BufferAssigner::BuildColocatedBufferSets( &colocated_set); AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); }); + } else if (opcode == HloOpcode::kConditional) { + const HloInstruction* conditional_hlo = instruction; + ShapeUtil::ForEachSubshape( + conditional_hlo->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector colocated_set; + // Add conditional.result. + AddBufferToColocatedSet(conditional_hlo, index, + points_to_analysis, &colocated_set); + // Add conditional.true_computation.root. + AddBufferToColocatedSet( + conditional_hlo->true_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + // Add conditional.false_computation.root. + AddBufferToColocatedSet( + conditional_hlo->false_computation()->root_instruction(), + index, points_to_analysis, &colocated_set); + AddSetToColocatedBufferSets(colocated_set, colocated_buffer_sets); + }); + + // Add true_operand and conditional.true_computation.parameter(0) as a + // colocated buffer set. Note that this has to be done for each subshape + // in the true_operand of the conditional. + ShapeUtil::ForEachSubshape( + conditional_hlo->operand(1)->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector true_set; + // Add conditional.true_operand. + AddBufferToColocatedSet(conditional_hlo->operand(1), index, + points_to_analysis, &true_set); + // Add conditional.true_computation.parameter_instruction(0). + AddBufferToColocatedSet( + conditional_hlo->true_computation()->parameter_instruction(0), + index, points_to_analysis, &true_set); + AddSetToColocatedBufferSets(true_set, colocated_buffer_sets); + }); + + // Add false_operand and conditional.false_computation.parameter(0) as a + // colocated buffer set. Note that this has to be done for each subshape + // in the false_operand of the conditional. + ShapeUtil::ForEachSubshape( + conditional_hlo->operand(2)->shape(), + [this, conditional_hlo, &points_to_analysis, colocated_buffer_sets]( + const Shape& /*subshape*/, const ShapeIndex& index) { + std::vector false_set; + // Add conditional.false_operand. + AddBufferToColocatedSet(conditional_hlo->operand(2), index, + points_to_analysis, &false_set); + // Add conditional.false_computation.parameter_instruction(0). + AddBufferToColocatedSet( + conditional_hlo->false_computation()->parameter_instruction( + 0), + index, points_to_analysis, &false_set); + AddSetToColocatedBufferSets(false_set, colocated_buffer_sets); + }); } } } + + if (colocated_buffer_sets->empty()) { + return; + } + + // Try to find more coalescing opportunities among the colocated buffer sets. + // + // TODO(b/32491382): We should be able to remove this by using the + // module-level liveness analysis, which would let us directly detect buffer + // sharing opportunities between the while instruction buffer and the buffers + // from the predicate and body computation, as well as sharing across + // different while instructions. + std::vector new_colocated_buffer_sets = + MergeColocatedBufferSets(*colocated_buffer_sets, buffer_liveness, + buffer_size); + std::swap(*colocated_buffer_sets, new_colocated_buffer_sets); } // Assigns all colocated buffer sets in 'colocated_buffer_sets' to the same @@ -1350,39 +1414,47 @@ void BufferAssigner::AssignColocatedBufferSets( FlatSet* colocated_allocations) { for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) { BufferAllocation* allocation = nullptr; - // Set 'entry_parameter_number' if entry param in 'colocated_buffer_set'. + // Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry + // param in 'colocated_buffer_set'. int64 entry_parameter_number = -1; + const ShapeIndex* entry_parameter_shape_idx = nullptr; for (const LogicalBuffer* buffer : colocated_buffer_set) { const HloInstruction* instruction = buffer->instruction(); const HloComputation* computation = instruction->parent(); if (instruction->opcode() == HloOpcode::kParameter && computation == computation->parent()->entry_computation()) { entry_parameter_number = instruction->parameter_number(); + entry_parameter_shape_idx = &buffer->index(); break; } } for (const LogicalBuffer* buffer : colocated_buffer_set) { + const int64 buffer_size = assignment->buffer_size_(*buffer); if (allocation == nullptr) { // TODO(b/32491382) Avoid current trivial solution of using new // allocations for each colocated buffer set. When liveness has // module-level scope, we can allow buffers to be shared across // computations (in some cases). - allocation = assignment->NewAllocation( - *buffer, assignment->buffer_size_(*buffer), - /*is_thread_local=*/false, /*is_reusable=*/true); + allocation = assignment->NewAllocation(*buffer, buffer_size, + /*is_thread_local=*/false, + /*is_reusable=*/true); if (entry_parameter_number >= 0) { // This colocated buffer set contains an entry parameter and other // logical buffers which use the parameter as read-only in a while // body computation (which updates in place). // Set 'entry_computation_parameter' to indicate that it contains // an entry parameter, and to prevent reuse in MaybeAssignBuffer. - allocation->set_entry_computation_parameter(entry_parameter_number); + allocation->set_entry_computation_parameter( + entry_parameter_number, *entry_parameter_shape_idx); } colocated_allocations->insert(allocation->index()); } else { + CHECK_EQ(buffer_size, allocation->size()) + << "Buffer: " << *buffer << " size mismatch in colocated buffer " + << "allocation: " << *allocation; assignment->AddAssignment(allocation, *buffer, /*offset=*/0, - assignment->buffer_size_(*buffer)); + buffer_size); } colocated_buffers->insert(buffer); } diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 08a40bfeb2a2a78c25805308e73154c6cc667f21..6b7fd0014d103ef0617afcc5cb3f663554a01aa4 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -91,6 +91,13 @@ class BufferAllocation { return parameter_number_; } + // If this allocation is for a parameter of the entry computation, this + // function returns which subshape of the parameter the allocation is for. + const ShapeIndex& param_shape_index() const { + CHECK(is_entry_computation_parameter_); + return param_shape_index_; + } + // Returns whether this allocation is assigned a LogicalBuffer which may // be live out of the entry computation. bool maybe_live_out() const { return maybe_live_out_; } @@ -203,9 +210,11 @@ class BufferAllocation { // Adds a LogicalBuffer to the set assigned to this buffer. void AddAssignment(const LogicalBuffer& buffer, int64 offset, int64 size); - void set_entry_computation_parameter(int64 parameter_number) { + void set_entry_computation_parameter(int64 parameter_number, + ShapeIndex param_shape_index) { is_entry_computation_parameter_ = true; parameter_number_ = parameter_number; + param_shape_index_ = std::move(param_shape_index); } void set_maybe_live_out(bool value) { maybe_live_out_ = value; } void set_index(Index index) { index_ = index; } @@ -235,6 +244,10 @@ class BufferAllocation { // indicates the index (starting from 0) of the parameter. int64 parameter_number_ = 0; + // If this buffer is for an entry computation parameter, which subshape of the + // parameter is it for? + ShapeIndex param_shape_index_; + // Whether the allocation contains a LogicalBuffer which may be live-out of // the entry computation. Note that this flag is conservatively computed by // TuplePointsToAnalysis. That is, an allocation marked `maybe_live_out_` @@ -528,15 +541,13 @@ class BufferAssigner { const std::vector& colocated_set, std::vector* colocated_buffer_sets); - // Conceptually the same as AddSetToColocatedBufferSets, but specific to the - // colocated buffers for while instructions. - void AddWhileSetToColocatedBufferSets( - const std::vector& colocated_set, - const LogicalBuffer* while_init_buffer, - const LogicalBuffer* while_result_buffer, const HloInstruction* while_hlo, - const HloComputation& computation, const BufferLiveness& buffer_liveness, - const LogicalBuffer::SizeFunction& buffer_size, - std::vector* colocated_buffer_sets); + // Given a list of colocated buffer sets (each colocated buffer set represents + // the logical buffers that would be assigned to the same physical buffer), + // try to merge the sets if the buffers can be shared. Returns the merged set. + std::vector MergeColocatedBufferSets( + const std::vector& colocated_buffer_sets, + const BufferLiveness& buffer_liveness, + const LogicalBuffer::SizeFunction& buffer_size); // Split a set of buffers into several sets, each of which contains buffers // colored with the same color. diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index 8fba8ef5e5c799eaac429017f4a0ff6a0315ba7c..cd73654b8f666c4b96c000235cc3ad2cd0a46c17 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -166,6 +166,15 @@ class BufferAssignmentTest : public HloTestBase { return builder.Build(); } + std::unique_ptr BuildR0F32UnaryOpComputation( + HloOpcode opcode, const string& name) { + auto builder = HloComputation::Builder(name); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + builder.AddInstruction(HloInstruction::CreateUnary(r0f32_, opcode, param)); + return builder.Build(); + } + // Verifies that the given instruction hlo has a valid input buffer assigned, // i.e., the parameter number matches the op's. const BufferAllocation& GetAssignedInputAllocation( @@ -605,7 +614,7 @@ TEST_F(BufferAssignmentTest, TrivialMap) { BufferAllocation map_buffer = GetAssignedOutputAllocation(*buffers, map); EXPECT_NE(param0_buffer.index(), map_buffer.index()); - // The final computation node of the map is an add of an f32 parm and a + // The final computation node of the map is an add of an f32 param and a // constant. EXPECT_EQ(HloOpcode::kAdd, inner_last->opcode()); const BufferAllocation& inner_add_buffer = @@ -740,6 +749,56 @@ TEST_F(BufferAssignmentTest, ExampleWhile) { << " instructions; total buffer size " << size0 + sizec + sizeb; } +TEST_F(BufferAssignmentTest, ExampleConditional) { + auto module = CreateNewModule(); + auto true_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kCeil, "Ceil")); + auto false_computation = module->AddEmbeddedComputation( + BuildR0F32UnaryOpComputation(HloOpcode::kFloor, "Floor")); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + auto const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.4f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + r0f32_, pred, const1, true_computation, const2, false_computation)); + module->AddEntryComputation(builder.Build()); + + const std::vector conditional_instrs = + GetInstructions(conditional); + const std::vector true_instrs = + GetInstructions(true_computation->root_instruction()); + const std::vector false_instrs = + GetInstructions(false_computation->root_instruction()); + EXPECT_EQ(4, conditional_instrs.size()); + EXPECT_EQ(2, true_instrs.size()); + EXPECT_EQ(2, false_instrs.size()); + + auto buffers = RunBufferAssignment(module.get()); + ValidateBuffers(conditional_instrs, *buffers); + ValidateBuffers(true_instrs, *buffers); + ValidateBuffers(false_instrs, *buffers); + + EXPECT_FALSE(BuffersDistinct(conditional_instrs, true_instrs, *buffers)) + << "Should be reuse between conditional and true computation."; + EXPECT_FALSE(BuffersDistinct(conditional_instrs, false_instrs, *buffers)) + << "Should be reuse between conditional and false computation."; + EXPECT_FALSE(BuffersDistinct(true_instrs, false_instrs, *buffers)) + << "Should be reuse between true and false computations."; + + const BufferAllocation& conditional_buffer = + GetTopLevelAllocation(*buffers, conditional); + const BufferAllocation& true_buffer = + GetTopLevelAllocation(*buffers, true_computation->root_instruction()); + const BufferAllocation& false_buffer = + GetTopLevelAllocation(*buffers, false_computation->root_instruction()); + EXPECT_EQ(conditional_buffer.size(), true_buffer.size()); + EXPECT_EQ(conditional_buffer.size(), false_buffer.size()); +} + TEST_F(BufferAssignmentTest, UnaryOpReuseChain) { // param0[100] ---> (exp) ---> (tanh) ---> (exp) ---> (neg) auto builder = HloComputation::Builder(TestName()); @@ -1360,10 +1419,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) { HloInstruction::CreateParameter(1, shape_3x4, "param_b")); auto param_c = builder.AddInstruction( HloInstruction::CreateParameter(2, shape_4x4, "param_c")); - auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary( - shape_2x4, HloOpcode::kDot, param_a, param_b)); - auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary( - shape_3x4, HloOpcode::kDot, param_b, param_c)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_ab = builder.AddInstruction( + HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums)); + auto dot_bc = builder.AddInstruction( + HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums)); builder.AddInstruction( HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1)); @@ -1525,6 +1587,117 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) { assignment->GetUniqueSlice(while1, {1}).ConsumeValueOrDie()); } +// Tests that the colocated buffers for while instructions are properly assigned +// during buffer assignment such that the result tuple elements are not assigned +// to the same buffer. +// +// %infeed --> %while.0 --> %while.1 --+ +// +-- %tuple +// %zero --> %add --> %while.2 --+ +// +// Execution Order: +// %infeed -> %while.0 -> %while.1 -> %zero -> %add -> %while.2 -> %tuple +// +// The HLO computation used in this test requires specific ordering to expose +// the bug (b/72496031). During buffer assignment, the visitation order of +// colocated buffers is %while.2 -> while.0 -> while.1, and the buffer +// assignment was coalescing the colocated buffers for all 3 while instructions, +// therefore assigning the same buffer to the two result tuple elements. +TEST_F(WhileBufferAssignmentTest, ColocatedBuffers) { + const Shape r0s32 = ShapeUtil::MakeShape(S32, {}); + + // Builds a condition computation: x -> x < 4 + auto build_cond = [&]() { + auto builder = HloComputation::Builder("cond"); + auto const4 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(4))); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); + builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, const4)); + return builder.Build(); + }; + + // Builds a body computation: x -> x + 9 + auto build_body = [&]() { + auto builder = HloComputation::Builder("body"); + auto const9 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(9))); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "x")); + builder.AddInstruction( + HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, param, const9)); + return builder.Build(); + }; + + // Build the entry computation as described in the comment above. + auto module = xla::MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto infeed = builder.AddInstruction(HloInstruction::CreateInfeed(r0s32, "")); + auto cond0 = module->AddEmbeddedComputation(build_cond()); + auto body0 = module->AddEmbeddedComputation(build_body()); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(r0s32, cond0, body0, infeed)); + + auto cond1 = module->AddEmbeddedComputation(build_cond()); + auto body1 = module->AddEmbeddedComputation(build_body()); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(r0s32, cond1, body1, while0)); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(r0s32, HloOpcode::kAdd, zero, zero)); + auto cond2 = module->AddEmbeddedComputation(build_cond()); + auto body2 = module->AddEmbeddedComputation(build_body()); + auto while2 = builder.AddInstruction( + HloInstruction::CreateWhile(r0s32, cond2, body2, add)); + + auto tuple = + builder.AddInstruction(HloInstruction::CreateTuple({while2, while1})); + module->AddEntryComputation(builder.Build()); + + // Run CopyInsertion and check if the graph constructed above doesn't need + // any copies inserted for BufferAssignment to run. + int64 instruction_count = module->instruction_count(); + CopyInsertion copy_insertion; + ASSERT_IS_OK(copy_insertion.Run(module.get()).status()); + ASSERT_EQ(instruction_count, module->instruction_count()); + + // Create a sequential order among all the instructions in the entry + // computation, since the issue this test stresses depends on the order the + // nodes are traversed during BufferAssignment. + SequentialHloOrdering::HloModuleSequence sequence; + sequence[module->entry_computation()] = {infeed, while0, while1, zero, + add, while2, tuple}; + TF_ASSERT_OK_AND_ASSIGN( + auto assignment, + BufferAssigner::Run( + module.get(), + xla::MakeUnique(module.get(), sequence), + backend().compiler()->BufferSizeBytesFunction(), + [](LogicalBuffer::Color) { return 1; })); + + // The result tuple elements must be assigned with different buffers. + TF_ASSERT_OK_AND_ASSIGN(auto slice0, assignment->GetUniqueSlice(tuple, {0})); + TF_ASSERT_OK_AND_ASSIGN(auto slice1, assignment->GetUniqueSlice(tuple, {1})); + EXPECT_NE(slice0, slice1); + + // while0 and while1 result buffers must be equal to slice1. + TF_ASSERT_OK_AND_ASSIGN(auto slice_while0, + assignment->GetUniqueSlice(while0, {})); + TF_ASSERT_OK_AND_ASSIGN(auto slice_while1, + assignment->GetUniqueSlice(while1, {})); + EXPECT_EQ(slice1, slice_while0); + EXPECT_EQ(slice1, slice_while1); + + // while2 result buffer must be equal to slice0. + TF_ASSERT_OK_AND_ASSIGN(auto slice_while2, + assignment->GetUniqueSlice(while2, {})); + EXPECT_EQ(slice0, slice_while2); +} + TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) { auto module = xla::MakeUnique(TestName()); auto builder = HloComputation::Builder("entry"); @@ -1708,9 +1881,8 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) { BufferAssigner::Run( module.get(), xla::MakeUnique(module.get(), sequence), - ByteSizeOf, - [](LogicalBuffer::Color) { return 1; }) - .ConsumeValueOrDie(); + ByteSizeOf, [](LogicalBuffer::Color) { return 1; }) + .ConsumeValueOrDie(); EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 513bfa3b7f7b45696093d03c1dd8250c548d260a..37982aaef9eddd64ef6b57ad5a9cf8dd6a565097 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -102,8 +102,8 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, return false; } - // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (const BufferAlias& alias : points_to_analysis_->GetBufferAliases(a)) { + // Every user of 'a' must be a predecessor of 'b' or 'b' itself. for (auto user : alias.instruction()->users()) { if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), user, points_to_analysis())) { @@ -114,6 +114,17 @@ bool BufferLiveness::live_range_strictly_before(const LogicalBuffer& a, return false; } } + + // If the root instruction aliases the buffer 'a', the live range of 'a' is + // until the end of the computation and can never be strictly before another + // buffer defined in the same computation. This is needed to prevent the + // root instruction's buffers from being reused by later instructions even + // when the root is not the last instruction in the schedule. + if (alias.instruction()->parent()->root_instruction() == + alias.instruction() && + alias.instruction()->parent() == b.instruction()->parent()) { + return false; + } } // If 'b' is a user of 'a' then the buffers interfere unless 'a.instruction' diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc index bbb42d494b8003176d4911bacbe8a10dc5fc7c6a..f623aef67a4f98b447a9a15634a78deb60cfe6f1 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc @@ -167,11 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) { SequentialHloOrdering::HloModuleSequence sequence; sequence.insert({entry, {param0, negate, param1, exp, add}}); - auto liveness = BufferLiveness::Run( - module.get(), - xla::MakeUnique( - module.get(), sequence)) - .ConsumeValueOrDie(); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), sequence)) + .ConsumeValueOrDie(); // Entry parameters interfere as if they are defined simultaneously at // the very beginning. @@ -296,7 +295,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { module_sequence.emplace(computation, order); auto liveness = BufferLiveness::Run(module.get(), xla::MakeUnique( - module.get(), module_sequence)) + module.get(), module_sequence)) .ConsumeValueOrDie(); EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate)); @@ -312,6 +311,48 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) { EXPECT_FALSE(InstructionsMayInterfere(*liveness, add, exp)); } +TEST_F(BufferLivenessTest, RootInstructionIsNotLastInSequentialOrder) { + // Tests that when the root instruction is not the last instruction in the + // schedule, the live range of its buffers interfere with the buffers of the + // later instructions. + // + // Two sets of independent instructions are executed in the computation. + // param --> add (root) + // recv --> recv-done --> send --> send-done + // + // Sequential order: + // param, add (root), recv, recv-done, send, send-done + auto builder = HloComputation::Builder(TestName()); + auto param = + builder.AddInstruction(HloInstruction::CreateParameter(0, vec_, "param")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(vec_, HloOpcode::kAdd, param, param)); + auto recv = builder.AddInstruction( + HloInstruction::CreateRecv(vec_, /*channel_id=*/0)); + auto recv_done = builder.AddInstruction(HloInstruction::CreateRecvDone(recv)); + auto send = builder.AddInstruction( + HloInstruction::CreateSend(recv_done, /*channel_id=*/1)); + auto send_done = builder.AddInstruction(HloInstruction::CreateSendDone(send)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build(add)); + + SequentialHloOrdering::HloModuleSequence module_sequence; + std::vector order = {param, add, recv, + recv_done, send, send_done}; + module_sequence.emplace(computation, order); + auto liveness = + BufferLiveness::Run(module.get(), xla::MakeUnique( + module.get(), module_sequence)) + .ConsumeValueOrDie(); + + EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, add)); + // Check the root instruction (add) buffer interferes with the recv buffer. + EXPECT_TRUE( + liveness->MayInterfere(GetBuffer(*liveness, add, /*index=*/{}), + GetBuffer(*liveness, recv, /*index=*/{0}))); +} + TEST_F(BufferLivenessTest, TupleLiveOut) { // Verify MaybeLiveOut with nested tuples. Result of computation looks like: // @@ -625,9 +666,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest { // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique( - module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. @@ -738,9 +778,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest { module->AddEmbeddedComputation(builder.Build()); // Run BufferLiveness on 'module'. auto liveness = - BufferLiveness::Run(module.get(), - xla::MakeUnique( - module.get())) + BufferLiveness::Run( + module.get(), xla::MakeUnique(module.get())) .ConsumeValueOrDie(); // Return whether or not buffers interference is detected between // 'tuple_param0' and 'tuple_root' at shape index '{1}'. diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index 1adecdb939cb2c1259003d3be2c90b5a299b0f30..13eb02ca012f44b2b5ed7c6f5becb7d54b07c33c 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -54,6 +54,7 @@ std::ostream& operator<<(std::ostream& out, const CallContext& context) { CallContext GetInstructionCallContext(const HloInstruction* instruction) { switch (instruction->opcode()) { case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/call_graph_test.cc b/tensorflow/compiler/xla/service/call_graph_test.cc index 0395ea8c8b52315f7ca2221f412750ebadda2dd8..1ea7d538cd515c3098b6a1f03c6146d288330406 100644 --- a/tensorflow/compiler/xla/service/call_graph_test.cc +++ b/tensorflow/compiler/xla/service/call_graph_test.cc @@ -34,12 +34,13 @@ using ::testing::UnorderedElementsAre; class CallGraphTest : public HloTestBase { protected: // Build and return a trivial computation taking and returning a scalar. - std::unique_ptr MakeScalarComputation() { + std::unique_ptr MakeScalarComputation( + HloOpcode opcode = HloOpcode::kNegate) { HloComputation::Builder builder(TestName() + ".ScalarComputation"); HloInstruction* param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, kScalarShape, "param0")); builder.AddInstruction( - HloInstruction::CreateUnary(kScalarShape, HloOpcode::kNegate, param0)); + HloInstruction::CreateUnary(kScalarShape, opcode, param0)); return builder.Build(); } @@ -236,6 +237,54 @@ TEST_F(CallGraphTest, ContextBothComputations) { EXPECT_EQ(CallContext::kBoth, sub_node.context()); } +TEST_F(CallGraphTest, ComputationWithConditional) { + // Test a call graph of a module with a conditional. + auto module = CreateNewModule(); + HloComputation* true_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kCeil)); + HloComputation* false_computation = + module->AddEmbeddedComputation(MakeScalarComputation(HloOpcode::kFloor)); + + HloComputation::Builder builder(TestName()); + HloInstruction* pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloInstruction* const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.4f))); + HloInstruction* const2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.6f))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, const1, true_computation, const2, + false_computation)); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + std::unique_ptr call_graph = CallGraph::Build(module.get()); + + EXPECT_EQ(3, call_graph->nodes().size()); + + const CallGraphNode& entry_node = call_graph->GetNode(entry_computation); + EXPECT_EQ(entry_computation, entry_node.computation()); + EXPECT_EQ(1, entry_node.callsites().size()); + + const CallSite& conditional_callsite = entry_node.callsites()[0]; + EXPECT_EQ(conditional, conditional_callsite.instruction()); + EXPECT_THAT(conditional_callsite.called_computations(), + UnorderedElementsAre(true_computation, false_computation)); + EXPECT_EQ(CallContext::kSequential, conditional_callsite.context()); + EXPECT_EQ(entry_node.GetCallSite(conditional), &conditional_callsite); + + const CallGraphNode& true_node = call_graph->GetNode(true_computation); + EXPECT_TRUE(true_node.callees().empty()); + EXPECT_EQ(1, true_node.callers().size()); + EXPECT_EQ(entry_computation, true_node.callers()[0]); + + const CallGraphNode& false_node = call_graph->GetNode(false_computation); + EXPECT_TRUE(false_node.callees().empty()); + EXPECT_EQ(1, false_node.callers().size()); + EXPECT_EQ(entry_computation, false_node.callers()[0]); +} + TEST_F(CallGraphTest, ComplexGraph) { // Test a call graph of a module with several computation called in various // contexts. The call graph looks like: diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 3aa7f5c4d5829ccc0e8df697c1363754128ff436..482ccc5b67109258f544e5657ecfa0e8f62192c0 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -82,6 +82,10 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { return outer_->ReplaceInstruction(call_, new_root); } + CallInliner::InlinedInstructionMap ConsumeInstructionMap() { + return std::move(subcomputation_hlo_to_new_hlo_); + } + private: // Resolves the callee subcomputation_hlo to the new (inline) HLO in the // caller computation, or returns a NotFound error if that subcomputation HLO @@ -112,13 +116,13 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloInstruction* call_; HloComputation* outer_; - std::unordered_map - subcomputation_hlo_to_new_hlo_; + CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_; }; } // namespace -/* static */ Status CallInliner::Inline(HloInstruction* call) { +/* static */ StatusOr CallInliner::Inline( + HloInstruction* call) { TF_RET_CHECK(call->opcode() == HloOpcode::kCall) << "Instruction was not a call op: " << call->opcode(); const auto& callees = call->called_computations(); @@ -126,7 +130,8 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloComputation* callee = callees[0]; // We visit the callee, cloning its body into its caller. SubcomputationInsertionVisitor visitor(call); - return callee->Accept(&visitor); + TF_RETURN_IF_ERROR(callee->Accept(&visitor)); + return visitor.ConsumeInstructionMap(); } StatusOr CallInliner::Run(HloModule* module) { @@ -140,7 +145,7 @@ StatusOr CallInliner::Run(HloModule* module) { VLOG(1) << "Visiting callsite: " << callsite.ToString(); if (callsite.instruction()->opcode() == HloOpcode::kCall) { HloInstruction* call = callsite.instruction(); - TF_RETURN_IF_ERROR(Inline(call)); + TF_RETURN_IF_ERROR(Inline(call).status()); did_mutate = true; } } diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 2dbd38bf1ac90d3efa1453e6af6f791668d5e72a..a8345a394d46c90a48305313dac0bcd9b06938ac 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -27,8 +27,12 @@ namespace xla { // called function, and proceed recursively. class CallInliner : public HloPassInterface { public: - // Inlines one call instruction. - static Status Inline(HloInstruction* call); + using InlinedInstructionMap = + std::unordered_map; + + // Inlines one call instruction. Returns a mapping from the original + // instructions to their inlined versions. + static StatusOr Inline(HloInstruction* call); ~CallInliner() override = default; tensorflow::StringPiece name() const override { return "CallInliner"; } diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index 865ed993da121d26ceb61123f1822d93814cbb9b..738d00881dd057fc13c115006c15e8f5b6d14a1d 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -135,7 +135,7 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) { HloInstruction::CreateCall(pred, {}, false_computation)); auto computation = module->AddEntryComputation(call_false_builder.Build()); - TF_ASSERT_OK(CallInliner::Inline(call)); + TF_ASSERT_OK(CallInliner::Inline(call).status()); EXPECT_THAT(computation->root_instruction(), op::Constant()); EXPECT_THAT(computation->root_instruction()->control_successors(), ElementsAre(op::Constant())); diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 9e96898d9b4215e67c8686d372e4b4e6edd1d88b..dab73596e1639eed62151197048ee8d29570b20a 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -101,12 +101,13 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options)); + &execution_options, *user_computation)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( versioned_handle, *module_config, /*include_unreachable_instructions=*/true)); + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*hlo_module)); hlo_modules.push_back(std::move(hlo_module)); } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index fc67330f5cbdbcb0d1a259d284599916a908d1fe..74fd24edf88d44b2dfdc87556b0af43987e69e08 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -72,8 +72,18 @@ class AotCompilationOptions { // Returns the ID of the platform to which these options apply. virtual perftools::gputools::Platform::Id PlatformId() const = 0; + // Optional allocator that may be used for allocating temp space on the device + // during compilation. + DeviceMemoryAllocator* device_allocator() const { return device_allocator_; } + void set_device_allocator(DeviceMemoryAllocator* device_allocator) { + device_allocator_ = device_allocator; + } + protected: AotCompilationOptions() = default; + + private: + DeviceMemoryAllocator* device_allocator_ = nullptr; }; // Abstract compiler interface that is subclassed for compilation on a @@ -99,9 +109,16 @@ class Compiler { // Runs Hlo passes to optimize the given Hlo module, returns the optimized // module. + // + // If device_allocator is not null, the compiler may use it to allocate temp + // space on the device for use during compilation. For example, the compiler + // may allocate buffers on the device and then run variants of a given + // algorithm over those buffers, to see which variant is fastest. Any space + // allocated should be deallocated before this function returns. virtual StatusOr> RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* executor) = 0; + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for execution on a device given by the executor, // and returns an executable object or an error status. No HLO passes are @@ -112,21 +129,27 @@ class Compiler { // The compiler may optionally specialize to the individual device // (not just type of device) indicated by the executor. // + // device_allocator is optional; see RunHloPasses. + // // Use the overload below to compile computations that run in parallel. virtual StatusOr> RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* executor) = 0; + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator) = 0; // Compiles a set of HLO modules that can run in parallel, potentially // communicating data between the modules, and returns a corresponding // sequence of executable objects. // + // device_allocator is optional; see RunHloPasses. + // // TODO(b/68666782): Remove this method after adding support for multiple // modules to RunHloPasses and RunBackends. virtual StatusOr>> Compile( std::vector> modules, std::vector> - stream_exec) = 0; + stream_exec, + DeviceMemoryAllocator* device_allocator) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. diff --git a/tensorflow/compiler/xla/service/copy_insertion_test.cc b/tensorflow/compiler/xla/service/copy_insertion_test.cc index 3278fd5f064902459ded4d9367b5390cf8a63f27..153f062d015e49db11c4c9ae0a2a61e76c020f02 100644 --- a/tensorflow/compiler/xla/service/copy_insertion_test.cc +++ b/tensorflow/compiler/xla/service/copy_insertion_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" @@ -339,7 +340,7 @@ TEST_F(CopyInsertionTest, ElementOfNestedTupleParameter) { ShapeUtil::MakeShape(F32, {42})}), "param0")); - // The return value of the computation is the zero-th elemnt of the nested + // The return value of the computation is the zero-th element of the nested // tuple. This element is itself a tuple. auto gte = builder.AddInstruction(HloInstruction::CreateGetTupleElement( ShapeUtil::GetSubshape(param->shape(), {0}), param, 0)); @@ -1723,8 +1724,242 @@ void BM_ParallelWhiles(int num_iters, int num_whiles) { } } +std::unique_ptr MakeBenchmarkWhileBody( + const int num_tuple_inputs) { + auto builder = HloComputation::Builder("benchmark_loop_body"); + const Shape element_shape = ShapeUtil::MakeShape(F32, {}); + std::vector input_shape(num_tuple_inputs, element_shape); + const Shape loop_state_shape = ShapeUtil::MakeTupleShape(input_shape); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_state_shape, "loop_state")); + std::vector gte_nodes(num_tuple_inputs); + for (int i = 0; i < num_tuple_inputs; ++i) { + gte_nodes[i] = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(element_shape, param, i)); + } + builder.AddInstruction(HloInstruction::CreateTuple(gte_nodes)); + return builder.Build(); +} + +void BM_ManyElementTuple(int num_iters, const int num_tuple_inputs) { + tensorflow::testing::StopTiming(); + HloModuleConfig config; + config.set_debug_options(legacy_flags::GetDebugOptionsFromFlags()); + CopyInsertion copy_insertion; + const Shape element_shape = ShapeUtil::MakeShape(F32, {}); + std::vector tuple_params(num_tuple_inputs); + for (int i = 0; i < num_iters; ++i) { + auto builder = HloComputation::Builder("BM_ParallelWhiles"); + HloModule module("BM_ManyElementTuple", VersionedComputationHandle(), + config); + for (int j = 0; j < num_tuple_inputs; ++j) { + tuple_params[j] = builder.AddInstruction( + HloInstruction::CreateParameter(j, element_shape, "")); + } + HloInstruction* init = + builder.AddInstruction(HloInstruction::CreateTuple(tuple_params)); + HloComputation* condition = + module.AddEmbeddedComputation(MakeTrivialCondition(init->shape())); + HloComputation* body = + module.AddEmbeddedComputation(MakeBenchmarkWhileBody(num_tuple_inputs)); + HloInstruction* xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(init->shape(), condition, body, init)); + builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(F32, {}), xla_while, 0)); + module.AddEntryComputation(builder.Build()); + tensorflow::testing::StartTiming(); + ASSERT_IS_OK(copy_insertion.Run(&module).status()); + tensorflow::testing::StopTiming(); + } +} + BENCHMARK(BM_SequentialWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); BENCHMARK(BM_ParallelWhiles)->Arg(512)->Arg(1024)->Arg(2048)->Arg(4096); +BENCHMARK(BM_ManyElementTuple)->Arg(1024)->Arg(12288); + +TEST_F(CopyInsertionTest, SimpleControlFlowTest) { + const string& hlo_string = R"( +HloModule TestModule + +if-body.v5 { + constant.3 = s32[] constant(-1) + p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 + get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 + get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 + add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) + tuple.33 = (s32[]) tuple(add.3) + ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) +} + +if-condition.v4 { + p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 + constant.4 = s32[] constant(0) + ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) +} + +_functionalize_body_1__.v28 { + arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.68 = s32[] get-tuple-element(arg_tuple.4), index=0 + constant.7 = s32[] constant(1) + add.4 = s32[] add(get-tuple-element.68, constant.7) + get-tuple-element.69 = s32[] get-tuple-element(arg_tuple.4), index=1 + get-tuple-element.70 = s32[] get-tuple-element(arg_tuple.4), index=2 + less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.69, get-tuple-element.70) + constant.8 = s32[] constant(0) + select = s32[] select(less-than-or-equal-to, constant.8, constant.7) + get-tuple-element.71 = s32[] get-tuple-element(arg_tuple.4), index=3 + tuple.35 = (s32[], s32[], s32[]) tuple(get-tuple-element.69, get-tuple-element.71, get-tuple-element.70) + tuple.36 = (s32[]) tuple(constant.8) + tuple.37 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.35, tuple.36) + while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.37), condition=if-condition.v4, body=if-body.v5 + get-tuple-element.72 = (s32[]) get-tuple-element(while), index=2 + get-tuple-element.73 = s32[] get-tuple-element(get-tuple-element.72), index=0 + ROOT tuple.38 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.69, get-tuple-element.70, get-tuple-element.73) +} + +cond_wrapper.v3.1 { + inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.75 = s32[] get-tuple-element(inputs.1), index=0 + constant.11 = s32[] constant(7) + ROOT less-than.2 = pred[] less-than(get-tuple-element.75, constant.11) +} + +_functionalize_body_2__.v25 { + arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.76 = s32[] get-tuple-element(arg_tuple.5), index=0 + get-tuple-element.77 = s32[] get-tuple-element(arg_tuple.5), index=2 + get-tuple-element.78 = s32[] get-tuple-element(arg_tuple.5), index=3 + get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=4 + tuple.39 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.76, get-tuple-element.77, get-tuple-element.78, get-tuple-element.79) + while.2 = (s32[], s32[], s32[], s32[]) while(tuple.39), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 + get-tuple-element.80 = s32[] get-tuple-element(while.2), index=0 + get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=1 + constant.12 = s32[] constant(1) + add.5 = s32[] add(get-tuple-element.81, constant.12) + get-tuple-element.82 = s32[] get-tuple-element(while.2), index=3 + ROOT tuple.40 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.80, add.5, get-tuple-element.77, get-tuple-element.78, get-tuple-element.82) +} + +cond_wrapper.v3.2 { + inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.83 = s32[] get-tuple-element(inputs.2), index=1 + constant.13 = s32[] constant(5) + ROOT less-than.3 = pred[] less-than(get-tuple-element.83, constant.13) +} + +ENTRY TestComputation { + arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 +} +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + InsertCopies(module.get()); +} + +TEST_F(CopyInsertionTest, ControlFlowTest) { + const string& hlo_string = R"( +HloModule TestModule + +if-body.v5 { + constant.3 = s32[] constant(-1) + p.1 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.18 = (s32[], s32[], s32[]) get-tuple-element(p.1), index=1 + get-tuple-element.65 = s32[] get-tuple-element(get-tuple-element.18), index=0 + get-tuple-element.66 = s32[] get-tuple-element(get-tuple-element.18), index=1 + add.3 = s32[] add(get-tuple-element.65, get-tuple-element.66) + tuple.33 = (s32[]) tuple(add.3) + ROOT tuple.34 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.3, get-tuple-element.18, tuple.33) +} + +if-condition.v4 { + p.2 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.67 = s32[] get-tuple-element(p.2), index=0 + constant.4 = s32[] constant(0) + ROOT equal-to = pred[] equal-to(get-tuple-element.67, constant.4) +} + +if-body.v5.1 { + constant.5 = s32[] constant(-1) + p.3 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.68 = (s32[], s32[], s32[]) get-tuple-element(p.3), index=1 + get-tuple-element.70 = s32[] get-tuple-element(get-tuple-element.68), index=2 + multiply.1 = s32[] multiply(get-tuple-element.70, get-tuple-element.70) + tuple.35 = (s32[]) tuple(multiply.1) + ROOT tuple.36 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(constant.5, get-tuple-element.68, tuple.35) +} + +if-condition.v4.1 { + p.4 = (s32[], (s32[], s32[], s32[]), (s32[])) parameter(0) + get-tuple-element.71 = s32[] get-tuple-element(p.4), index=0 + constant.6 = s32[] constant(1) + ROOT equal-to.1 = pred[] equal-to(get-tuple-element.71, constant.6) +} + +_functionalize_body_1__.v28 { + arg_tuple.4 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.72 = s32[] get-tuple-element(arg_tuple.4), index=0 + constant.7 = s32[] constant(1) + add.4 = s32[] add(get-tuple-element.72, constant.7) + get-tuple-element.73 = s32[] get-tuple-element(arg_tuple.4), index=1 + get-tuple-element.74 = s32[] get-tuple-element(arg_tuple.4), index=2 + less-than-or-equal-to = pred[] less-than-or-equal-to(get-tuple-element.73, get-tuple-element.74) + constant.8 = s32[] constant(0) + select = s32[] select(less-than-or-equal-to, constant.8, constant.7) + get-tuple-element.75 = s32[] get-tuple-element(arg_tuple.4), index=3 + tuple.37 = (s32[], s32[], s32[]) tuple(get-tuple-element.73, get-tuple-element.75, get-tuple-element.74) + tuple.38 = (s32[]) tuple(constant.8) + tuple.39 = (s32[], (s32[], s32[], s32[]), (s32[])) tuple(select, tuple.37, tuple.38) + while = (s32[], (s32[], s32[], s32[]), (s32[])) while(tuple.39), condition=if-condition.v4, body=if-body.v5 + while.1 = (s32[], (s32[], s32[], s32[]), (s32[])) while(while), condition=if-condition.v4.1, body=if-body.v5.1 + get-tuple-element.76 = (s32[]) get-tuple-element(while.1), index=2 + get-tuple-element.77 = s32[] get-tuple-element(get-tuple-element.76), index=0 + ROOT tuple.40 = (s32[], s32[], s32[], s32[]) tuple(add.4, get-tuple-element.73, get-tuple-element.74, get-tuple-element.77) +} + +cond_wrapper.v3.1 { + inputs.1 = (s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.78 = s32[] get-tuple-element(inputs.1), index=0 + constant.11 = s32[] constant(7) + ROOT less-than.2 = pred[] less-than(get-tuple-element.78, constant.11) +} + +_functionalize_body_2__.v25 { + arg_tuple.5 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.79 = s32[] get-tuple-element(arg_tuple.5), index=0 + get-tuple-element.80 = s32[] get-tuple-element(arg_tuple.5), index=2 + get-tuple-element.81 = s32[] get-tuple-element(arg_tuple.5), index=3 + get-tuple-element.82 = s32[] get-tuple-element(arg_tuple.5), index=4 + tuple.41 = (s32[], s32[], s32[], s32[]) tuple(get-tuple-element.79, get-tuple-element.80, get-tuple-element.81, get-tuple-element.82) + while.2 = (s32[], s32[], s32[], s32[]) while(tuple.41), condition=cond_wrapper.v3.1, body=_functionalize_body_1__.v28 + get-tuple-element.83 = s32[] get-tuple-element(while.2), index=0 + get-tuple-element.84 = s32[] get-tuple-element(arg_tuple.5), index=1 + constant.12 = s32[] constant(1) + add.5 = s32[] add(get-tuple-element.84, constant.12) + get-tuple-element.85 = s32[] get-tuple-element(while.2), index=3 + ROOT tuple.42 = (s32[], s32[], s32[], s32[], s32[]) tuple(get-tuple-element.83, add.5, get-tuple-element.80, get-tuple-element.81, get-tuple-element.85) +} + +cond_wrapper.v3.2 { + inputs.2 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + get-tuple-element.86 = s32[] get-tuple-element(inputs.2), index=1 + constant.13 = s32[] constant(5) + ROOT less-than.3 = pred[] less-than(get-tuple-element.86, constant.13) +} + +ENTRY TestComputation { + arg_tuple.6 = (s32[], s32[], s32[], s32[], s32[]) parameter(0) + ROOT while.3 = (s32[], s32[], s32[], s32[], s32[]) while(arg_tuple.6), condition=cond_wrapper.v3.2, body=_functionalize_body_2__.v25 +} +)"; + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + auto module = module_or_status.ConsumeValueOrDie(); + InsertCopies(module.get()); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index e1eed498f6adfdae9df1dbf183f7c0505afd4ea2..c13a0b1cdf0b5be0b69db98b2b9587f30ca4c304 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -81,14 +81,15 @@ cc_library( ":conv_canonicalization", ":cpu_copy_insertion", ":cpu_executable", + ":cpu_hlo_support_checker", ":cpu_instruction_fusion", + ":cpu_layout_assignment", ":cpu_options", ":cpu_parallelization_preparation", ":disassembler", ":dot_op_emitter", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":parallel_cpu_executable", ":parallel_task_assignment", ":simple_orc_jit", @@ -100,16 +101,18 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:algebraic_simplifier", - "//tensorflow/compiler/xla/service:batchnorm_rewriter", + "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_ordering", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", @@ -124,7 +127,9 @@ cc_library( "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_invariant_code_motion", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep "//tensorflow/core:lib", # fixdeps: keep "//tensorflow/core:stream_executor_no_cuda", @@ -135,8 +140,6 @@ cc_library( "@llvm//:core", "@llvm//:mc", # fixdeps: keep "@llvm//:object", - "@llvm//:powerpc_code_gen", # fixdeps: keep - "@llvm//:powerpc_disassembler", # fixdeps: keep "@llvm//:support", "@llvm//:target", # fixdeps: keep "@llvm//:x86_code_gen", # fixdeps: keep @@ -147,19 +150,21 @@ cc_library( cc_library( name = "simple_orc_jit", - srcs = ["simple_orc_jit.cc"], + srcs = [ + "simple_orc_jit.cc", + "windows_compatibility.cc", + "windows_compatibility.h", + ], hdrs = ["simple_orc_jit.h"], deps = [ ":compiler_functor", ":cpu_runtime", - ":cpu_runtime_avx", - ":cpu_runtime_neon", - ":cpu_runtime_sse4_1", ":custom_call_target_registry", ":disassembler", ":external_constant_pool", ":orc_jit_memory_mapper", ":runtime_conv2d", + ":runtime_fft", ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", @@ -250,8 +255,11 @@ cc_library( ":dot_op_emitter", ":external_constant_pool", ":ir_emission_utils", + ":ir_function", + ":parallel_loop_emitter", ":shape_partition", ":simple_orc_jit", + ":target_machine_features", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -280,6 +288,54 @@ cc_library( ], ) +cc_library( + name = "target_machine_features", + srcs = [ + "target_machine_features.cc", + ], + hdrs = ["target_machine_features.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/core:lib", + "@llvm//:analysis", + "@llvm//:target", + ], +) + +cc_library( + name = "ir_function", + srcs = ["ir_function.cc"], + hdrs = ["ir_function.h"], + deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + +cc_library( + name = "parallel_loop_emitter", + srcs = ["parallel_loop_emitter.cc"], + hdrs = ["parallel_loop_emitter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:ir_array", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", + "//tensorflow/core:lib", + "@llvm//:core", + ], +) + cc_library( name = "dot_op_emitter", srcs = ["dot_op_emitter.cc"], @@ -287,6 +343,8 @@ cc_library( deps = [ ":cpu_options", ":cpu_runtime", + ":target_machine_features", + ":vector_support_library", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:types", @@ -298,7 +356,6 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:kernel_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library", "//tensorflow/core:lib", "@llvm//:core", ], @@ -336,7 +393,6 @@ cc_library( "@llvm//:mc", "@llvm//:mc_disassembler", "@llvm//:object", - "@llvm//:powerpc_disassembler", # fixdeps: keep "@llvm//:support", "@llvm//:target", "@llvm//:x86_disassembler", # fixdeps: keep @@ -349,9 +405,6 @@ cc_library( hdrs = ["compiler_functor.h"], deps = [ ":cpu_runtime", - ":cpu_runtime_avx", - ":cpu_runtime_neon", - ":cpu_runtime_sse4_1", ":disassembler", ":llvm_ir_runtime", "//tensorflow/compiler/xla:statusor", @@ -371,43 +424,6 @@ cc_library( ], ) -cc_library( - name = "cpu_runtime_sse4_1", - srcs = ["cpu_runtime_sse4_1.cc"], - hdrs = ["cpu_runtime_sse4_1.h"], - copts = ["-DEIGEN_AVOID_STL_ARRAY"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - -cc_library( - name = "cpu_runtime_avx", - srcs = ["cpu_runtime_avx.cc"], - hdrs = ["cpu_runtime_avx.h"], - copts = ["-DEIGEN_AVOID_STL_ARRAY"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - -cc_library( - name = "cpu_runtime_neon", - srcs = ["cpu_runtime_neon.cc"], - hdrs = ["cpu_runtime_neon.h"], - # runtime_copts() enables -mfpu=neon - copts = ["-DEIGEN_AVOID_STL_ARRAY"] + runtime_copts(), - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework_lite", - "//third_party/eigen3", - ], -) - cc_library( name = "cpu_runtime", srcs = [ @@ -438,6 +454,7 @@ cc_library( "llvm_ir_runtime.h", ], deps = [ + ":vector_support_library", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:lib", "@llvm//:core", @@ -462,6 +479,24 @@ cc_library( ], ) +cc_library( + name = "runtime_fft", + srcs = [ + "runtime_fft.cc", + "runtime_fft_impl.h", + ], + hdrs = ["runtime_fft.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//third_party/eigen3", + ], +) + cc_library( name = "runtime_matvec", srcs = ["runtime_matvec.cc"], @@ -615,13 +650,14 @@ cc_library( "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:window_util", "//tensorflow/compiler/xla/service:hlo", + "@llvm//:core", ], ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "cpu_layout_assignment", + srcs = ["cpu_layout_assignment.cc"], + hdrs = ["cpu_layout_assignment.h"], deps = [ ":dot_op_emitter", ":ir_emission_utils", @@ -633,11 +669,11 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", + name = "cpu_layout_assignment_test", size = "small", - srcs = ["layout_assignment_test.cc"], + srcs = ["cpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":cpu_layout_assignment", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", @@ -763,6 +799,21 @@ cc_library( ], ) +cc_library( + name = "vector_support_library", + srcs = ["vector_support_library.cc"], + hdrs = ["vector_support_library.h"], + deps = [ + ":target_machine_features", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", + "@llvm//:core", + "@llvm//:support", + ], +) + tf_cc_test( name = "cpu_copy_insertion_test", srcs = ["cpu_copy_insertion_test.cc"], @@ -783,6 +834,32 @@ tf_cc_test( ], ) +cc_library( + name = "cpu_hlo_support_checker", + srcs = ["cpu_hlo_support_checker.cc"], + hdrs = ["cpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "cpu_hlo_support_checker_test", + srcs = ["cpu_hlo_support_checker_test.cc"], + deps = [ + ":cpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc index 04b4a8c5c80eeefdbe10001ba5c462affbc9b21d..ed290fcdf8bb69f1bbad57fa5a0926376bc9405a 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.cc +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.cc @@ -37,9 +37,6 @@ limitations under the License. #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -50,15 +47,6 @@ limitations under the License. namespace xla { namespace cpu { -/* static */ CompilerFunctor::VectorIntrinsics -CompilerFunctor::AllIntrinsics() { - VectorIntrinsics intrinsics; - intrinsics.sse_intrinsics = true; - intrinsics.avx_intrinsics = true; - intrinsics.neon_intrinsics = true; - return intrinsics; -} - /* Create filtered versions of the LLVM Pass Managers to filter out some of the expensive passes. Profiling: @@ -192,89 +180,28 @@ operator()(llvm::Module& module) const { std::move(object_file), std::move(memory_buffer)); } -namespace { -// Returns the set of vectorized library functions supported for the target. -std::vector VectorFunctionsForTargetLibraryInfoImpl( - llvm::Triple::ArchType arch, llvm::StringRef feature_string, - CompilerFunctor::VectorIntrinsics const& available_intrinsics) { - std::vector vector_functions; - - const llvm::VecDesc four_wide_vector_functions_neon[] = { - {"expf", runtime::kExpV4F32NEONSymbolName, 4}, - {"llvm.exp.f32", runtime::kExpV4F32NEONSymbolName, 4}, - - {"logf", runtime::kLogV4F32NEONSymbolName, 4}, - {"llvm.log.f32", runtime::kLogV4F32NEONSymbolName, 4}, - }; - - const llvm::VecDesc four_wide_vector_functions_sse[] = { - {"expf", runtime::kExpV4F32SSESymbolName, 4}, - {"llvm.exp.f32", runtime::kExpV4F32SSESymbolName, 4}, - - {"logf", runtime::kLogV4F32SSESymbolName, 4}, - {"llvm.log.f32", runtime::kLogV4F32SSESymbolName, 4}, - }; - - const llvm::VecDesc eight_wide_vector_functions_avx[] = { - {"expf", runtime::kExpV8F32AVXSymbolName, 8}, - {"llvm.exp.f32", runtime::kExpV8F32AVXSymbolName, 8}, - - {"logf", runtime::kLogV8F32AVXSymbolName, 8}, - {"llvm.log.f32", runtime::kLogV8F32AVXSymbolName, 8}, - }; - - // These functions are generated by XLA as LLVM IR, so they're always - // available. - const llvm::VecDesc ir_vector_functions[] = { +static std::vector VectorFunctionsForTargetLibraryInfoImpl() { + std::vector result = { {"tanhf", runtime::kTanhV4F32SymbolName, 4}, {"llvm.tanh.f32", runtime::kTanhV4F32SymbolName, 4}, {"tanhf", runtime::kTanhV8F32SymbolName, 8}, {"llvm.tanh.f32", runtime::kTanhV8F32SymbolName, 8}, - }; - llvm::SmallVector features; - feature_string.split(features, ',', -1, /*KeepEmpty=*/false); - auto has_feature = [&features](const llvm::StringRef feature) { - return std::find(features.begin(), features.end(), feature) != - features.end(); - }; + {"expf", runtime::kExpV4F32SymbolName, 4}, + {"llvm.exp.f32", runtime::kExpV4F32SymbolName, 4}, - switch (arch) { - case llvm::Triple::x86: - case llvm::Triple::x86_64: { - if (has_feature("+sse4.1") && available_intrinsics.sse_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(four_wide_vector_functions_sse), - std::end(four_wide_vector_functions_sse)); - } - if (has_feature("+avx") && available_intrinsics.avx_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(eight_wide_vector_functions_avx), - std::end(eight_wide_vector_functions_avx)); - } - break; - } - case llvm::Triple::arm: - case llvm::Triple::aarch64: { - if (has_feature("+neon") && available_intrinsics.neon_intrinsics) { - vector_functions.insert(vector_functions.end(), - std::begin(four_wide_vector_functions_neon), - std::end(four_wide_vector_functions_neon)); - } - break; - } - default: - break; - } + {"expf", runtime::kExpV8F32SymbolName, 8}, + {"llvm.exp.f32", runtime::kExpV8F32SymbolName, 8}, - vector_functions.insert(vector_functions.end(), - std::begin(ir_vector_functions), - std::end(ir_vector_functions)); + {"logf", runtime::kLogV4F32SymbolName, 4}, + {"llvm.log.f32", runtime::kLogV4F32SymbolName, 4}, - return vector_functions; + {"logf", runtime::kLogV8F32SymbolName, 8}, + {"llvm.log.f32", runtime::kLogV8F32SymbolName, 8}, + }; + return result; } -} // namespace void CompilerFunctor::AddTargetInfoPasses( llvm::legacy::PassManagerBase* passes) const { @@ -282,9 +209,7 @@ void CompilerFunctor::AddTargetInfoPasses( auto target_library_info_impl = MakeUnique(target_triple); target_library_info_impl->addVectorizableFunctions( - VectorFunctionsForTargetLibraryInfoImpl( - target_triple.getArch(), target_machine_->getTargetFeatureString(), - available_intrinsics_)); + VectorFunctionsForTargetLibraryInfoImpl()); passes->add( new llvm::TargetLibraryInfoWrapperPass(*target_library_info_impl)); passes->add(createTargetTransformInfoWrapperPass( diff --git a/tensorflow/compiler/xla/service/cpu/compiler_functor.h b/tensorflow/compiler/xla/service/cpu/compiler_functor.h index 8cdd049e7b773bdc455db627ff1749997d621ee4..1a8283a702223a7414c1ffcd99c1ac42c04ac068 100644 --- a/tensorflow/compiler/xla/service/cpu/compiler_functor.h +++ b/tensorflow/compiler/xla/service/cpu/compiler_functor.h @@ -31,21 +31,10 @@ namespace cpu { // Orc JIT compile layer. class CompilerFunctor { public: - // Describes the set of vector intrinsics available to the generated code. - struct VectorIntrinsics { - bool sse_intrinsics; - bool avx_intrinsics; - bool neon_intrinsics; - }; - - // Returns a VectorIntrinsics where all intrinsics are available. - static VectorIntrinsics AllIntrinsics(); - explicit CompilerFunctor( llvm::TargetMachine* target_machine, const Disassembler* disassembler, int opt_level, bool optimize_for_size, bool enable_fast_math, bool disable_expensive_passes, - const VectorIntrinsics& available_intrinsics, LLVMCompiler::ModuleHook pre_optimization_hook = nullptr, LLVMCompiler::ModuleHook post_optimization_hook = nullptr) : target_machine_(target_machine), @@ -54,7 +43,6 @@ class CompilerFunctor { optimize_for_size_(optimize_for_size), enable_fast_math_(enable_fast_math), disable_expensive_passes_(disable_expensive_passes), - available_intrinsics_(available_intrinsics), pre_optimization_hook_(pre_optimization_hook), post_optimization_hook_(post_optimization_hook) {} @@ -78,7 +66,6 @@ class CompilerFunctor { const bool optimize_for_size_; const bool enable_fast_math_; const bool disable_expensive_passes_; - const VectorIntrinsics available_intrinsics_; LLVMCompiler::ModuleHook pre_optimization_hook_; LLVMCompiler::ModuleHook post_optimization_hook_; }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index addd7284c593f3dcdd86b1745f9aef7b6a1c30c6..f9cc9651846cca7bd6ab7e9e61590cec4e2400da 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -31,6 +31,7 @@ limitations under the License. #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "llvm/Object/ObjectFile.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/TargetRegistry.h" @@ -42,7 +43,7 @@ limitations under the License. #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" @@ -50,24 +51,27 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/conv_canonicalization.h" #include "tensorflow/compiler/xla/service/cpu/cpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" #include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" #include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_ordering.h" @@ -83,7 +87,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -149,11 +155,6 @@ CpuCompiler::CpuCompiler() { LLVMInitializeAArch64TargetMC(); LLVMInitializeAArch64AsmPrinter(); LLVMInitializeAArch64Disassembler(); - LLVMInitializePowerPCTarget(); - LLVMInitializePowerPCTargetInfo(); - LLVMInitializePowerPCTargetMC(); - LLVMInitializePowerPCAsmPrinter(); - LLVMInitializePowerPCDisassembler(); } namespace { @@ -166,42 +167,16 @@ namespace { // first module is compiled. std::once_flag llvm_command_line_options_initialized; -void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { - auto options = config.debug_options().xla_backend_extra_options(); - if (!options.empty()) { - std::vector fake_argv_storage; - fake_argv_storage.push_back(""); - for (const auto& it : options) { - // Skip options the XLA backend itself consumes. - if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { - if (it.second.empty()) { - fake_argv_storage.push_back(it.first); - } else { - fake_argv_storage.push_back(it.first + "=" + it.second); - } - } - } - - VLOG(2) << "Passing argv to LLVM:"; - std::vector fake_argv; - for (const auto& s : fake_argv_storage) { - fake_argv.push_back(s.c_str()); - VLOG(2) << s; - } - llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); - } -} - // This visitor records which HLO instructions should have profiling information // recorded. class CollectProfileCandidates : public DfsHloVisitorWithDefault { public: - static StatusOr> + static StatusOr> GetCandidatesForComputation( HloComputation* computation, const std::unordered_map& assigned_indices) { - std::unordered_map hlo_to_profile_idx; + std::unordered_map hlo_to_profile_idx; CollectProfileCandidates profile_candidates_for_computation( &hlo_to_profile_idx, assigned_indices); TF_RETURN_IF_ERROR( @@ -211,7 +186,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { private: CollectProfileCandidates( - std::unordered_map* hlo_to_profile_idx, + std::unordered_map* hlo_to_profile_idx, const std::unordered_map& assigned_indices) : hlo_to_profile_idx_(hlo_to_profile_idx), assigned_indices_(assigned_indices) {} @@ -251,7 +226,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { return Status::OK(); } - std::unordered_map* hlo_to_profile_idx_; + std::unordered_map* hlo_to_profile_idx_; const std::unordered_map& assigned_indices_; }; } // namespace @@ -259,7 +234,8 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); - pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, module->config().debug_options(), @@ -272,14 +248,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // TODO(b/65775800): Fix wrong output bug in Call and remove the CallInliner // pass. pipeline.AddPass(); - + pipeline.AddPass(); pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(ShapeSizeBytesFunction()); + pass.AddInvariantChecker(); - pass.AddPass( + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, @@ -288,6 +264,12 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, /*enable_dot_strength_reduction=*/false); + + // BatchNormExpander can create zero-sized ops, so zero-sized HLO + // elimination has to come after that pass. + pipeline.AddPass(); + + pass.AddPass(); pass.AddPass(); pass.AddPass(); pass.AddPass(); @@ -318,6 +300,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { [](const Shape&, const Shape&) { return true; }, /*enable_dot_strength_reduction=*/false); pipeline.AddPass(/*is_layout_sensitive=*/true); + pipeline.AddPass(BF16, F32); // Outline ops in the entry computation into calls to subcomputations. const int max_parallelism = module->config().intra_op_parallelism_threads() > 0 @@ -435,11 +418,27 @@ Status InitializeModuleHooks( return Status::OK(); } +Status VerifyLlvmModule(const llvm::Module& llvm_module) { + XLA_SCOPED_LOGGING_TIMER("CpuCompiler - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + return Status::OK(); +} + } // namespace StatusOr> CpuCompiler::RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* /*stream_exec*/) { + perftools::gputools::StreamExecutor* /*stream_exec*/, + DeviceMemoryAllocator* /*device_allocator*/) { VLOG(2) << "Before optimization:"; XLA_VLOG_LINES(2, module->ToString()); @@ -452,7 +451,8 @@ StatusOr> CpuCompiler::RunHloPasses( StatusOr> CpuCompiler::RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) { + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* /*device_allocator*/) { const string timer_message = "Compiling [" + module->name() + "] for CPU using JIT"; XLA_SCOPED_LOGGING_TIMER(timer_message); @@ -460,7 +460,7 @@ StatusOr> CpuCompiler::RunBackend( VLOG(1) << "Compiling: " << module->name(); TF_RET_CHECK(stream_exec != nullptr); std::call_once(llvm_command_line_options_initialized, - &InitializeLLVMCommandLineOptions, module->config()); + &llvm_ir::InitializeLLVMCommandLineOptions, module->config()); ModuleHook pre_optimization_ir_hook; ModuleHook post_optimization_ir_hook; @@ -483,17 +483,19 @@ StatusOr> CpuCompiler::RunBackend( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - HloComputation* computation = module->entry_computation(); - std::unordered_map hlo_to_profile_idx; + HloComputation* entry_computation = module->entry_computation(); + std::unordered_map instruction_to_profile_idx; + std::unordered_map computation_to_profile_idx; std::unique_ptr hlo_profile_index_map; - std::unique_ptr hlo_profile_printer; + std::unique_ptr hlo_profile_printer_data; if (module->config().hlo_profiling_enabled()) { hlo_profile_index_map = MakeUnique(*module); TF_ASSIGN_OR_RETURN( - hlo_to_profile_idx, + instruction_to_profile_idx, CollectProfileCandidates::GetCandidatesForComputation( - computation, hlo_profile_index_map->instruction_to_profile_idx())); + entry_computation, + hlo_profile_index_map->instruction_to_profile_idx())); auto shape_size_bytes = [](const Shape& shape) { // On the cpu, opaques are pointers. @@ -504,8 +506,11 @@ StatusOr> CpuCompiler::RunBackend( }; HloCostAnalysis cost_analysis(shape_size_bytes); - hlo_profile_printer = - CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis); + TF_RETURN_IF_ERROR(entry_computation->Accept(&cost_analysis)); + hlo_profile_printer_data = + CreateHloProfilePrinterData(*hlo_profile_index_map, cost_analysis); + computation_to_profile_idx = + hlo_profile_index_map->computation_to_profile_idx(); } std::unique_ptr cpu_executable; @@ -514,8 +519,8 @@ StatusOr> CpuCompiler::RunBackend( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); if (options::CpuParallelBackendRequested(module->config())) { VLOG(1) << "Using parallel cpu backend"; @@ -528,17 +533,17 @@ StatusOr> CpuCompiler::RunBackend( // uses data dependencies for determining order. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run(module.get(), - xla::MakeUnique(module.get()), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run( + module.get(), xla::MakeUnique(module.get()), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!xla_dump_hlo_proto_to.empty()) { + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -546,7 +551,7 @@ StatusOr> CpuCompiler::RunBackend( std::map parallel_computations; std::unordered_map> aligned_constants; - for (auto instruction : computation->MakeInstructionPostOrder()) { + for (auto instruction : entry_computation->MakeInstructionPostOrder()) { // Parameters and constants don't get their own computation. if (instruction->opcode() == HloOpcode::kParameter) { continue; @@ -554,7 +559,7 @@ StatusOr> CpuCompiler::RunBackend( if (instruction->opcode() == HloOpcode::kConstant) { // Copy the constant out of the ProtocolBuffer so that we can give it a // higher alignment. - const void* data = instruction->literal().InternalData(); + const void* data = instruction->literal().untyped_data(); int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape()); auto iter = aligned_constants.emplace( instruction, xla::MakeUnique(size)); @@ -571,22 +576,15 @@ StatusOr> CpuCompiler::RunBackend( parallel_computations.emplace(to_apply, instruction); } - // We always profile the entire computation as a whole, even if hlo - // profiling is disabled. When hlo profiling is diabled, we pass in a - // profile counter array of just one element, which corresponds to the whole - // computation. - size_t entry_computation_profile_idx = - hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( - *module->entry_computation()) - : 0; IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - hlo_to_profile_idx, entry_computation_profile_idx, + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), jit->target_machine(), jit->external_constant_pool()); std::unique_ptr> function_names( new HloInstructionMap()); for (auto embedded_computation : - computation->MakeEmbeddedComputationsList()) { + entry_computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { continue; } @@ -600,7 +598,7 @@ StatusOr> CpuCompiler::RunBackend( llvm::Function * ir_function, ir_emitter.EmitComputation( embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/computation_is_parallel, + /*is_top_level_computation=*/computation_is_parallel, /*instruction_order=*/nullptr)); // If this computation is parallel, remember it in the function name map. // This way we know what function to execute when we try to run code for @@ -616,13 +614,14 @@ StatusOr> CpuCompiler::RunBackend( if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new ParallelCpuExecutable( std::move(jit), std::move(assignment), std::move(module), std::move(function_names), std::move(aligned_constants), - std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); + std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -642,27 +641,19 @@ StatusOr> CpuCompiler::RunBackend( // temporary buffers are required to run the computation. TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, - BufferAssigner::Run( - module.get(), - xla::MakeUnique(module.get(), module_sequence), - BufferSizeBytesFunction(), memory_alignment)); + BufferAssigner::Run(module.get(), + xla::MakeUnique( + module.get(), module_sequence), + BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!xla_dump_hlo_proto_to.empty()) { + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } - // We always profile the entire computation as a whole, even if hlo - // profiling is disabled. When hlo profiling is diabled, we pass in a - // profile counter array of just one element, which corresponds to the whole - // computation. - size_t entry_computation_profile_idx = - hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor( - *module->entry_computation()) - : 0; // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from @@ -670,11 +661,12 @@ StatusOr> CpuCompiler::RunBackend( // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - hlo_to_profile_idx, entry_computation_profile_idx, + std::move(instruction_to_profile_idx), + std::move(computation_to_profile_idx), jit->target_machine(), jit->external_constant_pool()); for (auto embedded_computation : - computation->MakeEmbeddedComputationsList()) { + entry_computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { continue; } @@ -682,29 +674,33 @@ StatusOr> CpuCompiler::RunBackend( ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } - string function_name_prefix = - computation->name().empty() ? "__compute" : computation->name(); + string function_name_prefix = entry_computation->name().empty() + ? "__compute" + : entry_computation->name(); TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, - ir_emitter.EmitComputation(computation, function_name_prefix, - /*is_entry_computation=*/true, - &module_sequence.at(computation))); + ir_emitter.EmitComputation(entry_computation, function_name_prefix, + /*is_top_level_computation=*/true, + &module_sequence.at(entry_computation))); string function_name = llvm_ir::AsString(entry_function->getName()); string ir_module_string; if (embed_ir_in_executable) { ir_module_string = llvm_ir::DumpModuleToString(*llvm_module); } + TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module)); + + XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module)); // JIT compile the LLVM IR module to in-memory machine code. jit->AddModule(std::move(llvm_module)); cpu_executable.reset(new CpuExecutable( std::move(jit), std::move(assignment), std::move(module), function_name, - std::move(hlo_profile_printer), std::move(hlo_profile_index_map))); + std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map))); if (embed_ir_in_executable) { static_cast(*cpu_executable) @@ -721,7 +717,8 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, const AotCompilationOptions& aot_options) { TF_RET_CHECK(!modules.empty()); std::call_once(llvm_command_line_options_initialized, - &InitializeLLVMCommandLineOptions, modules[0]->config()); + &llvm_ir::InitializeLLVMCommandLineOptions, + modules[0]->config()); // We can pass just one llvm::TargetOptions when we compile the LLVM module, // so we bail if the configs have conflicting flags. At the moment, the only @@ -824,27 +821,28 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( std::unique_ptr assignment, BufferAssigner::Run( - module, xla::MakeUnique(module, module_sequence), + module, + xla::MakeUnique(module, module_sequence), BufferSizeBytesFunction(), memory_alignment)); // BufferAssignment::ToString() includes a header, so no need for us to // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); - if (!xla_dump_hlo_proto_to.empty()) { + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } - IrEmitter ir_emitter( - *module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/ - std::unordered_map{}, - /*entry_computation_profile_idx=*/tensorflow::gtl::nullopt, - target_machine.get(), - /*external_constant_pool=*/nullptr); + IrEmitter ir_emitter(*module, *assignment, &llvm_module, + /*instruction_to_profile_idx=*/ + std::unordered_map{}, + /*computation_to_profile_idx=*/ + std::unordered_map{}, + target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -855,7 +853,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, ir_emitter .EmitComputation(embedded_computation, embedded_computation->name(), - /*is_entry_computation=*/false, + /*is_top_level_computation=*/false, &module_sequence.at(embedded_computation)) .status()); } @@ -863,7 +861,7 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, TF_ASSIGN_OR_RETURN( llvm::Function * entry_function, ir_emitter.EmitComputation(computation, entry_point_name, - /*is_entry_computation=*/true, + /*is_top_level_computation=*/true, &module_sequence.at(computation))); CHECK(entry_function->getName() == llvm_ir::AsStringRef(entry_point_name)); @@ -874,14 +872,23 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, *module, user_pre_optimization_hook_, user_post_optimization_hook_, &pre_optimization_ir_dump_hook, &post_optimization_ir_dump_hook)); + // Run the LLVM verifier over the unoptimized LLVM IR. If it fails, run the + // pre-optimization IR dump hook before returning. + { + Status verify_status = VerifyLlvmModule(llvm_module); + if (!verify_status.ok() && pre_optimization_ir_dump_hook) { + pre_optimization_ir_dump_hook(llvm_module).IgnoreError(); + } + TF_RETURN_IF_ERROR(verify_status); + } + Disassembler disassembler(*target_machine); CompilerFunctor compiler_functor( target_machine.get(), &disassembler, opt_level, options::OptimizeForSizeRequested(module->config()), module->config().debug_options().xla_enable_fast_math(), module->config().debug_options().xla_llvm_disable_expensive_passes(), - CompilerFunctor::AllIntrinsics(), pre_optimization_ir_dump_hook, - post_optimization_ir_dump_hook); + pre_optimization_ir_dump_hook, post_optimization_ir_dump_hook); llvm::object::OwningBinary object_file = compiler_functor(llvm_module); llvm::StringRef object_file_data_ref = object_file.getBinary()->getData(); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index ebed7058d8f7968c6e03ef90d0da6b2325037eb0..3498139ab95d21383c6dc008ae5614b7bfe91148 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -118,11 +118,13 @@ class CpuCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> modules, diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index e956f478b86d9816615e2902f5bbeae6d6384162..802d0a6fb46890b31d14b1fbf3b2e7d6520caccc 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -55,9 +55,9 @@ CpuExecutable::CpuExecutable( std::unique_ptr assignment, std::unique_ptr hlo_module, const string& entry_function_name, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), jit_(std::move(jit)), assignment_(std::move(assignment)) { @@ -73,28 +73,6 @@ CpuExecutable::CpuExecutable( reinterpret_cast(cantFail(sym.getAddress())); } -// Given a pointer to an output buffer (following the CPU JIT calling -// conventions), mark addresses that are "live". The initial pointer itself is -// trivially live. If the shape of the buffer is a tuple, this analysis looks -// into the tuple's elements and marks them live as well (since tuples keep -// pointers to buffers) and also works recursively. address is an in-memory -// buffer address that contains some runtime XLA object. shape is its -// shape. marked_addresses is the set of live addresses to populate. -static void MarkLiveAddressesInOutput( - const void* address, const Shape& shape, - std::unordered_set* marked_addresses) { - marked_addresses->insert(address); - const uintptr_t* address_buffer = static_cast(address); - if (ShapeUtil::IsTuple(shape)) { - for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { - const uintptr_t* element_address = address_buffer + i; - const void* element = reinterpret_cast(*element_address); - MarkLiveAddressesInOutput( - element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses); - } - } -} - Status CpuExecutable::AllocateBuffers( DeviceMemoryAllocator* memory_allocator, int device_ordinal, std::vector* buffers) { @@ -148,20 +126,6 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers; - argument_buffers.reserve(arguments.size()); - for (const auto* argument : arguments) { - argument_buffers.push_back(argument->buffer(/*index=*/{})); - } - return ExecuteComputeFunction(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status CpuExecutable::ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // The calling convention for JITed functions is: // // void function(void* result, const void* run_options, void** args_array, @@ -177,23 +141,19 @@ Status CpuExecutable::ExecuteComputeFunction( // determined by buffer analysis. // std::vector args_array; - for (se::DeviceMemoryBase arg_mem : arguments) { - args_array.push_back(arg_mem.opaque()); + for (const ShapedBuffer* argument : arguments) { + args_array.push_back(argument->root_buffer().opaque()); } uint64 start_micros = tensorflow::Env::Default()->NowMicros(); - // Allocate profiling counters for each hlo instruction that we would like to - // profile. Even when not Hlo profiling, we allocate a counter for the entire - // computation, which we use to update ExecutionProfile below. - std::vector* profile_counters = nullptr; - std::vector profile_counter_for_entry_computation; - if (hlo_execution_profile) { - profile_counters = hlo_execution_profile->mutable_profile_counters(); - } else { - profile_counters = &profile_counter_for_entry_computation; - profile_counter_for_entry_computation.push_back(0); - } + size_t profile_counters_size = + hlo_execution_profile ? hlo_execution_profile->profile_counters().size() + : 0; + int64* profile_counters = + hlo_execution_profile + ? hlo_execution_profile->mutable_profile_counters()->data() + : nullptr; // Call the computation function following the calling convention. std::vector buffer_pointers; @@ -208,7 +168,7 @@ Status CpuExecutable::ExecuteComputeFunction( VLOG(3) << tensorflow::strings::Printf( " func(void* result, void* params[%zu], void* temps[%zu], " "uint64 profile_counters[%zu])", - args_array.size(), buffer_pointers.size(), profile_counters->size()); + args_array.size(), buffer_pointers.size(), profile_counters_size); VLOG(3) << tensorflow::strings::Printf(" result = %p", result_buffer); auto ptr_printer = [](string* out, const void* p) { tensorflow::strings::StrAppend(out, tensorflow::strings::Printf("%p", p)); @@ -220,11 +180,11 @@ Status CpuExecutable::ExecuteComputeFunction( " temps = [%s]", tensorflow::str_util::Join(buffer_pointers, ", ", ptr_printer).c_str()); VLOG(3) << tensorflow::strings::Printf(" profile_counters = %p", - profile_counters->data()); + profile_counters); } compute_function_(result_buffer, run_options, args_array.data(), - buffer_pointers.data(), profile_counters->data()); + buffer_pointers.data(), profile_counters); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -232,13 +192,11 @@ Status CpuExecutable::ExecuteComputeFunction( tensorflow::mutex_lock lock(mutex_); const double nanoseconds = (end_micros - start_micros) * 1000.0; execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - + // If hlo profiling was disabled then the cycle count is left empty. if (hlo_execution_profile) { execution_profile_.set_compute_cycle_count( hlo_execution_profile->total_cycles_executed( *module().entry_computation())); - } else { - execution_profile_.set_compute_cycle_count(profile_counters->back()); } } @@ -246,11 +204,23 @@ Status CpuExecutable::ExecuteComputeFunction( } static void LogLiveAddresses( - const std::unordered_set& marked_addresses) { + tensorflow::gtl::ArraySlice buffers, + const std::vector& buffers_in_result) { + if (!VLOG_IS_ON(3)) { + return; + } + + CHECK_EQ(buffers.size(), buffers_in_result.size()); + std::vector live_out_buffers; + for (int i = 0; i < buffers.size(); ++i) { + if (buffers_in_result[i]) { + live_out_buffers.push_back(buffers[i].opaque()); + } + } VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" + << live_out_buffers.size() << " addresses:\n" << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { + live_out_buffers, ", ", [](string* out, const void* address) { tensorflow::strings::StrAppend( out, tensorflow::strings::Printf("%p", address)); }); @@ -259,13 +229,12 @@ static void LogLiveAddresses( static Status DeallocateTempBuffers( DeviceMemoryAllocator* allocator, se::Stream* stream, tensorflow::gtl::ArraySlice buffers, - const std::unordered_set& marked_addresses) { - // Keep those marked live because they are referenced by the output of the - // computation and are needed by the service. They will be deallocated by the - // service. + const std::vector& buffers_in_result) { + // Keep those buffers in the output of the marked live because they are needed + // by the service. They will be deallocated by the service. for (size_t i = 0; i < buffers.size(); ++i) { se::DeviceMemoryBase alloc = buffers[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && !alloc.is_null()) { + if (!buffers_in_result[i] && !alloc.is_null()) { VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" << alloc.opaque() << "]"; TF_RETURN_IF_ERROR( @@ -276,33 +245,43 @@ static Status DeallocateTempBuffers( return Status::OK(); } -StatusOr CpuExecutable::ExecuteOnStream( +StatusOr> CpuExecutable::CreateResultShapedBuffer( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result) { se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - std::vector buffers(assignment_->Allocations().size()); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); - TF_RETURN_IF_ERROR(AllocateBuffers( - memory_allocator, stream->parent()->device_ordinal(), &buffers)); - TF_RETURN_IF_ERROR(ExecuteComputeFunction( - &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; - MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), - &marked_addresses); - - LogLiveAddresses(marked_addresses); - TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, - marked_addresses)); - - return top_level_output; + // Copy DeviceMemoryBase values which contain the array(s) of the result into + // the respective location in ShapedBuffer which is returned to the caller. + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as + // a tuple element. The source instruction should have a + // non-parameter buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = allocated_buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + (*buffers_in_result)[buffer_index] = true; + return Status::OK(); + })); + return std::move(result_buffer); } StatusOr> CpuExecutable::ExecuteOnStream( @@ -317,67 +296,26 @@ StatusOr> CpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); - TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); TF_RETURN_IF_ERROR(ExecuteComputeFunction( &run_options->run_options(), arguments, buffers, hlo_execution_profile)); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); // Free all buffers not in the result. - for (size_t i = 0; i < buffers.size(); ++i) { - se::DeviceMemoryBase alloc = buffers[i]; - if (!buffers_in_result[i] && !alloc.is_null()) { - VLOG(3) << "CpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate( - stream->parent()->device_ordinal(), &alloc)); - } - } + TF_RETURN_IF_ERROR(DeallocateTempBuffers(memory_allocator, stream, buffers, + buffers_in_result)); return std::move(result_buffer); } -StatusOr -CpuExecutable::ExecuteAsyncOnStream( +StatusOr> CpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { if (hlo_profiling_enabled()) { return Unimplemented( "Asynchronous execution on stream with hlo profiling is not yet " @@ -393,29 +331,25 @@ CpuExecutable::ExecuteAsyncOnStream( TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase top_level_output = buffers[result_slice.index()]; - MarkLiveAddressesInOutput(top_level_output.opaque(), result_shape(), - &marked_addresses); + std::vector buffers_in_result(assignment_->Allocations().size(), false); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_buffer, + CreateResultShapedBuffer(run_options, buffers, &buffers_in_result)); - LogLiveAddresses(marked_addresses); + LogLiveAddresses(buffers, buffers_in_result); host_stream->EnqueueTask([this, run_options, arguments, buffers, - marked_addresses, memory_allocator, stream]() { + buffers_in_result, memory_allocator, stream]() { // Failing a CHECK here is not great, but I don't see an obvious way to // return a failed Status asynchronously. TF_CHECK_OK(ExecuteComputeFunction(&run_options->run_options(), arguments, buffers, /*hlo_execution_profile=*/nullptr)); TF_CHECK_OK(DeallocateTempBuffers(memory_allocator, stream, buffers, - marked_addresses)); + buffers_in_result)); }); - return top_level_output; + return std::move(result_buffer); } /*static*/ int64 CpuExecutable::ShapeSizeBytes(const Shape& shape) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index 17ee2d673ee7cde1847bf29e2399e6033cb7e30e..267b89a10b3c038dc2048f0ad5b5b343c88ef0f9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -51,25 +51,18 @@ class CpuExecutable : public Executable { std::unique_ptr assignment, std::unique_ptr hlo_module, const string& entry_function_name, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); ~CpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -108,13 +101,6 @@ class CpuExecutable : public Executable { // Calls the generated function performing the computation with the given // arguments using the supplied buffers. - Status ExecuteComputeFunction( - const ExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunction( const ExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -122,6 +108,18 @@ class CpuExecutable : public Executable { buffers, HloExecutionProfile* hlo_execution_profile); + // Create a ShapedBuffer for holding the result of the computation. The + // addresses (DeviceMemoryBases) are set according to buffer assignment. + // 'buffers_in_result' should point to a vector of the same size as + // 'allocated_buffers'. An element in buffers_in_result is set to true if the + // corresponding buffer is live out of the computation (and thus contained in + // the returned ShapedBuffer). + StatusOr> CreateResultShapedBuffer( + const ServiceExecutableRunOptions* run_options, + tensorflow::gtl::ArraySlice + allocated_buffers, + std::vector* buffers_in_result); + // Returns the points-to set of the root instruction of the entry // computation. Uses points-to analysis from buffer assignment. const PointsToSet& GetRootPointsToSet() const; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bd4741a04b1135d9780e0cf765b7b33378526e1 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr CpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "CPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..2924b6365943f0a3ec998d7a77767a76cbb576ae --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// This pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the CPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class CpuHloSupportChecker : public HloPassInterface { + public: + CpuHloSupportChecker() = default; + ~CpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "cpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0f463e6de623fc6ab43d685ff2a5d6882ba7b8a2 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/cpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class CpuHloSupportCheckerTest : public HloTestBase { + protected: + CpuHloSupportChecker& checker() { return checker_; } + + private: + CpuHloSupportChecker checker_; +}; + +TEST_F(CpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(CpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("CPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index f87ee3cecd932faac140636a3db7cd4aa0371b85..482e04052d5a914eab0e5bff2c7a83f3b698052f 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -26,7 +26,7 @@ int64 BytesInDimension(const Shape& shape, int64 dimension) { shape.dimensions(dimension); } -bool IsFusile(const HloInstruction& hlo) { +bool CanBeLoopFused(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. return hlo.IsElementwise() || // @@ -42,6 +42,23 @@ bool IsFusile(const HloInstruction& hlo) { hlo.opcode() == HloOpcode::kTranspose; } +bool IsMatrixVectorDot(const HloInstruction* hlo) { + const Shape& hlo_shape = hlo->shape(); + return hlo->opcode() == HloOpcode::kDot && hlo_shape.dimensions_size() == 2 && + (hlo_shape.dimensions(0) == 1 || hlo_shape.dimensions(1) == 1); +} + +bool CanBeOutputFused(const HloInstruction* producer, + const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && IsMatrixVectorDot(producer) && + producer->user_count() == 1; +} + +bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) { + return consumer->opcode() == HloOpcode::kAdd && + (CanBeOutputFused(consumer->operand(0), consumer) || + CanBeOutputFused(consumer->operand(1), consumer)); +} } // namespace bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, @@ -52,7 +69,15 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, constexpr int kFusionThresholdBytes = 16 * 1024; - if (!IsFusile(*producer)) { + if (CanBeOutputFused(producer, consumer)) { + return true; + } + + if (CanBeOutputFusedIntoSomeOperand(producer)) { + return false; + } + + if (!CanBeLoopFused(*producer)) { VLOG(2) << "Producer is not fusile."; return false; } @@ -108,16 +133,13 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, } } - if (consumer->opcode() == HloOpcode::kFusion) { - // InstructionFusion::ShouldFuse above only allows kLoop and kInput fusions. - // The CPU backend does not create kInput fusions, so we only expect to see - // kLoop here. - CHECK(consumer->fusion_kind() == HloInstruction::FusionKind::kLoop); + if (consumer->opcode() == HloOpcode::kFusion && + consumer->fusion_kind() == HloInstruction::FusionKind::kLoop) { VLOG(2) << "Fusing: consumer is a fusion node."; return true; } - if (IsFusile(*consumer)) { + if (CanBeLoopFused(*consumer)) { VLOG(2) << "Fusing: consumer is elementwise or fusile."; return true; } @@ -126,5 +148,11 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return false; } +HloInstruction::FusionKind CpuInstructionFusion::ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) { + return CanBeOutputFused(producer, consumer) + ? HloInstruction::FusionKind::kOutput + : HloInstruction::FusionKind::kLoop; +} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h index 0eca4c3473e1454fe5dbd8bf855b4418cf553a94..07aff34974e0cfa6c7a129f82017b280fb1ccd59 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h @@ -30,6 +30,8 @@ class CpuInstructionFusion : public InstructionFusion { protected: bool ShouldFuse(HloInstruction* consumer, int64 operand_index) override; + HloInstruction::FusionKind ChooseKind( + const HloInstruction* producer, const HloInstruction* consumer) override; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index b9e4d006d77ae76e33ac51440349400ea4eff118..595c3f55b321f47e2312b93e0c238c7637495d77 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -31,6 +31,14 @@ namespace { using InstructionFusionTest = HloTestBase; +std::unique_ptr MakeDot(const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums); +} + TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloComputation::Builder builder(TestName()); HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter( @@ -40,8 +48,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) { HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, exp0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -59,8 +67,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -80,8 +88,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) { ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0)); HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, bitcast0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -102,8 +110,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) { HloInstruction* reshape0 = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(S32, {1024, 256}), exp0)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, reshape0, arg1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -121,8 +129,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 32 * 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -140,8 +148,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) { HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary( ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 1024}), HloOpcode::kDot, arg0, exp1)); + HloInstruction* dot = builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -162,8 +170,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0})); - builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, transpose1)); + builder.AddInstruction( + MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1)); auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); @@ -188,7 +196,9 @@ class OpcodeFusionTest : public InstructionFusionTest { // Runs CPU instruction fusion on the given module, and tests that the result // contains a fused op at the root with exactly the given multiset of opcodes. void RunFusionAndCheckOpcodesWereFused( - HloModule* module, const std::multiset& expected_opcodes) { + HloModule* module, const std::multiset& expected_opcodes, + HloInstruction::FusionKind fusion_kind = + HloInstruction::FusionKind::kLoop) { auto computation = module->entry_computation(); auto did_fusion = CpuInstructionFusion().Run(module); ASSERT_TRUE(did_fusion.ok()); @@ -196,7 +206,7 @@ class OpcodeFusionTest : public InstructionFusionTest { HloInstruction* root = computation->root_instruction(); ASSERT_THAT(root, op::Fusion()); - EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); + EXPECT_EQ(root->fusion_kind(), fusion_kind); std::vector fused_opcodes(root->fused_instruction_count()); std::transform(root->fused_instructions().begin(), @@ -608,6 +618,88 @@ TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { Not(op::Fusion())); } +void CreateComputationForDotAddOutputFusionTest(const string& test_name, + HloModule* module, int m, int k, + int n, + bool add_extra_use_for_dot) { + HloComputation::Builder builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + auto* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + auto* dot_rhs = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_rhs_shape, "param1")); + auto* addend = builder.AddInstruction( + HloInstruction::CreateParameter(2, dot_shape, "param2")); + + auto* dot = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + builder.AddInstruction( + HloInstruction::CreateBinary(dot_shape, HloOpcode::kAdd, dot, addend)); + + if (add_extra_use_for_dot) { + builder.AddInstruction( + HloInstruction::CreateOutfeed(dot_shape, dot, "no_config")); + } + + module->AddEntryComputation(builder.Build()); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_1x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/1, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/false); + + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kDot, HloOpcode::kAdd, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter}, + HloInstruction::FusionKind::kOutput); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x19) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/19, + /*add_extra_use_for_dot=*/false); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + +TEST_F(OpcodeFusionTest, DotAddOutputFusion_19x50x1_multi_use) { + auto module = CreateNewModule(); + CreateComputationForDotAddOutputFusionTest(TestName(), module.get(), /*m=*/19, + /*k=*/50, /*n=*/1, + /*add_extra_use_for_dot=*/true); + + TF_ASSERT_OK_AND_ASSIGN(bool fused_something, + CpuInstructionFusion().Run(module.get())); + EXPECT_FALSE(fused_something); + EXPECT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc similarity index 55% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc index 3f2d101959db50d9f775097f01d5a2ba25a0da8c..e8117377e61a4e21b8c45b929c518a18878fcb60 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include @@ -25,58 +25,77 @@ limitations under the License. namespace xla { namespace cpu { -Status CpuLayoutAssignment::AddBackendConstraints( - LayoutConstraints* constraints) { - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - auto col_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.begin(), dimension_order.end(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - // We want to change the layout of constant arrays to be column major when all - // of their users are dot operations that can be made faster with the flipped - // layout. To avoid going quadriatic over the # of instructions, we cache - // this property in should_make_rhs_col_major -- it maps a constant to true if - // all of the users of said constant are dot operations that can be sped up. - // This cache is populated lazily as we encounter dot operations traversing - // the instruction stream. - tensorflow::gtl::FlatMap - should_make_rhs_col_major_cache; - auto should_make_rhs_col_major = [&](const HloInstruction& instruction) { - if (ProfitableToImplementDotInUntiledLlvmIr(instruction) != - DotInLlvmIrProfitable::kWithColumnMajorRhs) { +// We want to change the layout of constant arrays to be column major when all +// of their users are dot operations that can be made faster with the flipped +// layout. To avoid going quadriatic over the # of instructions, we cache this +// property in should_make_rhs_col_major -- it maps a constant to true if all of +// the users of said constant are dot operations that can be sped up. This +// cache is populated lazily as we encounter dot operations traversing the +// instruction stream. + +namespace { +using ::tensorflow::gtl::nullopt; +using ::tensorflow::gtl::optional; + +using ShouldMakeOperandColMajorCache = + tensorflow::gtl::FlatMap; +} // namespace + +static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) { + for (auto* user : instruction->users()) { + optional operand_idx = ProfitableToMakeDotOperandColumnMajor(*user); + if (!operand_idx || user->operand(*operand_idx) != instruction || + std::count(user->operands().begin(), user->operands().end(), + instruction) != 1) { return false; } + } + return true; +} - const auto* rhs = instruction.operand(1); - if (rhs->opcode() != HloOpcode::kConstant) { - return false; - } +static optional ShouldMakeOperandColumnMajor( + ShouldMakeOperandColMajorCache* cache, const HloInstruction& instruction) { + optional operand_idx = + ProfitableToMakeDotOperandColumnMajor(instruction); + if (!operand_idx) { + return nullopt; + } - auto it = should_make_rhs_col_major_cache.find(rhs); - if (it != should_make_rhs_col_major_cache.end()) { - return it->second; - } + const HloInstruction* operand = instruction.operand(*operand_idx); + if (operand->opcode() != HloOpcode::kConstant) { + return nullopt; + } - bool result = std::all_of( - rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) { - return ProfitableToImplementDotInUntiledLlvmIr(*user) == - DotInLlvmIrProfitable::kWithColumnMajorRhs && - user->operand(0) != rhs; - }); + auto it = cache->find(operand); + if (it == cache->end()) { + auto insert_result = + cache->insert({operand, ShouldMakeAllUsersColMajor(operand)}); + CHECK(insert_result.second); + it = insert_result.first; + } - InsertOrDie(&should_make_rhs_col_major_cache, rhs, result); - return result; - }; + return it->second ? operand_idx : nullopt; +} + +static Shape RowMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +static Shape ColMajorShape(const Shape& old_shape) { + Shape new_shape(old_shape); + std::vector dimension_order(new_shape.dimensions_size()); + std::iota(dimension_order.begin(), dimension_order.end(), 0); + *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); + return new_shape; +} + +Status CpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + ShouldMakeOperandColMajorCache cache; const HloComputation* computation = constraints->computation(); for (auto* instruction : computation->instructions()) { @@ -91,9 +110,9 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(convolution->shape())); - Shape input_shape(row_major_shape(lhs_instruction->shape())); - Shape filter_shape(row_major_shape(rhs_instruction->shape())); + Shape output_shape(RowMajorShape(convolution->shape())); + Shape input_shape(RowMajorShape(lhs_instruction->shape())); + Shape filter_shape(RowMajorShape(rhs_instruction->shape())); // Set layouts of the instructions' shapes. TF_RETURN_IF_ERROR( @@ -102,11 +121,11 @@ Status CpuLayoutAssignment::AddBackendConstraints( constraints->SetOperandLayout(filter_shape, convolution, 1)); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, convolution)); - } else if (should_make_rhs_col_major(*instruction)) { - auto* dot = instruction; - const auto& rhs_shape = dot->operand(1)->shape(); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1)); + } else if (optional op_idx = + ShouldMakeOperandColumnMajor(&cache, *instruction)) { + const HloInstruction* op = instruction->operand(*op_idx); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + ColMajorShape(op->shape()), instruction, *op_idx)); } else if (PotentiallyImplementedAsEigenDot(*instruction)) { const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, @@ -114,17 +133,17 @@ Status CpuLayoutAssignment::AddBackendConstraints( // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. - Shape output_shape(row_major_shape(dot->shape())); + Shape output_shape(RowMajorShape(dot->shape())); const HloInstruction* lhs_instruction = dot->operand(0); - Shape lhs_shape(row_major_shape(lhs_instruction->shape())); + Shape lhs_shape(RowMajorShape(lhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); // dot is a kDot or a kTransposeDot fusion node. In the latter case, if // it represents X @ X, it may have just one operand. if (dot->operand_count() > 1) { const HloInstruction* rhs_instruction = dot->operand(1); - Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + Shape rhs_shape(RowMajorShape(rhs_instruction->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); } @@ -141,8 +160,12 @@ Status CpuLayoutAssignment::AddBackendConstraints( if (constraints->OperandBufferForwarded(instruction, operand_no)) { continue; } + // Skip operands with non-array shapes. + if (!ShapeUtil::IsArray(instruction->operand(operand_no)->shape())) { + continue; + } Shape operand_shape( - row_major_shape(instruction->operand(operand_no)->shape())); + RowMajorShape(instruction->operand(operand_no)->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( operand_shape, instruction, operand_no)); } diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.h b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h similarity index 86% rename from tensorflow/compiler/xla/service/cpu/layout_assignment.h rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h index 4fd8d68dd6b4f2a8b16f6c048743a996ea76a560..c8edbb9e15a5b6f9c574f5fe9d130d149499ebd2 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -38,4 +38,4 @@ class CpuLayoutAssignment : public LayoutAssignment { } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc similarity index 54% rename from tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc rename to tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc index 1ea5e8c7fc4896512e62396d0a756cda44785f11..6ba030fff3bbc5f413bfb133114ceb5309b77672 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h" #include #include @@ -40,6 +40,8 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +namespace op = xla::testing::opcode_matchers; + namespace xla { namespace { @@ -61,8 +63,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensor) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -98,10 +100,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor0) { HloInstruction::CreateParameter(1, lhs_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_b_lhs, dot_rhs)); builder.AddInstruction(HloInstruction::CreateBinary( result_shape, HloOpcode::kAdd, dot_a_result, dot_b_result)); @@ -142,10 +144,10 @@ TEST_F(CpuLayoutAssignmentTest, MultipleDotsWithSameConstantRhsTensor1) { HloInstruction::CreateParameter(1, lhs_b_shape, "param1")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateFromShape(rhs_shape))); - auto dot_a_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_a_shape, HloOpcode::kDot, dot_a_lhs, dot_rhs)); - auto dot_b_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_b_shape, HloOpcode::kDot, dot_b_lhs, dot_rhs)); + auto dot_a_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_a_shape, dot_a_lhs, dot_rhs)); + auto dot_b_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_b_shape, dot_b_lhs, dot_rhs)); auto tuple_result = builder.AddInstruction( HloInstruction::CreateTuple({dot_a_result, dot_b_result})); @@ -180,8 +182,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantLhsTensor) { HloInstruction::CreateConstant(Literal::CreateFromShape(lhs_shape))); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateParameter(0, rhs_shape, "param0")); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -220,8 +222,8 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { HloInstruction::CreateParameter(0, lhs_shape, "param0")); auto dot_rhs = builder.AddInstruction( HloInstruction::CreateGetTupleElement(rhs_shape, constant, 1)); - auto dot_result = builder.AddInstruction(HloInstruction::CreateBinary( - result_shape, HloOpcode::kDot, dot_lhs, dot_rhs)); + auto dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(result_shape, dot_lhs, dot_rhs)); auto module = CreateNewModule(); HloComputation* computation = module->AddEntryComputation(builder.Build()); @@ -241,5 +243,172 @@ TEST_F(CpuLayoutAssignmentTest, DotWithConstantRhsTensorThroughGTE) { EXPECT_NE(instruction->opcode(), HloOpcode::kCopy); } } + +struct DotOutputFusionLayoutAssignmentResult { + bool layout_assignment_changed_something; + const HloInstruction* dot_lhs_fusion_param; + const HloInstruction* dot_rhs_fusion_param; + const HloInstruction* addend_fusion_param; +}; + +static StatusOr RunDotOutputFusion( + HloModule* module, const string& test_name, int m, int k, int n, + const int64 dot_operand_idx_in_add) { + DotOutputFusionLayoutAssignmentResult result; + + CHECK(dot_operand_idx_in_add == 0 || dot_operand_idx_in_add == 1); + + auto builder = HloComputation::Builder(test_name); + + Shape dot_lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape dot_rhs_shape = ShapeUtil::MakeShape(F32, {k, n}); + Shape dot_shape = ShapeUtil::MakeShape(F32, {m, n}); + + HloInstruction* dot_lhs = builder.AddInstruction( + HloInstruction::CreateParameter(0, dot_lhs_shape, "param0")); + HloInstruction* addend = builder.AddInstruction( + HloInstruction::CreateParameter(1, dot_shape, "param1")); + HloInstruction* dot_rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateFromShape(dot_rhs_shape))); + HloInstruction* dot_result = builder.AddInstruction( + HloInstruction::CreateCanonicalDot(dot_shape, dot_lhs, dot_rhs)); + HloInstruction* add_result; + if (dot_operand_idx_in_add == 0) { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, dot_result, addend)); + } else { + add_result = builder.AddInstruction(HloInstruction::CreateBinary( + dot_shape, HloOpcode::kAdd, addend, dot_result)); + } + + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + HloInstruction* fusion_instruction = + module->entry_computation()->AddInstruction(HloInstruction::CreateFusion( + dot_shape, HloInstruction::FusionKind::kOutput, add_result)); + TF_RETURN_IF_ERROR( + computation->ReplaceInstruction(add_result, fusion_instruction)); + + HloInstruction* fused_add = + fusion_instruction->fused_instructions_computation()->root_instruction(); + HloInstruction* fused_dot = fusion_instruction->FuseInstruction(dot_result); + + TF_RETURN_IF_ERROR( + computation->RemoveInstructionAndUnusedOperands(dot_result)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_lhs_shape)); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + *computation_layout.mutable_result_layout() = + ShapeLayout(LayoutUtil::GetWithDefaultLayout(dot_shape)); + + result.dot_lhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(0)->parameter_number()); + result.dot_rhs_fusion_param = + fusion_instruction->operand(fused_dot->operand(1)->parameter_number()); + result.addend_fusion_param = fusion_instruction->operand( + fused_add->operand(1 - dot_operand_idx_in_add)->parameter_number()); + + cpu::CpuLayoutAssignment layout_assignment(&computation_layout); + TF_ASSIGN_OR_RETURN(result.layout_assignment_changed_something, + layout_assignment.Run(module)); + + return result; +} + +static void AssertCorrectLayoutForDotOutputFusion( + const HloComputation* computation, + const DotOutputFusionLayoutAssignmentResult& layout_assignment_result, + bool expect_col_major_dot_rhs) { + Layout expected_dot_rhs_layout = expect_col_major_dot_rhs + ? LayoutUtil::MakeLayout({0, 1}) + : LayoutUtil::MakeLayout({1, 0}); + EXPECT_TRUE(LayoutUtil::Equal( + expected_dot_rhs_layout, + layout_assignment_result.dot_rhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.dot_lhs_fusion_param->shape().layout())); + + EXPECT_TRUE(LayoutUtil::Equal( + LayoutUtil::MakeLayout({1, 0}), + layout_assignment_result.addend_fusion_param->shape().layout())); + EXPECT_THAT(computation->instructions(), Each(Not(op::Copy()))); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_1x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/1, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/true); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x1_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/1, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_0) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/0)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} + +TEST_F(CpuLayoutAssignmentTest, DotOutputFusion_19x50x19_dot_idx_1) { + std::unique_ptr module = CreateNewModule(); + TF_ASSERT_OK_AND_ASSIGN( + DotOutputFusionLayoutAssignmentResult layout_assignment_result, + RunDotOutputFusion(module.get(), TestName(), /*m=*/19, /*k=*/50, /*n=*/19, + /*dot_operand_idx_in_add=*/1)); + ASSERT_TRUE(layout_assignment_result.layout_assignment_changed_something); + AssertCorrectLayoutForDotOutputFusion(module->entry_computation(), + layout_assignment_result, + /*expect_col_major_dot_rhs=*/false); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index 7908dc173d79a4a9dcb6127ac344267e27d2b5f2..1ef45dbec39a0880ebb123ba3fcd1fd6c89eb39a 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -37,6 +37,7 @@ extern const char* const kEigenMatMulF64SymbolName = "__xla_cpu_runtime_EigenMatMulF64"; extern const char* const kEigenConvF32SymbolName = "__xla_cpu_runtime_EigenConvF32"; +extern const char* const kEigenFftSymbolName = "__xla_cpu_runtime_EigenFft"; extern const char* const kEigenSingleThreadedMatMulF32SymbolName = "__xla_cpu_runtime_EigenSingleThreadedMatMulF32"; extern const char* const kEigenSingleThreadedMatMulF64SymbolName = diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 2ade455b8a0a43dda8c93bbb79891439da2e4f75..3e1f08071119c938619d02777513e5b834077118 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -44,6 +44,7 @@ namespace runtime { extern const char* const kEigenMatMulF32SymbolName; extern const char* const kEigenMatMulF64SymbolName; extern const char* const kEigenConvF32SymbolName; +extern const char* const kEigenFftSymbolName; extern const char* const kEigenSingleThreadedMatMulF32SymbolName; extern const char* const kEigenSingleThreadedMatMulF64SymbolName; extern const char* const kEigenSingleThreadedConvF32SymbolName; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h deleted file mode 100644 index acfada8540d89bb098bb0b04e109441e2123e678..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h +++ /dev/null @@ -1,51 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context -// which is used to cache expensive state and resources utilized by the -// aforementioned functions. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ - -#include "tensorflow/core/platform/macros.h" - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kExpV8F32AVXSymbolName; -extern const char *const kLogV8F32AVXSymbolName; - -typedef float V8F32AVX __attribute__((__vector_size__(32))); -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_ExpV8F32AVX( - xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK; - -xla::cpu::runtime::V8F32AVX __xla_cpu_runtime_LogV8F32AVX( - xla::cpu::runtime::V8F32AVX x) TF_ATTRIBUTE_WEAK; -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_AVX_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc deleted file mode 100644 index abe792b2787ce8baf56ee62585a0ab886d922a23..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.cc +++ /dev/null @@ -1,46 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/Eigen/Core" - -#ifdef __ARM_NEON__ - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x) { - return Eigen::internal::pexp(x); -} - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::plog(p); -} - -#endif // __ARM_NEON__ - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kExpV4F32NEONSymbolName = "__xla_cpu_runtime_ExpV4F32NEON"; -const char *const kLogV4F32NEONSymbolName = "__xla_cpu_runtime_LogV4F32NEON"; - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h deleted file mode 100644 index 75cb16b273973d2bf665d378084343fd612a2941..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. - -#include "tensorflow/core/platform/macros.h" - -#ifdef __ARM_NEON__ -// For the other runtimes (AVX, SSE4.1) we define the vector type directly using -// __attribute__((__vector_size__(*))). Unfortunately, the typedef for the ARM -// NEON SIMD types is not portable, so the type has to come from -#include -#endif // __ARM_NEON__ - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kExpV4F32NEONSymbolName; -extern const char *const kLogV4F32NEONSymbolName; - -#ifdef __ARM_NEON__ -typedef float32x4_t V4F32NEON; -#else -// On non-ARM platforms ensure the declaration is present -struct V4F32NEON; -#endif // __ARM_NEON__ - -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_ExpV4F32NEON( - xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK; - -xla::cpu::runtime::V4F32NEON __xla_cpu_runtime_LogV4F32NEON( - xla::cpu::runtime::V4F32NEON x) TF_ATTRIBUTE_WEAK; -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_NEON_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc deleted file mode 100644 index a9a45db5a424d2faecbd437542c41fbd7fdf0bb8..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.cc +++ /dev/null @@ -1,47 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/Eigen/Core" - -#ifdef __SSE4_1__ - -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::pexp(p); -} - -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x) { - Eigen::internal::Packet4f p = x; - return Eigen::internal::plog(p); -} - -#endif // __SSE4_1__ - -namespace xla { -namespace cpu { -namespace runtime { - -const char *const kExpV4F32SSESymbolName = "__xla_cpu_runtime_ExpV4F32SSE"; -const char *const kLogV4F32SSESymbolName = "__xla_cpu_runtime_LogV4F32SSE"; - -} // namespace runtime -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h deleted file mode 100644 index 96587d10d2b86e14ff6a7400fdf14ca0d994ddc5..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h +++ /dev/null @@ -1,52 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -// This header declares functions which may be called by the generated code on -// the CPU. Calls to these functions must be resolved explicitly in the JIT in -// xla::cpu::SimpleResolver. It also defines a per-CpuExecutable context -// which is used to cache expensive state and resources utilized by the -// aforementioned functions. - -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ - -#include "tensorflow/core/platform/macros.h" - -namespace xla { -namespace cpu { -namespace runtime { - -extern const char *const kExpV4F32SSESymbolName; -extern const char *const kLogV4F32SSESymbolName; - -typedef float V4F32SSE __attribute__((__vector_size__(16))); - -} // namespace runtime -} // namespace cpu -} // namespace xla - -extern "C" { - -// The following functions are vectorized versions of a selection of libm -// library functions. -// References to these functions are created by the LLVM vectorizer. -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_ExpV4F32SSE( - xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK; - -xla::cpu::runtime::V4F32SSE __xla_cpu_runtime_LogV4F32SSE( - xla::cpu::runtime::V4F32SSE x) TF_ATTRIBUTE_WEAK; -} - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_RUNTIME_SSE4_1_H_ diff --git a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index b53719fcc260d706eab3d7460c42af4a1b5e775f..f5e61aef534da57ce13d3ee9bbeaeaec31f53d2e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -98,7 +98,7 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, if (!ShapeUtil::IsTuple(shape)) { int64 size = GetByteSizeRequirement(shape); - return TransferBufferToInfeed(executor, size, literal.InternalData()); + return TransferBufferToInfeed(executor, size, literal.untyped_data()); } if (ShapeUtil::IsNestedTuple(shape)) { @@ -111,20 +111,20 @@ Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, // enqueue the resulting destination device addresses with the // infeed manager. std::vector buffers; - buffers.reserve(literal.tuple_literals_size()); + buffers.reserve(ShapeUtil::TupleElementCount(shape)); auto cleanup = tensorflow::gtl::MakeCleanup([&buffers]() { for (cpu::runtime::XfeedBuffer* b : buffers) { b->Done(Cancelled("Failed to infeed buffer to device.")); } }); - for (const auto& tuple_element : literal.tuple_literals()) { - const Shape& tuple_element_shape = tuple_element.shape(); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& tuple_element_shape = ShapeUtil::GetSubshape(shape, {i}); int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); TF_ASSIGN_OR_RETURN( cpu::runtime::XfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, tuple_element_size, - tuple_element.InternalData())); + literal.untyped_data({i}))); buffers.push_back(buffer); } @@ -187,14 +187,14 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( literal_shape.element_type(), dimensions)); TF_ASSIGN_OR_RETURN(Shape received_shape, TransferArrayBufferFromOutfeed( - executor, literal->MutableInternalData(), size)); + executor, literal->untyped_data(), size)); TF_RET_CHECK(ShapeUtil::Compatible(received_shape, literal->shape())) << "Shape received from outfeed " << ShapeUtil::HumanString(received_shape) << " did not match the shape that was requested for outfeed: " << ShapeUtil::HumanString(literal_shape); TF_RET_CHECK(size == GetByteSizeRequirement(received_shape)); - *literal->mutable_shape() = received_shape; + *literal->mutable_shape_do_not_use() = received_shape; return Status::OK(); } @@ -217,7 +217,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( auto empty = Literal::CreateFromDimensions( tuple_element_shape.element_type(), dimensions); int64 size = GetByteSizeRequirement(tuple_element_shape); - buffer_data.push_back({empty->MutableInternalData(), size}); + buffer_data.push_back({empty->untyped_data(), size}); elements.push_back(std::move(empty)); } @@ -233,7 +233,7 @@ Status CpuTransferManager::TransferLiteralFromOutfeed( GetByteSizeRequirement(received_shape)); for (int64 i = 0; i < literal_shape.tuple_shapes_size(); ++i) { - *elements[i]->mutable_shape() = received_shape.tuple_shapes(i); + *elements[i]->mutable_shape_do_not_use() = received_shape.tuple_shapes(i); } *literal = std::move(*Literal::MakeTupleOwned(std::move(elements))); TF_RET_CHECK(ShapeUtil::Equal(literal->shape(), literal_shape)); diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h index 2994642356d55df26c31553ef28dc653503d05be..664125ecc95ca5ac10be4201b9120ddbdb9b9821 100644 --- a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ // This file is depended on by kernels that have to build for mobile devices. // For this reason, we avoid relying on TensorFlow and instead only use the @@ -71,4 +71,4 @@ class RegisterCustomCallTarget { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/disassembler.h b/tensorflow/compiler/xla/service/cpu/disassembler.h index b6feaa7e45cee26eb7f850081bd1fad2cb63b15c..5e302f88990ee4a3c37758881ecec4d6f71dd8e6 100644 --- a/tensorflow/compiler/xla/service/cpu/disassembler.h +++ b/tensorflow/compiler/xla/service/cpu/disassembler.h @@ -37,7 +37,7 @@ struct DisassemblerResult { DisassemblerResult(const string& text, size_t code_size_bytes) : text(text), code_size_bytes(code_size_bytes) {} - // The dissassembled text sections of the object file. + // The disassembled text sections of the object file. string text; // The total number of bytes of executable code in the object file. uint64_t code_size_bytes; @@ -53,7 +53,7 @@ class Disassembler { // Returns a DisassemblerResult for the given object file, containing the // disassembled code. // - // If we couldnt' retrieve a disassembler for this platform, an error status + // If we couldn't retrieve a disassembler for this platform, an error status // is returned. StatusOr DisassembleObjectFile( const llvm::object::ObjectFile& object_file) const; diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 4c40dae5122b0853a72d6428fc120220e3a69237..c9fc586b9a4c06eb9e1f111d8f9bd2f717990aab 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -23,10 +23,11 @@ limitations under the License. #include "llvm/IR/Module.h" #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/util.h" @@ -143,7 +144,8 @@ class ColumnMajorMatrixVectorProductEmitter { ColumnMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, - llvm::Value* rhs, llvm::Value* result, + llvm::Value* rhs, llvm::Value* addend, + llvm::Value* result, llvm::IRBuilder<>* ir_builder) : scalar_type_(scalar_type), tile_rows_(tile_rows), @@ -152,6 +154,7 @@ class ColumnMajorMatrixVectorProductEmitter { k_(k), lhs_(lhs), rhs_(rhs), + addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), @@ -173,7 +176,7 @@ class ColumnMajorMatrixVectorProductEmitter { } // Load a tile of values from the RHS. For the RHS a "tile" is a contiguous - // sequnce of `count` values, each one broadcasted to the vector width. + // sequence of `count` values, each one broadcasted to the vector width. std::vector LoadRhsTile(llvm::Value* offset, int64 count) { llvm::Value* base_pointer = vsl_.ComputeOffsetPointer(rhs_, offset); std::vector result; @@ -198,6 +201,7 @@ class ColumnMajorMatrixVectorProductEmitter { int64 k_; llvm::Value* lhs_; llvm::Value* rhs_; + llvm::Value* addend_; llvm::Value* result_; llvm::IRBuilder<>* ir_builder_; KernelSupportLibrary ksl_; @@ -242,9 +246,10 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopTiled( /*step=*/tile_rows_, [&](llvm::Value* row) { std::vector lhs_tile = lhs_tile_loader->LoadTile(/*minor_dim_offset=*/row); - llvm::Value* accumulator = is_first_column - ? vsl_.GetZeroVector() - : vsl_.LoadVector(result_, row); + llvm::Value* accumulator = + is_first_column ? (addend_ ? vsl_.LoadVector(addend_, row) + : vsl_.GetZeroVector()) + : vsl_.LoadVector(result_, row); for (int i = 0; i < columns; i++) { accumulator = vsl_.MulAdd(lhs_tile[i], rhs_tile[i], accumulator); } @@ -288,7 +293,18 @@ void ColumnMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( ir_builder_->getInt1(is_first_tiled_column)); ksl_.If( setting_result_first_time, - [&]() { vsl_.StoreScalar(product, result_, scalar_row); }, + /*true_block_generator=*/ + [&]() { + if (addend_) { + vsl_.StoreScalar( + vsl_.Add(vsl_.LoadScalar(addend_, scalar_row), + product), + result_, scalar_row); + } else { + vsl_.StoreScalar(product, result_, scalar_row); + } + }, + /*false_block_generator=*/ [&]() { vsl_.StoreScalar( vsl_.Add(vsl_.LoadScalar(result_, scalar_row), product), @@ -353,7 +369,7 @@ class RowMajorMatrixVectorProductEmitter { RowMajorMatrixVectorProductEmitter(PrimitiveType scalar_type, int64 tile_rows, int64 tile_cols, int64 m, int64 k, llvm::Value* lhs, llvm::Value* rhs, - llvm::Value* result, + llvm::Value* addend, llvm::Value* result, llvm::IRBuilder<>* ir_builder) : scalar_type_(scalar_type), tile_rows_(tile_rows), @@ -362,6 +378,7 @@ class RowMajorMatrixVectorProductEmitter { k_(k), lhs_(lhs), rhs_(rhs), + addend_(addend), result_(result), ir_builder_(ir_builder), ksl_(ir_builder_), @@ -394,6 +411,7 @@ class RowMajorMatrixVectorProductEmitter { int64 k_; llvm::Value* lhs_; llvm::Value* rhs_; + llvm::Value* addend_; llvm::Value* result_; llvm::IRBuilder<>* ir_builder_; KernelSupportLibrary ksl_; @@ -415,11 +433,32 @@ void RowMajorMatrixVectorProductEmitter::EmitOuterLoopBody(llvm::Value* row, EmitInnerLoopEpilogue(/*current_tile_row=*/row, /*rows=*/row_count, &scalar_accumulators); + std::vector accumulator_values; + std::transform( + vector_accumulators.begin(), vector_accumulators.end(), + std::back_inserter(accumulator_values), + [](const VectorVariable& vector_var) { return vector_var.Get(); }); + + std::vector horizontal_sums; + if (row_count == vsl_.vector_size()) { + if (addend_) { + horizontal_sums = vsl_.ComputeHorizontalSums( + std::move(accumulator_values), vsl_.LoadVector(addend_, row)); + } else { + horizontal_sums = + vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + } else { + horizontal_sums = vsl_.ComputeHorizontalSums(std::move(accumulator_values)); + } + for (int i = 0; i < row_count; i++) { llvm::Value* result_value = - vsl_.Add(vsl_.AddReduce(vector_accumulators[i].Get()), - scalar_accumulators[i].Get()); + vsl_.Add(horizontal_sums[i], scalar_accumulators[i].Get()); llvm::Value* offset = ir_builder_->CreateAdd(ir_builder_->getInt64(i), row); + if (addend_ && row_count != vsl_.vector_size()) { + result_value = vsl_.Add(vsl_.LoadScalar(addend_, offset), result_value); + } vsl_.StoreScalar(result_value, result_, offset); } } @@ -483,49 +522,52 @@ void RowMajorMatrixVectorProductEmitter::EmitInnerLoopEpilogue( } // namespace -DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, - bool transpose_rhs, - const llvm_ir::IrArray& target_array, - const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, - llvm::Value* executable_run_options_value, - llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config) +DotOpEmitter::DotOpEmitter( + const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, + const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, + llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) : dot_(dot), transpose_lhs_(transpose_lhs), transpose_rhs_(transpose_rhs), target_array_(target_array), lhs_array_(lhs_array), rhs_array_(rhs_array), + addend_array_(addend_array), executable_run_options_value_(executable_run_options_value), ir_builder_(ir_builder), - hlo_module_config_(hlo_module_config) {} + hlo_module_config_(hlo_module_config), + target_machine_features_(target_machine_features) {} /* static */ tensorflow::Status DotOpEmitter::EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config) { + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features) { PrimitiveType type = target_array.GetShape().element_type(); TF_RET_CHECK(F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, - lhs_array, rhs_array, executable_run_options_value, - ir_builder, hlo_module_config); + lhs_array, rhs_array, addend_array, + executable_run_options_value, ir_builder, + hlo_module_config, target_machine_features); return dot_emitter.Emit(); } bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; } bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { - if (dot_.shape().dimensions_size() != 2 || - ProfitableToImplementDotInUntiledLlvmIr(dot_) == - DotInLlvmIrProfitable::kYes) { + if (dot_.shape().dimensions_size() != 2) { return false; } - if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) && - !primitive_util::IsIntegralType(dot_.shape().element_type())) { + PrimitiveType primitive_type = dot_.shape().element_type(); + + if (!primitive_util::IsFloatingPointType(primitive_type) && + !primitive_util::IsIntegralType(primitive_type)) { return false; } @@ -575,30 +617,76 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() { int64 tiling_factor = GetGemvTilingFactor(); CHECK_GT(tiling_factor, 0); + llvm::Value* result_op = target_array_.GetBasePointer(); + llvm::Value* lhs_op = + swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer(); + llvm::Value* rhs_op = + swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer(); + + const bool enable_fast_math = + hlo_module_config_.debug_options().xla_enable_fast_math(); + const bool optimize_for_size = + options::OptimizeForSizeRequested(hlo_module_config_); + + const int target_vector_register_element_size = + target_machine_features_.vector_register_num_elements( + *ir_builder_->GetInsertBlock()->getParent(), primitive_type); + + // We may not always know the vector register size for the target we're + // compiling against, in which case target_vector_register_element_size is 0. + // In these cases we choose a default LLVM IR register size. + const int kUnknownTargetVectorRegisterSize = 4; + const int vector_register_element_size = + target_vector_register_element_size == 0 + ? kUnknownTargetVectorRegisterSize + : target_vector_register_element_size; + if (is_column_major_matrix_vector) { VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m << " and k = " << k; - ColumnMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/8, - /*tile_cols=*/tiling_factor, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = vector_register_element_size; + int64 tile_cols = tiling_factor; + + string kernel_name = tensorflow::strings::StrCat( + "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + ColumnMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); } else { VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m << " and k = " << k; - RowMajorMatrixVectorProductEmitter emitter( - dot_.shape().element_type(), /*tile_rows=*/tiling_factor, - /*tile_cols=*/8, m, k, - swap_operands ? rhs_array_.GetBasePointer() - : lhs_array_.GetBasePointer(), - swap_operands ? lhs_array_.GetBasePointer() - : rhs_array_.GetBasePointer(), - target_array_.GetBasePointer(), ir_builder_); - emitter.Emit(); + int64 tile_rows = tiling_factor; + int64 tile_cols = vector_register_element_size; + + string kernel_name = tensorflow::strings::StrCat( + "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows, + "_", tile_cols, "_", m, "_", k, addend_array_ ? "_with_addend" : ""); + + KernelSupportLibrary::EmitAndCallOutlinedKernel( + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, ir_builder_, kernel_name, + lhs_op, rhs_op, + addend_array_ ? addend_array_->GetBasePointer() : nullptr, result_op, + [this, tile_rows, tile_cols, m, k, primitive_type]( + llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* addend_op, + llvm::Value* result_op) { + RowMajorMatrixVectorProductEmitter emitter( + primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op, + addend_op, result_op, ir_builder_); + emitter.Emit(); + }); } return true; @@ -641,6 +729,8 @@ tensorflow::Status DotOpEmitter::Emit() { return Status::OK(); } + CHECK_EQ(addend_array_, nullptr); + if (PotentiallyImplementedAsEigenDot(dot_)) { return EmitCallToRuntime(); } @@ -915,8 +1005,8 @@ DotOpEmitter::MatMultDims DotOpEmitter::GetMatMultDims() const { return {lhs_shape.dimensions(transpose_lhs_ ? 1 : 0), lhs_shape.dimensions(transpose_lhs_ ? 0 : 1), rhs_shape.dimensions(transpose_rhs_ ? 0 : 1), - lhs_shape.layout().minor_to_major(0) == 0, - rhs_shape.layout().minor_to_major(0) == 0}; + LayoutUtil::Minor(lhs_shape.layout(), 0) == 0, + LayoutUtil::Minor(rhs_shape.layout(), 0) == 0}; } llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( @@ -927,8 +1017,8 @@ llvm_ir::IrArray::Index DotOpEmitter::EmitOperandArrayLoopNest( // reduction dimension. std::vector dimensions; const Shape& shape = operand_array.GetShape(); - for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape.layout().minor_to_major(i); + for (int i = LayoutUtil::MinorToMajor(shape).size() - 1; i >= 0; --i) { + int64 dimension = LayoutUtil::Minor(shape.layout(), i); if (dimension != reduction_dimension) { dimensions.push_back(dimension); } @@ -977,9 +1067,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } - if (ProfitableToImplementDotInUntiledLlvmIr(hlo) == - DotInLlvmIrProfitable::kYes || - ProfitableToImplementDotInTiledLlvmIr(hlo)) { + if (ProfitableToImplementDotInTiledLlvmIr(hlo)) { return false; } @@ -1010,46 +1098,42 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { return false; } -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot) { - if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) { - const Shape& result_shape = dot.shape(); - // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1 - // cache line size, so that we can have the reduction dimension of both the - // LHS and RHS matrices and still have some space "left over". This needs - // to be tuned further. - const int64 kReductionDimensionThresholdBytes = 8 * 1024; - const bool single_threaded_eigen = - !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen(); - - // This is the point at which it is better to call into Eigen and shard the - // dot across multiple worker threads. This is a rough estimate by running - // a matmult benchmark on my local machine, and it can be tuned further. - const int64 kMaxSingleThreadedFlops = 16 * 1024; - - const int64 M = result_shape.dimensions(0); - const int64 N = result_shape.dimensions(1); - const int64 K = dot.operand(1)->shape().dimensions(0); - const int64 primitive_type_size = - ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type()); - if (M == 1 && - K * primitive_type_size <= kReductionDimensionThresholdBytes && - (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) { - // Heuristics: - // - // - Look for a configuration where we will likely be able to keep LHS in - // L1 and do a cache-optimal traversal of RHS. - // - // - Bail out on matrices that are large enough that Eigen can profitably - // shard the computation across multiple cores. This only applies when - // multi-threading is enabled. - return LayoutUtil::IsMonotonicWithDim0Major( - dot.operand(1)->shape().layout()) - ? DotInLlvmIrProfitable::kWithColumnMajorRhs - : DotInLlvmIrProfitable::kYes; +// For vector-matrix dot products, it is always profitable to make the Rhs +// column major. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo) { + if (hlo.opcode() == HloOpcode::kDot && hlo.shape().dimensions_size() == 2 && + hlo.shape().dimensions(0) == 1) { + if (hlo.dot_dimension_numbers().rhs_contracting_dimensions(0) == 0) { + return 1; + } + return {}; + } + + if (hlo.opcode() == HloOpcode::kFusion && + hlo.fusion_kind() == HloInstruction::FusionKind::kOutput) { + auto* fusion_root = + hlo.fused_instructions_computation()->root_instruction(); + if (fusion_root->opcode() != HloOpcode::kAdd) { + return {}; + } + + for (auto* fusion_root_op : fusion_root->operands()) { + if (fusion_root_op->opcode() != HloOpcode::kDot) { + continue; + } + if (auto operand_num = + ProfitableToMakeDotOperandColumnMajor(*fusion_root_op)) { + auto* operand = fusion_root_op->operand(*operand_num); + if (operand->opcode() == HloOpcode::kParameter && + operand->user_count() == 1) { + return operand->parameter_number(); + } + } } } - return DotInLlvmIrProfitable::kNo; + + return {}; } bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) { diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h index c9168ccc0f6629c2a2bfbc7d4dc9c7ebab0a5708..9d748eb81f7850f3ccdb10f076eecfdc8326c05f 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include "llvm/IR/IRBuilder.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" @@ -32,19 +33,11 @@ namespace cpu { bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo); -enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs }; - -// Returns a value to indicate if (and under what conditions) will lowering -// |dot| as a untiled LLVM IR dot operation be profitable over calling into -// Eigen or emitting a tiled LLVM IR implementation. Possible return values -// are: -// -// * DotInLlvmIrProfitable::kYes - always profitable. -// * DotInLlvmIrProfitable::kNo - never profitable. -// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make -// the Rhs layout column major. -DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr( - const HloInstruction& dot); +// Returns the index for an operand to `hlo` that should ideally be column +// major. Returns nullopt if there is no such operand or if `hlo` is not a dot +// or a fusion containing a dot. +tensorflow::gtl::optional ProfitableToMakeDotOperandColumnMajor( + const HloInstruction& hlo); // Returns true to indicate that we can generate a tiled LLVM IR implementation // for |dot|. @@ -57,21 +50,29 @@ class DotOpEmitter { // place the result in target_array. IR is emitted at current insert point of // the builder. Upon completion of the method, the insert point is set to the // end of all instructions emitted for this operation. + // + // If `addend_array` is not nullptr then it must be an array of the same + // dimensions as the result, and the result is computed as `addend_array` + + // dot(`lhs_array`, `rhs_array`). A non-null `addend_array` is only supported + // for Matrix-vector products. static tensorflow::Status EmitDotOperation( const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, - const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray& rhs_array, const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config); + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); private: DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, bool transpose_rhs, const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& lhs_array, const llvm_ir::IrArray& rhs_array, + const llvm_ir::IrArray* addend_array, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, - const HloModuleConfig& hlo_module_config); + const HloModuleConfig& hlo_module_config, + const TargetMachineFeatures& target_machine_features); // Emits the IR to perform the dot operation. tensorflow::Status Emit(); @@ -140,9 +141,11 @@ class DotOpEmitter { const llvm_ir::IrArray& target_array_; const llvm_ir::IrArray& lhs_array_; const llvm_ir::IrArray& rhs_array_; + const llvm_ir::IrArray* addend_array_; llvm::Value* executable_run_options_value_; llvm::IRBuilder<>* ir_builder_; const HloModuleConfig& hlo_module_config_; + const TargetMachineFeatures& target_machine_features_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index ba693ec89ab7c4090f8c9d1e4d65f17a80d0ac55..99c5e16db70c6a203b4751c1ed8a106c0dc260e6 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -33,8 +33,14 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( switch (op->opcode()) { case HloOpcode::kTanh: { PrimitiveType element_type = op->shape().element_type(); + bool cast_result_to_fp16 = false; string function_name; switch (element_type) { + case F16: + cast_result_to_fp16 = true; + operand_value = ir_builder_->CreateFPCast(operand_value, + ir_builder_->getFloatTy()); + TF_FALLTHROUGH_INTENDED; case F32: function_name = "tanhf"; break; @@ -44,26 +50,61 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( default: return Unimplemented("tanh"); } - // Create function type for the function. - llvm::FunctionType* function_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - /*isVarArg=*/false); - // Create function declaration for 'tanhf'. + // Create a function declaration. llvm::Function* function = llvm::cast(module_->getOrInsertFunction( - llvm_ir::AsStringRef(function_name), function_type)); + llvm_ir::AsStringRef(function_name), operand_value->getType(), + operand_value->getType())); function->setCallingConv(llvm::CallingConv::C); function->setDoesNotThrow(); function->setDoesNotAccessMemory(); - // Create instruction to call 'tanhf'. - return ir_builder_->CreateCall(function, operand_value); + // Create an instruction to call the function. + llvm::Value* result = ir_builder_->CreateCall(function, operand_value); + if (cast_result_to_fp16) { + result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + } + return result; } default: return ElementalIrEmitter::EmitFloatUnaryOp(op, operand_value); } } +StatusOr CpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + string function_name; + bool cast_result_to_fp16 = false; + switch (prim_type) { + case F16: + cast_result_to_fp16 = true; + lhs = ir_builder_->CreateFPCast(lhs, ir_builder_->getFloatTy()); + rhs = ir_builder_->CreateFPCast(rhs, ir_builder_->getFloatTy()); + TF_FALLTHROUGH_INTENDED; + case F32: + function_name = "atan2f"; + break; + case F64: + function_name = "atan2"; + break; + default: + return Unimplemented("atan2"); + } + // Create a function declaration. + llvm::Function* function = + llvm::cast(module_->getOrInsertFunction( + llvm_ir::AsStringRef(function_name), lhs->getType(), lhs->getType(), + rhs->getType())); + function->setCallingConv(llvm::CallingConv::C); + function->setDoesNotThrow(); + function->setDoesNotAccessMemory(); + // Create an instruction to call the function. + llvm::Value* result = ir_builder_->CreateCall(function, {lhs, rhs}); + if (cast_result_to_fp16) { + result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + } + return result; +} + llvm_ir::ElementGenerator CpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* hlo, const HloToElementGeneratorMap& operand_to_generator) const { diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h index 7e9f27befb456c17581f556868712f92fd8fd083..4446dfd2821fb4b6e75f33694367392ecbcdd8bf 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h @@ -41,6 +41,8 @@ class CpuElementalIrEmitter : public ElementalIrEmitter { protected: StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; IrEmitter* ir_emitter_; }; diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc index c9f8e5584965d0c73771750e26bd63c401d5b0c0..7dcc4ca7fa08b478f24065275ffa69725dc51682 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -33,15 +33,12 @@ void ExternalConstantPool::Insert(string name, const Literal& literal, CHECK(entries_.find(name) == entries_.end()); int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); - void* raw_pointer; - CHECK_EQ( - posix_memalign(&raw_pointer, std::max(alignment, sizeof(void*)), - literal_size), - 0) - << "failed to allocate " << literal_size << " bytes with alignment of " - << alignment; - - std::memcpy(raw_pointer, literal.InternalData(), literal_size); + void* raw_pointer = tensorflow::port::AlignedMalloc( + literal_size, std::max(alignment, sizeof(void*))); + CHECK(raw_pointer != nullptr) << "failed to allocate " << literal_size + << " bytes with alignment of " << alignment; + + std::memcpy(raw_pointer, literal.untyped_data(), literal_size); entries_.emplace(std::move(name), static_cast(raw_pointer)); } diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h index ade28cbcbcfda05a9ad0adab1139bf316720e11f..8008a56df4dbf16e7b57aee8a344058bb0d5883d 100644 --- a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -13,13 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ #include #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/mem.h" namespace xla { namespace cpu { @@ -49,10 +50,10 @@ class ExternalConstantPool { const uint8* Find(const string& name); private: - // We need to `free()` pointers allocated into `entries_` since we allocate - // them with `posix_memalign`. + // We need to `AlignedFree` pointers allocated into `entries_` since we + // allocate them with `AlignedMalloc`. struct FreeDeleter { - void operator()(void* ptr) { free(ptr); } + void operator()(void* ptr) { tensorflow::port::AlignedFree(ptr); } }; tensorflow::gtl::FlatMap> @@ -61,4 +62,4 @@ class ExternalConstantPool { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 3993779da636e519f8d8fded468c3271d27ee093..788217aab6172b4e548452b3f6ffd4197c163ce4 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -44,6 +44,9 @@ bool PotentiallyImplementedAsEigenConvolution( ShapeUtil::ElementIsComplex(kernel_shape)) { return false; } + if (window_util::HasWindowReversal(convolution.window())) { + return false; + } const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h index ac361ddfb4c8d253ffb1c99200939f6324cad2bb..34b2003916933f5ec0a15d9e219063c0a912fa40 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_EMISSION_UTILS_H_ +#include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" namespace xla { @@ -23,6 +24,19 @@ namespace cpu { bool PotentiallyImplementedAsEigenConvolution( const HloInstruction& convolution); + +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// See IrFunction and ParallelLoopEmitter for details. +using DynamicLoopBounds = std::vector>; + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 502dd2e7387d701e69e1c7ecb67fbdac26c6b5de..d9eeb1c3bdc2a8058992de0e13045a240bf56b8d 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc" #include "llvm/CodeGen/TargetRegisterInfo.h" @@ -42,6 +43,8 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -59,6 +62,7 @@ limitations under the License. #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -69,6 +73,7 @@ namespace { using llvm_ir::AsStringRef; using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; +namespace gtl = tensorflow::gtl; } // namespace namespace cpu { @@ -76,16 +81,16 @@ namespace cpu { IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, - std::unordered_map hlo_to_profile_idx, - tensorflow::gtl::optional entry_computation_profile_idx, + std::unordered_map instruction_to_profile_idx, + std::unordered_map computation_to_profile_idx, llvm::TargetMachine* target_machine, ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), ir_builder_(llvm_module->getContext()), - hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), - entry_computation_profile_idx_(std::move(entry_computation_profile_idx)), + instruction_to_profile_idx_(std::move(instruction_to_profile_idx)), + computation_to_profile_idx_(std::move(computation_to_profile_idx)), alias_analysis_(hlo_module, assignment, &llvm_module->getContext()), hlo_module_config_(hlo_module.config()), parallel_cpu_backend_( @@ -117,138 +122,33 @@ StatusOr IrEmitter::EmitComputation( // readcyclecounter if it is unavailable. bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 || arch_type_ == llvm::Triple::ArchType::x86_64; - profiling_state_ = ProfilingState(is_top_level_computation_, use_rdtscp, - GetProfileCountersArgument()); + profiling_state_ = ProfilingState(use_rdtscp, GetProfileCountersArgument()); if (instruction_order == nullptr) { TF_RETURN_IF_ERROR(computation->Accept(this)); } else { TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order)); } - InsertOrDie(&emitted_functions_, computation, compute_function_); - - return compute_function_; -} - -static llvm::Argument* GetArg(llvm::Function* f, int idx) { - llvm::Function::arg_iterator arg_iter = f->arg_begin(); - std::advance(arg_iter, idx); - return &*arg_iter; + llvm::Function* ir_function = compute_function_->function(); + InsertOrDie(&emitted_functions_, computation, ir_function); + // Delete 'compute_function', finalizing 'ir_function' and restoring caller + // IR insert point. + compute_function_.reset(); + return ir_function; } void IrEmitter::InitializeIrFunction(const string& function_name) { - // The function signature is: - // void function(i8* retval, i8* run_options, i8** params, i8** temps, - // i64* dynamic_loop_bounds, i64* prof_counters) - // - // retval: points to the returned value. - // params: address of an array with pointers to parameters. - // temps: address of an array with pointers to temporary buffers. - // - // Therefore, the generated function's signature (FunctionType) is statically - // determined - parameter unpacking is done in code generated into the - // function, rather than by a prologue dictated by the platform ABI. - // - // /--------------\ - // retval ----------> | return value | - // \--------------/ - // - // /-------------------------------\ - // run_options -----> | xla::ExecutableRunOptions | - // \-------------------------------/ - // - // /---------------------------------------------\ - // params --------> | param 0 | param 1 | ..... | param N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | param 0 | | param 1 | | param N-1 | - // \---------/ \---------/ \-----------/ - // - // /---------------------------------------------\ - // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | - // | addr | addr | | addr | - // \---------------------------------------------/ - // | | | - // | | | - // V V V - // /---------\ /---------\ /-----------\ - // | temp 0 | | temp 1 | | temp N-1 | - // \---------/ \---------/ \-----------/ - // - // /--------------------------------------------\ - // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| - // (elided for aot) \--------------------------------------------/ - // - // /---------------------------------------------\ - // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | - // (elided for aot) \---------------------------------------------/ - - // Even though the type of params and temps is void** in the host's view, in - // LLVM IR this is represented by i8*, similarly to void*. It's up to the code - // to use GEPs to unravel the indirection layers. - llvm::FunctionType* compute_function_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/GetComputeFunctionParams(), - /*isVarArg=*/false); - // Functions with local linkage get an inlining bonus. Because we know // a-priori that embedded functions (non-entry functions) will not have its // name resolved, give it local linkage. llvm::Function::LinkageTypes linkage = is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage : llvm::GlobalValue::InternalLinkage; - compute_function_ = - llvm::Function::Create(/*Ty=*/compute_function_type, - /*Linkage=*/linkage, - /*Name=*/AsStringRef(function_name), - /*Module=*/module_); - compute_function_->setCallingConv(llvm::CallingConv::C); - - // Set meaningful names for the function's arguments: useful for debugging. - llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin(); - arg_iter->setName("retval"); - (++arg_iter)->setName("run_options"); - (++arg_iter)->setName("params"); - (++arg_iter)->setName("temps"); - if (num_dynamic_loop_bounds_ > 0) { - (++arg_iter)->setName("dynamic_loop_bounds"); - } - (++arg_iter)->setName("prof_counters"); - - // We know a-priori that the function arguments are guaranteed to point to - // disjoint objects. - llvm::Argument* retval = GetResultArgument(); - for (llvm::Argument& argument : compute_function_->args()) { - // However, the return buffer aliases the temporaries and thus cannot be - // marked noalias. - if (&argument == retval) { - continue; - } - compute_function_->addAttribute(argument.getArgNo() + 1, - llvm::Attribute::NoAlias); - } - - // Add the optize attribute to the function if optimizing for size. This - // controls internal behavior of some optimization passes (e.g. loop - // unrolling). - if (options::OptimizeForSizeRequested(hlo_module_config_)) { - compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize); - } - - if (hlo_module_config_.debug_options().xla_enable_fast_math()) { - compute_function_->addFnAttr("unsafe-fp-math", "true"); - compute_function_->addFnAttr("no-infs-fp-math", "true"); - compute_function_->addFnAttr("no-nans-fp-math", "true"); - compute_function_->addFnAttr("no-signed-zeros-fp-math", "true"); - } - - ir_builder_.SetInsertPoint(llvm::BasicBlock::Create( - /*Context=*/module_->getContext(), - /*Name=*/"entry", - /*Parent=*/compute_function_)); + // Create and initialize new IrFunction. + compute_function_.reset( + new IrFunction(function_name, linkage, + options::OptimizeForSizeRequested(hlo_module_config_), + hlo_module_config_.debug_options().xla_enable_fast_math(), + module_, &ir_builder_, num_dynamic_loop_bounds_)); } IrEmitter::~IrEmitter() {} @@ -344,11 +244,12 @@ int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { // Calculate the alignment of a buffer allocated for a given primitive type. int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) { - int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - DCHECK_GE(buffer_size, 0); - DCHECK_LE(buffer_size, SIZE_MAX); - - return MinimumAlignmentForBufferSize(buffer_size); + int64 byte_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); + DCHECK_GE(byte_size, 0); + // Largest scalar is a complex64 so we don't need to worry about the + // int64->int truncation here. + DCHECK_LE(byte_size, 8); + return byte_size; } int64 IrEmitter::ByteSizeOf(const Shape& shape) const { @@ -357,6 +258,10 @@ int64 IrEmitter::ByteSizeOf(const Shape& shape) const { // Calculate the alignment of a buffer allocated for a given shape. int IrEmitter::MinimumAlignmentForShape(const Shape& shape) { + if (ShapeUtil::IsScalar(shape)) { + return MinimumAlignmentForPrimitiveType(shape.element_type()); + } + int64 buffer_size = ByteSizeOf(shape); DCHECK_GE(buffer_size, 0); DCHECK_LE(buffer_size, SIZE_MAX); @@ -574,7 +479,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { Status IrEmitter::HandleSort(HloInstruction* sort) { // TODO(b/26783907): Implement sort on CPU. - return Unimplemented("Sort is not supported on CPU (b/26783907)."); + return Unimplemented("Sort is not implemented on CPU."); } Status IrEmitter::HandleTuple(HloInstruction* tuple) { @@ -588,7 +493,7 @@ Status IrEmitter::HandleTuple(HloInstruction* tuple) { } Status IrEmitter::HandleMap(HloInstruction* map) { - tensorflow::gtl::ArraySlice operands(map->operands()); + gtl::ArraySlice operands(map->operands()); HloComputation* function = map->to_apply(); // The called computation should have been emitted previously. llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function); @@ -612,12 +517,12 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, BF16})); // TODO(b/31410564): Implement dilation for reduce-window. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for reduce-window not implemented on CPU. See b/31410564."); + "Dilation for ReduceWindow is not implemented on CPU."); } // The called computation should have been emitted previously. @@ -720,8 +625,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // TODO(b/31410564): Implement dilation for select-and-scatter. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for select-and-scatter not implemented on CPU. " - "See b/31410564."); + "Dilation for SelectAndScatter is not implemented on CPU. "); } // The select and scatter computations should have been emitted previously. @@ -898,6 +802,24 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32, F64, C64})); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions_size() != 1) { + // This is disallowed by ShapeInference today. + return Unimplemented( + "Dot with multiple contracting dimensions not implemented."); + } + + if (dnums.lhs_contracting_dimensions(0) != + std::min(lhs->shape().dimensions_size() - 1, 1) || + dnums.rhs_contracting_dimensions(0) != 0) { + return Unimplemented( + "Dot with non-standard contracting dimensions not implemented."); + } llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); @@ -916,8 +838,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { // Dot operation is complicated so we delegate to a helper class. return DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, - lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_); + lhs_array, rhs_array, /*addend_array=*/nullptr, + GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_, + target_machine_features_); } Status IrEmitter::HandleConvolution(HloInstruction* convolution) { @@ -1189,8 +1112,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); llvm_ir::IrArray::Index kernel_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { - kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i]; + kernel_index[dnums.kernel_spatial_dimensions(i)] = + window.dimensions(i).window_reversal() + ? ir_builder_.CreateNSWSub( + ir_builder_.getInt64(window.dimensions(i).size() - 1), + kernel_spatial[i]) + : kernel_spatial[i]; } + kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; @@ -1207,10 +1136,66 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { }); } +Status IrEmitter::HandleFft(HloInstruction* fft) { + auto operand = fft->operand(0); + TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( + /*instruction=*/*fft, /*operands=*/{operand}, + /*supported_types=*/{F32, C64})); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); + VLOG(3) << "operand=" << ShapeUtil::HumanStringWithLayout(operand->shape()); + VLOG(3) << "fft=" << ShapeUtil::HumanStringWithLayout(fft->shape()); + + llvm::Value* operand_address = GetEmittedValueFor(operand); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fft)); + + const std::vector& fft_length = fft->fft_length(); + int64 input_batch = 1; + for (int i = 0; i < fft->shape().dimensions_size() - fft_length.size(); i++) { + input_batch *= fft->shape().dimensions(i); + } + + // Args have been computed, make the call. + llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo(); + llvm::Type* int32_type = ir_builder_.getInt32Ty(); + llvm::Type* int64_type = ir_builder_.getInt64Ty(); + llvm::FunctionType* fft_type = llvm::FunctionType::get( + ir_builder_.getVoidTy(), + {int8_ptr_type, int8_ptr_type, int8_ptr_type, int32_type, int32_type, + int64_type, int64_type, int64_type, int64_type}, + /*isVarArg=*/false); + const char* fn_name = runtime::kEigenFftSymbolName; + llvm::Function* fft_func = llvm::cast( + module_->getOrInsertFunction(fn_name, fft_type)); + fft_func->setCallingConv(llvm::CallingConv::C); + fft_func->setDoesNotThrow(); + fft_func->setOnlyAccessesInaccessibleMemOrArgMem(); + const int fft_rank = fft_length.size(); + ir_builder_.CreateCall( + fft_func, + {GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast(GetEmittedValueFor(fft), int8_ptr_type), + ir_builder_.CreateBitCast(operand_address, int8_ptr_type), + ir_builder_.getInt32(fft->fft_type()), ir_builder_.getInt32(fft_rank), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(fft_rank > 0 ? fft_length[0] : 0), + ir_builder_.getInt64(fft_rank > 1 ? fft_length[1] : 0), + ir_builder_.getInt64(fft_rank > 2 ? fft_length[2] : 0)}); + + return Status::OK(); +} + Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { + if (hlo_module_config_.replica_count() == 1) { + // When there is a single replica, a cross replica sum is the identity + // function, and the buffer assignment expects a copy (we could eliminate + // these at the HLO level as an optimization). + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(crs)); + return EmitMemcpy(*crs->operand(0), *crs); + } + // TODO(b/33011107): Support cross replica sum on CPU. - return Unimplemented( - "Cross replica sum not implemented on CPU. See b/33011107."); + return Unimplemented("CrossReplicaSum is not implemented on CPU."); } // Fills up the free variables in 'index_with_free_var' with values from @@ -1240,205 +1225,6 @@ static llvm_ir::IrArray::Index FillReducedDimensionIndex( return index_with_free_var; } -Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { - // The output of BatchNormTraining is a tuple of three element: - // - An N-dimensional array containing normalized values. - // - A 1 dimensional array containing the mean value for each feature. - // - A 1 dimensional array containing the variance value for each feature. - HloInstruction* operand = batch_norm_training->operands()[0]; - HloInstruction* scale = batch_norm_training->operands()[1]; - HloInstruction* offset = batch_norm_training->operands()[2]; - float epsilon = batch_norm_training->epsilon(); - int64 feature_index = batch_norm_training->feature_index(); - TF_RET_CHECK(ShapeUtil::IsTuple(batch_norm_training->shape()) && - ShapeUtil::TupleElementCount(batch_norm_training->shape()) == 3); - - const Shape& output_shape = - ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 0); - const Shape& feature_shape = - ShapeUtil::GetTupleElementShape(batch_norm_training->shape(), 1); - - // Reduce vector of the non-feature dimensions. - std::vector dimensions_to_reduce; - - for (int64 i = 0; i < operand->shape().dimensions_size(); ++i) { - if (i != feature_index) { - dimensions_to_reduce.push_back(i); - } - } - - // Get the second and third allocations in the output tuple, which should be - // used to store the result of mean and variance value calculation. - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice_mean, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{1})); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice_var, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{2})); - const int feature_count = output_shape.dimensions(feature_index); - const int size_in_elements = ShapeUtil::ElementsIn(output_shape); - TF_RET_CHECK(ShapeUtil::ElementsIn(operand->shape()) == size_in_elements); - const int elements_per_feature = size_in_elements / feature_count; - - llvm::Value* mean = EmitTempBufferPointer(slice_mean, feature_shape); - llvm_ir::IrArray mean_array(mean, feature_shape); - - llvm::Value* var = EmitTempBufferPointer(slice_var, feature_shape); - llvm_ir::IrArray var_array(var, feature_shape); - - // This loop calculates mean and variance for each feature. - // - // In theory this could be swapped by multi-output fusion. We will evaluate - // this when it's ready. - // - // For variance calculation, we use a simplified formula so we can fuse the - // computation into the same loop to calculate mean: Var=E(X^2) - E(X)^2. - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter( - [&](const llvm_ir::IrArray::Index& index) { - PrimitiveType element_type = operand->shape().element_type(); - // Used to calculate E(X). - llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - "sum_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(element_type)); - - // Used to calculate E(X^2). - llvm::Value* sum_square_address = - llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, module_), - "sum_square_address", &ir_builder_, - MinimumAlignmentForPrimitiveType(element_type)); - - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), - sum_address); - - ir_builder_.CreateStore( - llvm::ConstantFP::get(ir_builder_.getFloatTy(), 0.0), - sum_square_address); - - llvm_ir::ForLoopNest loops(IrName(batch_norm_training, "inner"), - &ir_builder_); - - const llvm_ir::IrArray::Index reduced_dims_index = - loops.AddLoopsForShapeOnDimensions( - operand->shape(), dimensions_to_reduce, "reduction_dim"); - - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), - &ir_builder_); - - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); - llvm_ir::IrArray::Index input_index = - FillReducedDimensionIndex(reduced_dims_index, index); - llvm::Value* new_value = - operand_array.EmitReadArrayElement(input_index, &ir_builder_); - - llvm::Value* new_value_square = - ir_builder_.CreateFMul(new_value, new_value); - - llvm::Value* current_sum = ir_builder_.CreateLoad(sum_address); - llvm::Value* current_sum_square = - ir_builder_.CreateLoad(sum_square_address); - // Update sum. - ir_builder_.CreateStore( - ir_builder_.CreateFAdd(current_sum, new_value), sum_address); - - // Update sum square. - ir_builder_.CreateStore( - ir_builder_.CreateFAdd(current_sum_square, new_value_square), - sum_square_address); - - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), - &ir_builder_); - - llvm::Value* sum = ir_builder_.CreateLoad(sum_address); - llvm::Value* elements_per_feature_value = llvm::ConstantFP::get( - ir_builder_.getFloatTy(), elements_per_feature); - llvm::Value* mean = - ir_builder_.CreateFDiv(sum, elements_per_feature_value); - llvm::Value* mean_square = ir_builder_.CreateFMul(mean, mean); - llvm::Value* sum_square = - ir_builder_.CreateLoad(sum_square_address); - - // Var=E(X^2) - E(X)^2. - llvm::Value* var = ir_builder_.CreateFSub( - ir_builder_.CreateFDiv(sum_square, elements_per_feature_value), - mean_square); - - var_array.EmitWriteArrayElement(index, var, &ir_builder_); - return mean; - }, - mean_array, &ir_builder_) - .EmitLoop(IrName(batch_norm_training, "mean_var"))); - - TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training)); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); - - llvm::Value* normalized = EmitTempBufferPointer(slice, output_shape); - - llvm_ir::IrArray target_array(normalized, output_shape); - - AddAliasingInformationToIrArray(*batch_norm_training, &target_array); - - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter( - [this, mean_array, var_array, epsilon, operand, dimensions_to_reduce, - feature_index, offset, scale](const llvm_ir::IrArray::Index& index) { - // The following logic normalizes the input value, scales and shifts - // it: - // - // normalized = (input - mean) / sqrt(variance + epsilon) - // result = normalized * scale + offset - - // Current index in the feature dimension. - llvm_ir::IrArray::Index feature_index_value(1, - index[feature_index]); - - llvm::Value* mean = mean_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm::Value* var = var_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - - llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); - llvm::Value* input = - operand_array.EmitReadArrayElement(index, &ir_builder_); - - llvm::Value* variance_with_epsilon = ir_builder_.CreateFAdd( - var, llvm::ConstantFP::get(ir_builder_.getFloatTy(), epsilon)); - llvm::Function* func_llvm_sqrt = llvm::Intrinsic::getDeclaration( - module_, llvm::Intrinsic::sqrt, {ir_builder_.getFloatTy()}); - llvm::Value* variance_sqrt = - ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); - llvm::Value* normalized = ir_builder_.CreateFDiv( - ir_builder_.CreateFSub(input, mean), variance_sqrt); - llvm_ir::IrArray offset_array(GetIrArrayFor(offset)); - llvm::Value* offset = offset_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm_ir::IrArray scale_array(GetIrArrayFor(scale)); - llvm::Value* scale = scale_array.EmitReadArrayElement( - feature_index_value, &ir_builder_); - llvm::Value* result = ir_builder_.CreateFAdd( - ir_builder_.CreateFMul(normalized, scale), offset); - - return result; - }, - target_array, &ir_builder_) - .EmitLoop(IrName(batch_norm_training, "normalize"))); - - llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training), - {normalized, mean, var}, &ir_builder_, module_); - return Status::OK(); -} - -Status IrEmitter::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { - // TODO(b/62843645) Implement BatchNormGrad on CPU backend. - return Unimplemented( - "BatchNormGrad is not implemented on CPU. See b/62843645."); -} - Status IrEmitter::HandleParameter(HloInstruction* parameter) { VLOG(2) << "HandleParameter: " << parameter->ToString(); auto param_number = parameter->parameter_number(); @@ -1452,15 +1238,20 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { // // Where Param is the actual element type of the underlying buffer (for // example, float for an XLA F32 element type). - llvm::Argument* params = GetArg(compute_function_, 2); + llvm::Value* params = compute_function_->parameters_arg(); llvm::Value* param_address_offset = llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = ir_builder_.CreateLoad(param_address_offset); param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // We never reassign parameters, so this load is invariant. + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. param_address_untyped->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); @@ -1479,6 +1270,52 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } +// Returns true if the relative order of the unreduced dimensions stays the same +// through the reduce operation. +static bool ReductionPreservesLayout(const HloInstruction& reduce) { + DCHECK_EQ(reduce.opcode(), HloOpcode::kReduce); + + // Maps dimensions that were not reduced from their dimension numbers in the + // source shape to their dimensions numbers in the destination shape. + // + // So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains + // [0->0, 3->1]. + gtl::FlatMap unreduced_dim_map; + + gtl::FlatSet reduced_dims(reduce.dimensions().begin(), + reduce.dimensions().end()); + + const Shape& operand_shape = reduce.operand(0)->shape(); + const Shape& result_shape = reduce.shape(); + + int64 delta = 0; + for (int64 i = 0; i < operand_shape.dimensions_size(); i++) { + if (reduced_dims.count(i)) { + delta++; + } else { + InsertOrDie(&unreduced_dim_map, i, i - delta); + } + } + + // Iterate dimensions minor to major and check that the corresponding + // dimensions in the source and target shapes are equivalent. + int64 result_dim_idx = 0; + for (int64 operand_dim_idx = 0; + operand_dim_idx < operand_shape.dimensions_size(); operand_dim_idx++) { + int64 operand_dim = operand_shape.layout().minor_to_major(operand_dim_idx); + if (!reduced_dims.count(operand_dim)) { + if (FindOrDie(unreduced_dim_map, operand_dim) != + result_shape.layout().minor_to_major(result_dim_idx++)) { + return false; + } + } + } + + CHECK_EQ(result_dim_idx, result_shape.dimensions_size()); + + return true; +} + IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( HloComputation* function, string* failure_reason) const { CHECK_EQ(function->num_parameters(), 2); @@ -1495,7 +1332,7 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( if (ShapeUtil::ElementIsComplex(root_shape)) { // TODO(b/65408531): Complex add could by done via bitcast to // Complex multiply would be more challenging. We could perhaps use a - // strided load to get all reals in a vector, all imags in a vector, or use + // strided load to get all reals in a vector, all images in a vector, or use // CreateShuffleVector on a bitcast to float x [2N]. *failure_reason = "complex values not supported"; return nullptr; @@ -1587,13 +1424,9 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( PrimitiveType element_type, unsigned element_count) { - // Here we assume that the largest register is a vector register. - int max_vector_register_size_in_bytes = - target_machine_features_.largest_register_size_in_bytes( - compute_function_); - int vector_register_size_in_elements = - max_vector_register_size_in_bytes / + target_machine_features_.vector_register_byte_size( + *compute_function_->function()) / ShapeUtil::ByteSizeOfPrimitiveType(element_type); ShardedVectorType sharded_vector_type; @@ -1646,7 +1479,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( const ReductionGenerator& reduction_generator, const llvm_ir::IrArray::Index& output_index, const ShardedVectorType& accumulator_type, HloInstruction* init_value, - HloInstruction* arg, tensorflow::gtl::ArraySlice dimensions, + HloInstruction* arg, gtl::ArraySlice dimensions, unsigned element_alignment) { ShardedVector accumulator; accumulator.reserve(accumulator_type.size()); @@ -1748,23 +1581,14 @@ void IrEmitter::EmitShardedVectorStore( } } -namespace { -// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc. -// Extract out a common implementation to tensorflow/core/lib/math/math_util.h -uint32 GCD(uint32 x, uint32 y) { - while (y != 0) { - uint32 r = x % y; - x = y; - y = r; - } - return x; -} -} // namespace - StatusOr IrEmitter::EmitVectorizedReduce( HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function, + gtl::ArraySlice dimensions, HloComputation* function, string* failure_reason) { + if (!ReductionPreservesLayout(*reduce)) { + return false; + } + ReductionGenerator reduction_generator = MatchReductionGenerator(function, failure_reason); if (!reduction_generator) { @@ -1781,11 +1605,12 @@ StatusOr IrEmitter::EmitVectorizedReduce( bool is_reduction_over_minor_dimension = std::find(dimensions.begin(), dimensions.end(), - arg->shape().layout().minor_to_major(0)) != dimensions.end(); + LayoutUtil::Minor(arg->shape().layout(), 0)) != + dimensions.end(); - unsigned element_alignment = - GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), - MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); + unsigned element_alignment = tensorflow::MathUtil::GCD( + ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()), + MinimumAlignmentForPrimitiveType(reduce->shape().element_type())); if (is_reduction_over_minor_dimension) { // TODO(sanjoy): Implement vectorized reduction over the minor dimension. @@ -1818,8 +1643,9 @@ StatusOr IrEmitter::EmitVectorizedReduce( llvm_ir::ForLoopNest loop_nest(IrName(reduce), &ir_builder_); llvm_ir::IrArray::Index array_index(reduce->shape().dimensions_size()); - for (int i = reduce->shape().layout().minor_to_major_size() - 1; i > 0; --i) { - int64 dimension = reduce->shape().layout().minor_to_major(i); + for (int i = LayoutUtil::MinorToMajor(reduce->shape()).size() - 1; i > 0; + --i) { + int64 dimension = LayoutUtil::Minor(reduce->shape().layout(), i); int64 start_index = 0; int64 end_index = reduce->shape().dimensions(dimension); std::unique_ptr loop = @@ -1828,7 +1654,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( array_index[dimension] = loop->GetIndVarValue(); } - int64 innermost_dimension = reduce->shape().layout().minor_to_major(0); + int64 innermost_dimension = LayoutUtil::Minor(reduce->shape().layout(), 0); int64 innermost_dimension_size = reduce->shape().dimensions(innermost_dimension); @@ -1864,10 +1690,10 @@ StatusOr IrEmitter::EmitVectorizedReduce( target_array); if (auto exit_terminator = loop->GetExitBasicBlock()->getTerminator()) { - CHECK_GT(reduce->shape().layout().minor_to_major_size(), 1); + CHECK_GT(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); ir_builder_.SetInsertPoint(exit_terminator); } else { - CHECK_EQ(reduce->shape().layout().minor_to_major_size(), 1); + CHECK_EQ(LayoutUtil::MinorToMajor(reduce->shape()).size(), 1); ir_builder_.SetInsertPoint(loop->GetExitBasicBlock()); } } @@ -1906,7 +1732,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( Status IrEmitter::HandleReduce(HloInstruction* reduce) { auto arg = reduce->mutable_operand(0); auto init_value = reduce->mutable_operand(1); - tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + gtl::ArraySlice dimensions(reduce->dimensions()); HloComputation* function = reduce->to_apply(); if (!options::VectorizedReduceDisabled(hlo_module_config_)) { string vectorization_failure_reason; @@ -1983,19 +1809,19 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce) { Status IrEmitter::HandleSend(HloInstruction* send) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Send is not implemented on CPU. See b/33942983."); + return Unimplemented("Send is not implemented on CPU."); } Status IrEmitter::HandleSendDone(HloInstruction* send_done) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Send-done is not implemented on CPU. See b/33942983."); + return Unimplemented("Send-done is not implemented on CPU."); } Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use - // EmitParallelTargetElementLoop() which respects dynamic loop bounds. + // ParallelLoopEmitter which respects dynamic loop bounds. if (ShouldEmitParallelLoopFor(*slice)) { return DefaultAction(slice); } @@ -2026,8 +1852,8 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // // * Implement the memcpy within the innermost loop. - tensorflow::gtl::FlatSet inner_dims; - for (int64 dim : layout.minor_to_major()) { + gtl::FlatSet inner_dims; + for (int64 dim : LayoutUtil::MinorToMajor(layout)) { if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) { break; } @@ -2054,7 +1880,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) { // memcpy_dim is the innermost (in terms of layout) dimension for which the // slice does *not* just copy all the elements along the dimension. - const int64 memcpy_dim = layout.minor_to_major(inner_dims.size()); + const int64 memcpy_dim = LayoutUtil::Minor(layout, inner_dims.size()); const bool memcpy_is_contiguous = slice->slice_strides(memcpy_dim) == 1; // The number of logical elements that can be copied in a single call @@ -2153,12 +1979,12 @@ Status IrEmitter::HandleDynamicUpdateSlice( Status IrEmitter::HandleRecv(HloInstruction* recv) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Recv is not implemented on CPU. See b/33942983."); + return Unimplemented("Recv is not implemented on CPU."); } Status IrEmitter::HandleRecvDone(HloInstruction* recv_done) { // TODO(b/33942983): Support Send/Recv on CPU. - return Unimplemented("Recv-done is not implemented on CPU. See b/33942983."); + return Unimplemented("Recv-done is not implemented on CPU."); } Status IrEmitter::HandlePad(HloInstruction* pad) { @@ -2167,10 +1993,10 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { for (auto& padding_dimension : pad->padding_config().dimensions()) { if (padding_dimension.edge_padding_low() < 0 || padding_dimension.edge_padding_high() < 0) { - return Unimplemented( - "Negative padding not supported in the CPU backend (b/34628603); " - "this should have been eliminated at the HLO level: %s", - pad->padding_config().ShortDebugString().c_str()); + return InternalErrorStrCat( + "Encountered negative padding in IrEmitter on CPU. " + "This should have been eliminated at the HLO level. ", + pad->ToString()); } } @@ -2263,8 +2089,8 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( *root, root->operand(0)->IsRank2Transpose(), root->operand(1)->IsRank2Transpose(), target_array, lhs_array, - rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_)); + rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_, target_machine_features_)); return Status::OK(); } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) { @@ -2285,6 +2111,35 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); + } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kOutput) { + VLOG(3) << "HandleFusion kOutput"; + int64 dot_op_index = root->operand(0)->opcode() == HloOpcode::kDot ? 0 : 1; + const HloInstruction* dot = root->operand(dot_op_index); + CHECK_EQ(dot->opcode(), HloOpcode::kDot) + << dot->ToString() << " " + << fusion->fused_instructions_computation()->ToString(); + + int64 dot_lhs_param_number = dot->operand(0)->parameter_number(); + int64 dot_rhs_param_number = dot->operand(1)->parameter_number(); + int64 addend_param_number = + root->operand(1 - dot_op_index)->parameter_number(); + + Shape target_shape = fusion->shape(); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array = GetIrArrayFor(fusion); + + llvm_ir::IrArray lhs_array( + GetIrArrayFor(fusion->operand(dot_lhs_param_number))); + llvm_ir::IrArray rhs_array( + GetIrArrayFor(fusion->operand(dot_rhs_param_number))); + llvm_ir::IrArray addend_array( + GetIrArrayFor(fusion->operand(addend_param_number))); + + TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, + lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(), + &ir_builder_, hlo_module_config_, target_machine_features_)); + return Status::OK(); } else { return Unimplemented("Fusion kind not implemented on CPU"); } @@ -2305,9 +2160,17 @@ Status IrEmitter::HandleCall(HloInstruction* call) { !parallel_cpu_backend_) { // ParallelTaskAssignment assigned partitions, emit call to // ParallelForkJoin. - TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, - emitted_value_[call], computation, - call_ir_function)); + std::vector call_args = GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, computation->name(), + /*return_value_buffer=*/emitted_value_[call], + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument()); + + HloInstruction* root = computation->root_instruction(); + TF_RETURN_IF_ERROR(EmitCallToParallelForkJoin( + call_args, root->shape(), root->outer_dimension_partitions(), + &ir_builder_, call_ir_function, computation->name())); } else { EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, emitted_value_[call], computation->name()); @@ -2317,8 +2180,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) { } Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { - tensorflow::gtl::ArraySlice operands( - custom_call->operands()); + gtl::ArraySlice operands(custom_call->operands()); tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = @@ -2410,7 +2272,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Terminates the current block with a branch to a while header. llvm::BasicBlock* header_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "header")), - compute_function_); + compute_function_->function()); ir_builder_.CreateBr(header_bb); ir_builder_.SetInsertPoint(header_bb); @@ -2427,7 +2289,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "body")), - compute_function_); + compute_function_->function()); llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create( module_->getContext(), AsStringRef(IrName(xla_while, "exit"))); ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb); @@ -2442,15 +2304,14 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { ir_builder_.CreateBr(header_bb); // Adds the exit block to the function and sets the insert point there. - compute_function_->getBasicBlockList().push_back(exit_bb); + compute_function_->function()->getBasicBlockList().push_back(exit_bb); ir_builder_.SetInsertPoint(exit_bb); return Status::OK(); } StatusOr IrEmitter::EmitFastConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands, + HloInstruction* concatenate, gtl::ArraySlice operands, string* failure_reason) { if (ShouldEmitParallelLoopFor(*concatenate)) { *failure_reason = @@ -2478,14 +2339,13 @@ StatusOr IrEmitter::EmitFastConcatenate( int64 concat_dim = concatenate->dimensions(0); const Layout& output_layout = output_shape.layout(); + auto output_min2maj = LayoutUtil::MinorToMajor(output_layout); auto concat_dim_layout_itr = - std::find(output_layout.minor_to_major().begin(), - output_layout.minor_to_major().end(), concat_dim); + std::find(output_min2maj.begin(), output_min2maj.end(), concat_dim); - std::vector inner_dims(output_layout.minor_to_major().begin(), - concat_dim_layout_itr); + std::vector inner_dims(output_min2maj.begin(), concat_dim_layout_itr); std::vector outer_dims(std::next(concat_dim_layout_itr), - output_layout.minor_to_major().end()); + output_min2maj.end()); llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::Type* i8_type = ir_builder_.getInt8Ty(); @@ -2560,7 +2420,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, const llvm_ir::IrArray& source_array) { unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - unsigned element_alignment = GCD( + unsigned element_alignment = tensorflow::MathUtil::GCD( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); @@ -2590,8 +2450,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { - tensorflow::gtl::ArraySlice operands( - concatenate->operands()); + gtl::ArraySlice operands(concatenate->operands()); string failure_reason; TF_ASSIGN_OR_RETURN( bool successful, @@ -2607,6 +2466,65 @@ Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { return DefaultAction(concatenate); } +Status IrEmitter::HandleConditional(HloInstruction* conditional) { + auto pred = conditional->operand(0); + auto true_arg = conditional->operand(1); + auto false_arg = conditional->operand(2); + TF_RET_CHECK(ShapeUtil::IsScalar(pred->shape()) && + pred->shape().element_type() == PRED) + << "Predicate on a Conditional must be bool; got: " + << ShapeUtil::HumanString(pred->shape()); + + HloComputation* true_computation = conditional->true_computation(); + HloComputation* false_computation = conditional->false_computation(); + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + true_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the true " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(true_computation->root_instruction()->shape()); + + TF_RET_CHECK(ShapeUtil::Equal(conditional->shape(), + false_computation->root_instruction()->shape())) + << "Shape of conditional should be same as the shape of the false " + << "computation; got: " << ShapeUtil::HumanString(conditional->shape()) + << " and " + << ShapeUtil::HumanString(false_computation->root_instruction()->shape()); + + llvm::Function* true_function = + FindOrDie(emitted_functions_, true_computation); + llvm::Function* false_function = + FindOrDie(emitted_functions_, false_computation); + + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(conditional)); + llvm::Value* conditional_result = GetEmittedValueFor(conditional); + + // Generating: + // if (pred) + // cond_result = true_computation(true_operand) + // else + // cond_result = false_computation(false_operand) + llvm::LoadInst* pred_value = ir_builder_.CreateLoad( + GetIrArrayFor(pred).GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder_.CreateICmpNE( + pred_value, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), + "boolean_predicate"); + llvm_ir::LlvmIfData if_data = + llvm_ir::EmitIfThenElse(pred_cond, "conditional", &ir_builder_); + + SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + EmitArrayFunctionCallInto(true_function, {GetEmittedValueFor(true_arg)}, + conditional_result, IrName(conditional, "_true")); + + SetToFirstInsertPoint(if_data.false_block, &ir_builder_); + EmitArrayFunctionCallInto(false_function, {GetEmittedValueFor(false_arg)}, + conditional_result, IrName(conditional, "_false")); + + SetToFirstInsertPoint(if_data.after_block, &ir_builder_); + return Status::OK(); +} + Status IrEmitter::FinishVisit(HloInstruction* root) { // When this method is called, we should have already emitted an IR value for // the root (return) op. The IR value holds the address of the buffer holding @@ -2618,57 +2536,51 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { llvm::Value* root_value = GetEmittedValueFor(root); VLOG(2) << " value: " << llvm_ir::DumpToString(*root_value); - llvm::Value* prof_counter = [&]() { - // For the parallel cpu backend, we record the total for each embedded - // computation callee with its caller kCall HLO. - if (parallel_cpu_backend_ && is_top_level_computation_) { - auto* computation = root->parent(); - auto* entry_computation = computation->parent()->entry_computation(); - if (computation != entry_computation) { - for (HloInstruction* instruction : entry_computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCall && - instruction->to_apply()->root_instruction() == root) { - return GetProfileCounterFor(*instruction); - } + auto record_complete_computation = [&](llvm::Value* prof_counter) { + if (prof_counter) { + profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); + } + }; + + // For the parallel cpu backend, we record the total for each embedded + // computation callee with its caller kCall HLO. + if (parallel_cpu_backend_ && is_top_level_computation_) { + auto* computation = root->parent(); + auto* entry_computation = computation->parent()->entry_computation(); + if (computation != entry_computation) { + for (HloInstruction* instruction : entry_computation->instructions()) { + if (instruction->opcode() == HloOpcode::kCall && + instruction->to_apply()->root_instruction() == root) { + record_complete_computation(GetProfileCounterFor(*instruction)); + return Status::OK(); } } } - - // Otherwise we record the total computation cycles in a dedicated slot for - // the entry computation. - return GetProfileCounterForEntryComputation(); - }(); - - if (prof_counter) { - profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter); } - ir_builder_.CreateRetVoid(); + + // For the entry computation this increment is cumulative of embedded + // computations since it includes cycles spent in computations invoked by + // While, Call etc. + record_complete_computation(GetProfileCounterFor(*root->parent())); return Status::OK(); } -llvm::Value* IrEmitter::GetProfileCounterFor(const HloInstruction& hlo) { - auto it = hlo_to_profile_idx_.find(&hlo); - if (it == hlo_to_profile_idx_.end()) { +template +llvm::Value* IrEmitter::GetProfileCounterCommon( + const T& hlo, + const std::unordered_map& profile_index_map) { + auto it = profile_index_map.find(&hlo); + if (it == profile_index_map.end()) { return nullptr; } - size_t prof_counter_idx = it->second; + int64 prof_counter_idx = it->second; string counter_name = IrName("prof_counter", hlo.name()); return ir_builder_.CreateGEP(GetProfileCountersArgument(), ir_builder_.getInt64(prof_counter_idx), AsStringRef(counter_name)); } -llvm::Value* IrEmitter::GetProfileCounterForEntryComputation() { - if (entry_computation_profile_idx_) { - return ir_builder_.CreateGEP( - GetProfileCountersArgument(), - ir_builder_.getInt64(*entry_computation_profile_idx_), - "prof_counter.computation"); - } - return nullptr; -} - void IrEmitter::ProfilingState::UpdateProfileCounter( llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter, llvm::Value* cycle_end, llvm::Value* cycle_start) { @@ -2731,8 +2643,7 @@ void IrEmitter::ProfilingState::RecordCycleDelta(llvm::IRBuilder<>* ir_builder, void IrEmitter::ProfilingState::RecordCompleteComputation( llvm::IRBuilder<>* ir_builder, llvm::Value* prof_counter) { - if (is_top_level_computation_ && last_read_cycle_end_ && - first_read_cycle_start_) { + if (last_read_cycle_end_ && first_read_cycle_start_) { UpdateProfileCounter(ir_builder, prof_counter, last_read_cycle_end_, first_read_cycle_start_); } @@ -2740,7 +2651,7 @@ void IrEmitter::ProfilingState::RecordCompleteComputation( Status IrEmitter::Preprocess(HloInstruction* hlo) { VLOG(3) << "Visiting: " << hlo->ToString(); - if (hlo_to_profile_idx_.count(hlo)) { + if (instruction_to_profile_idx_.count(hlo)) { profiling_state_.RecordCycleStart(&ir_builder_, hlo); } return Status::OK(); @@ -2783,43 +2694,16 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { return llvm_ir::ShapeToIrType(shape, module_); } -std::vector IrEmitter::GetComputeFunctionParams() { - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (num_dynamic_loop_bounds_ > 0) { - compute_function_params.push_back(i64_ptr_type); - } - compute_function_params.push_back(i64_ptr_type); - return compute_function_params; -} - -llvm::Argument* IrEmitter::GetResultArgument() { - return GetArg(compute_function_, 0); -} - -llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; - return GetArg(compute_function_, arg_index); +llvm::Value* IrEmitter::GetProfileCountersArgument() { + return compute_function_->profile_counters_arg(); } llvm::Value* IrEmitter::GetTempBuffersArgument() { - return GetArg(compute_function_, 3); -} - -llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) { - CHECK_GT(num_dynamic_loop_bounds_, 0); - CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); - llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4); - string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); - return ir_builder_.CreateLoad(ir_builder_.CreateGEP( - loop_bounds_arg, ir_builder_.getInt64(offset), AsStringRef(name))); + return compute_function_->temp_buffers_arg(); } llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() { - return GetArg(compute_function_, 1); + return compute_function_->exec_run_options_arg(); } llvm::Value* IrEmitter::EmitTempBufferPointer( @@ -2850,10 +2734,14 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( GetTempBuffersArgument(), slice.index(), &ir_builder_); llvm::LoadInst* tempbuf_address_base = ir_builder_.CreateLoad(tempbuf_address_ptr); - if (hlo_module_config_.debug_options() + if (is_top_level_computation_ && + hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. + // In the entry computation the parameter slots in the %params argument are + // invariant through program execution. In computations that are called + // from the entry computation (via kWhile, kCall and kConditional) the + // parameter slots are *not* invariant since they're written to by their + // callers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); @@ -2875,7 +2763,7 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( // for a single element_type value, and loads it after call. llvm::Value* IrEmitter::EmitElementFunctionCall( llvm::Function* function, const Shape& return_shape, - tensorflow::gtl::ArraySlice parameter_addresses, + gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name) { llvm::Value* return_value_buffer = EmitArrayFunctionCall( function, return_shape, 1, parameter_addresses, name); @@ -2884,42 +2772,6 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits code to allocate an array of parameter address pointers, and store -// each address from 'parameter_addresses'. -// Returns an array of compute function call arguments (including parameter -// address buffer). -std::vector IrEmitter::GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name) { - llvm::Value* parameter_addresses_buffer = - llvm_ir::EmitAllocaAtFunctionEntryWithCount( - ir_builder_.getInt8PtrTy(), - ir_builder_.getInt32(parameter_addresses.size()), - tensorflow::strings::StrCat(name, "_parameter_addresses"), - &ir_builder_); - for (size_t i = 0; i < parameter_addresses.size(); ++i) { - llvm::Value* parameter_as_i8ptr = ir_builder_.CreateBitCast( - parameter_addresses[i], ir_builder_.getInt8PtrTy(), - AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, - "_address_as_i8ptr"))); - llvm::Value* slot_in_param_adresses = ir_builder_.CreateInBoundsGEP( - parameter_addresses_buffer, {ir_builder_.getInt64(i)}); - ir_builder_.CreateStore(parameter_as_i8ptr, slot_in_param_adresses); - } - - const auto to_int8_ptr = [this](llvm::Value* ptr) { - return ir_builder_.CreatePointerCast(ptr, ir_builder_.getInt8PtrTy()); - }; - std::vector arguments{ - to_int8_ptr(return_value_buffer), - to_int8_ptr(GetExecutableRunOptionsArgument()), - parameter_addresses_buffer, GetTempBuffersArgument()}; - if (auto* profile_counters = GetProfileCountersArgument()) { - arguments.push_back(profile_counters); - } - return arguments; -} - // Emits a core function call based on the following pseudo-code. // // char** parameter_addresses_buffer = @@ -2931,17 +2783,20 @@ std::vector IrEmitter::GetArrayFunctionCallArguments( // temps) // return return_value_buffer -- address of the return value. void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, - tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Function* function, gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { ir_builder_.CreateCall( - function, GetArrayFunctionCallArguments(parameter_addresses, - return_value_buffer, name)); + function, GetArrayFunctionCallArguments( + parameter_addresses, &ir_builder_, name, + /*return_value_buffer=*/return_value_buffer, + /*exec_run_options_arg=*/GetExecutableRunOptionsArgument(), + /*temp_buffers_arg=*/GetTempBuffersArgument(), + /*profile_counters_arg=*/GetProfileCountersArgument())); } llvm::Value* IrEmitter::EmitArrayFunctionCall( llvm::Function* function, const Shape& return_shape, int64 element_count, - tensorflow::gtl::ArraySlice parameter_addresses, + gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name) { llvm::Value* elements = llvm::ConstantInt::get(ir_builder_.getInt64Ty(), element_count); @@ -2956,117 +2811,13 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -// Emits a call to a runtime fork/join function which dispatches parallel -// calls to 'parallel_function' (and joins threads before returning). -Status IrEmitter::EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function) { - HloInstruction* root = computation->root_instruction(); - - // Build ParallelForkJoin function type. - std::vector compute_function_params = GetComputeFunctionParams(); - // Number of parallel compute functions. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Array of partitions. There is an array element for each - // partition x partition_dim x 2 (for dimension start and limit). - compute_function_params.push_back( - llvm::Type::getInt64PtrTy(module_->getContext())); - // Number of partitioned most-major dimensions in 'root.shape'. - compute_function_params.push_back(ir_builder_.getInt32Ty()); - // Function pointer for compute function to be dispatched in parallel. - compute_function_params.push_back( - llvm::Type::getInt8PtrTy(module_->getContext())); - - llvm::FunctionType* fork_join_type = llvm::FunctionType::get( - /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, - /*isVarArg=*/false); - - llvm::Function* fork_join_func = - llvm::cast(module_->getOrInsertFunction( - runtime::kParallelForkJoinSymbolName, fork_join_type)); - fork_join_func->setCallingConv(llvm::CallingConv::C); - fork_join_func->setDoesNotThrow(); - - // Add common compute function arguments. - const string name = computation->name(); - std::vector arguments = - GetArrayFunctionCallArguments(parameter_addresses, output_address, name); - - // Create ShapePartitionIterator to generate all partitions of 'root.shape'. - ShapePartitionIterator partition_iterator(root->shape(), - root->outer_dimension_partitions()); - const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); - // Add argument specifying the number of parallel partitions. - arguments.push_back(ir_builder_.getInt32(num_partitions)); - - // The number of partitioned most-major dimensions in 'root.shape'. - const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); - // A dimension partition consists of two elements: [start_index, limit_index). - const int32 dim_partition_size = 2; - // Calculate array partition stride. - const int32 array_partition_stride = - num_partitioned_dims * dim_partition_size; - // Calculate the total number of elements in the partition array. - const int32 partition_array_size = - dim_partition_size * num_partitioned_dims * num_partitions; - - // Store dimension partition values as llvm constants in 'partitions'. - // See comments in runtime_fork_join.cc for array layout description. - std::vector partitions(partition_array_size); - for (int32 i = 0; i < num_partitions; ++i) { - std::vector> dim_partitions = - partition_iterator.GetPartition(i); - CHECK_EQ(num_partitioned_dims, dim_partitions.size()); - const int32 partition_index = i * array_partition_stride; - for (int32 j = 0; j < num_partitioned_dims; ++j) { - const std::pair& dim_partition = dim_partitions[j]; - const int32 index = partition_index + j * dim_partition_size; - // Store partition [dim_start, dim_limit) intervals for each dimension. - partitions[index] = ir_builder_.getInt64(dim_partition.first); - partitions[index + 1] = - ir_builder_.getInt64(dim_partition.first + dim_partition.second); - } - } - - // Create global variable out of dimension partitions in 'partitions'. - llvm::ArrayType* partitions_array_type = - llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); - llvm::Constant* partitions_array = - llvm::ConstantArray::get(partitions_array_type, partitions); - llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/partitions_array_type, - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/partitions_array, - /*Name=*/ - AsStringRef( - tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); - - // Add argument specifying parallel dimension partitions. - arguments.push_back(ir_builder_.CreateBitCast( - global_partitions_array, - llvm::Type::getInt64PtrTy(module_->getContext()))); - // Add argument specifying the number of partitioned most-major dimensions. - arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); - // Add argument for parallel compute function pointer. - arguments.push_back( - ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); - // Emit call to parallel fork/join. - ir_builder_.CreateCall(fork_join_func, arguments); - - return Status::OK(); -} - Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { llvm::Value* addr; const Shape& target_shape = op->shape(); if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. - llvm::Argument* retval = GetResultArgument(); + llvm::Argument* retval = compute_function_->result_arg(); if (!ShapeUtil::IsNil(target_shape)) { llvm::AttrBuilder attr_builder; attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape)); @@ -3127,8 +2878,13 @@ Status IrEmitter::EmitTargetElementLoop( } else { if (ShouldEmitParallelLoopFor(*target_op)) { - TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( - target_shape, element_generator, IrName(target_op), &target_array)); + // Emit code to read dynamic loop bounds from compute function argument. + std::vector> dynamic_loop_bounds = + compute_function_->GetDynamicLoopBounds(); + // Emit parallel loop with dynamic loop bounds for most-major dimensions. + TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array, + &dynamic_loop_bounds, &ir_builder_) + .EmitLoop(IrName(target_op))); } else { TF_RETURN_IF_ERROR( llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_) @@ -3138,60 +2894,6 @@ Status IrEmitter::EmitTargetElementLoop( return Status::OK(); } -Status IrEmitter::EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array) { - CHECK(!ShapeUtil::IsTuple(target_shape)); - CHECK(!ShapeUtil::IsScalar(target_shape)); - - // Emit code to read dynamic loop bounds from function argument 4. - std::vector dynamic_loop_bounds(2 * num_dynamic_loop_bounds_); - for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) { - dynamic_loop_bounds[i] = GetDynamicLoopBound(i); - } - - llvm_ir::ForLoopNest loop_nest(loop_name, &ir_builder_); - const int64 num_dims = target_shape.dimensions_size(); - llvm_ir::IrArray::Index array_index(num_dims); - - // Add loops from outer-most to inner-most dimensions. - for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - const int64 dimension = target_shape.layout().minor_to_major(i); - const int bounds_index = num_dims - 1 - i; - if (bounds_index < num_dynamic_loop_bounds_) { - // Emit dynamic loop bounds for this dimension. Dynamic loop bounds - // are read from ir function dynamic loop bounds argument. - llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0]; - llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1]; - - std::unique_ptr loop = loop_nest.AddLoop( - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), - start_index, end_index); - array_index[dimension] = loop->GetIndVarValue(); - } else { - // Emit static loop bounds for this dimension. - std::unique_ptr loop = loop_nest.AddLoop( - /*start_index=*/0, - /*end_index=*/target_shape.dimensions(dimension), - /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); - array_index[dimension] = loop->GetIndVarValue(); - } - } - // Point IR builder at inner loop BB. - SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - - // Emit loop body. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - element_generator(array_index)); - target_array->EmitWriteArrayElement(array_index, target_element, - &ir_builder_); - // Point IR builder at outer loop exit BB. - SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_); - - return Status::OK(); -} - Status IrEmitter::EmitMemcpy(const HloInstruction& source, const HloInstruction& destination) { llvm::Value* source_value = GetEmittedValueFor(&source); @@ -3204,8 +2906,8 @@ Status IrEmitter::EmitMemcpy(const HloInstruction& source, Status IrEmitter::ElementTypesSameAndSupported( const HloInstruction& instruction, - tensorflow::gtl::ArraySlice operands, - tensorflow::gtl::ArraySlice supported_types) { + gtl::ArraySlice operands, + gtl::ArraySlice supported_types) { for (auto operand : operands) { TF_RET_CHECK( ShapeUtil::SameElementType(operands[0]->shape(), operand->shape())); @@ -3249,37 +2951,5 @@ StatusOr IrEmitter::EmitScalarCall( ShapeUtil::MakeShape(return_type, {}), argument_addrs, name); } - -unsigned TargetMachineFeatures::largest_register_size_in_bytes( - llvm::Function* function) { - auto itr = largest_register_size_in_bytes_.find(function); - if (itr != largest_register_size_in_bytes_.end()) { - return itr->second; - } - - int result = largest_register_size_in_bytes_impl(function); - - InsertOrDie(&largest_register_size_in_bytes_, function, result); - DCHECK_EQ(result, largest_register_size_in_bytes_.begin()->second); - return result; -} - -unsigned TargetMachineFeatures::largest_register_size_in_bytes_impl( - llvm::Function* function) const { - auto register_info = - target_machine_->getSubtargetImpl(*function)->getRegisterInfo(); - - unsigned largest_register_size = 0; - for (const llvm::TargetRegisterClass* register_class : - register_info->regclasses()) { - if (register_class->isAllocatable()) { - largest_register_size = - std::max(largest_register_size, - register_info->getRegSizeInBits(*register_class)); - } - } - - return largest_register_size / 8; -} } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 351c95278c17f536e56d9f085b938a9baea9cde1..509440251497cd7337284c39dae05c5f6c28e7c2 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include #include @@ -30,6 +31,8 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -49,49 +52,6 @@ limitations under the License. namespace xla { namespace cpu { - -// Wraps an llvm::TargetMachine and parses out some information that feeds into -// code LLVM IR generation decisions. -// -// Ideally we'd be able to use llvm::TargetTransformInfo here (since its -// interface is pretty much a perfect fit for our use case), but obtaining an -// instance of llvm::TargetTransformInfo outside an LLVM pass pipeline without -// super-ugly hacks is difficult. -// -// TODO(b/66049221): See if the LLVM community will be receptive to exposing an -// API that lets us directly create and use llvm::TargetTransformInfo instances -// outside of a pass manager. -class TargetMachineFeatures { - public: - TargetMachineFeatures(llvm::TargetMachine* target_machine) - : target_machine_(target_machine) {} - - // Return the vectorization factor, which is the number of bytes of data - // explicitly vectorized routines will try to process at once. - int vectorization_factor_in_bytes() const { - // Ideally this should be a function of the cache line size (which we can - // get from llvm::TargetTransformInfo::getCacheLineSize) of the target - // machine. Guess a value of 128 bytes for now. - return 128; - } - - // Return the size of the largest register size in bytes. We need to pass in - // "function" since llvm functions can contain annotations for specializing - // them to specific micro-architectures (though currently XLA does not use - // this functionality). - // - // Ideally we should have been able to use - // llvm::TargetTransformInfo::getRegisterBitWidth(true) here. - unsigned largest_register_size_in_bytes(llvm::Function* function); - - private: - unsigned largest_register_size_in_bytes_impl(llvm::Function* function) const; - - tensorflow::gtl::FlatMap - largest_register_size_in_bytes_; - llvm::TargetMachine* target_machine_; -}; - // This class is the top-level API for the XLA HLO --> LLVM IR compiler. It // implements the DfsHloVisitor interface and emits HLO computations as LLVM IR // functions. @@ -103,20 +63,21 @@ class IrEmitter : public DfsHloVisitorWithDefault { // assignment: a BufferAssignment from which we know which temporary buffers // are used by the HLO nodes. // llvm_module: the LLVM module to emit IR into. - // hlo_to_profile_idx: the mapping from HLO to its index in the profiling - // array. - // entry_computation_profile_idx: the index in the profiling array - // for the entry computation. + // instruction_to_profile_idx: the mapping from HLO instructions to their + // index in the profiling array. + // computation_to_profile_idx: the mapping from HLO computations to their + // index in the profiling array. // external_constant_pool: if non-null, points to an ExternalConstantPool // instance into which the Ir emitter can spill // constants. - IrEmitter( - const HloModule& hlo_module, const BufferAssignment& assignment, - llvm::Module* llvm_module, - std::unordered_map hlo_to_profile_idx, - tensorflow::gtl::optional entry_computation_profile_idx, - llvm::TargetMachine* target_machine, - ExternalConstantPool* external_constant_pool); + IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, + llvm::Module* llvm_module, + std::unordered_map + instruction_to_profile_idx, + std::unordered_map + computation_to_profile_idx, + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -163,8 +124,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleSelect(HloInstruction* select) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; @@ -189,6 +149,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConditional(HloInstruction* conditional) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; @@ -198,14 +159,23 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Private helper to initialize an IR function for the computation. void InitializeIrFunction(const string& function_name); - // Convenience function to generate a GEP into the profile counter parameter - // which would correspond to the index for a given HLO. - llvm::Value* GetProfileCounterFor(const HloInstruction& hlo); + template + llvm::Value* GetProfileCounterCommon( + const T& hlo, + const std::unordered_map& profile_index_map); + + // Convenience functions to generate a GEP into the profile counter parameter + // which would correspond to the index for a given HLO instruction or + // computation. + llvm::Value* GetProfileCounterFor(const HloInstruction& instruction) { + return GetProfileCounterCommon(instruction, + instruction_to_profile_idx_); + } - // Convenience function to generate a GEP into the profile counter parameter - // corresponding to the index for the entry computation. Returns nullptr if - // profiling the entry computation is disabled. - llvm::Value* GetProfileCounterForEntryComputation(); + llvm::Value* GetProfileCounterFor(const HloComputation& computation) { + return GetProfileCounterCommon(computation, + computation_to_profile_idx_); + } // Gets the IR Value emitted previously for the given hlo. // @@ -233,16 +203,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); - // Returns an array of compute function parameter types. - std::vector GetComputeFunctionParams(); - - // Get the llvm::Value* that represents the "retval" argument of the - // computation function being emitted by this emitter. - llvm::Argument* GetResultArgument(); - // Get the llvm::Value* that represents the "prof_counters" argument of the // computation function being emitted by this emitter. - llvm::Argument* GetProfileCountersArgument(); + llvm::Value* GetProfileCountersArgument(); // Get the xla::ExecutableRunOptions that represents the "run_options" // argument of the computation function being emitted by this emitter. @@ -252,11 +215,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { // computation function being emitted by this emitter. llvm::Value* GetTempBuffersArgument(); - // Emit ir to read and return the ir value for the dynamic loop bound at - // 'offset' from the "dynamic_loop_bounds" argument of the computation - // function being emitted by this emitter. - llvm::Value* GetDynamicLoopBound(const int64 offset); - // Emits code that computes the address of the given temporary buffer to the // function. target_shape is the shape of this temporary buffer. // The returned Value's type is a pointer to element_type. @@ -310,18 +268,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); - // Returns an array of compute function call arguments. - std::vector GetArrayFunctionCallArguments( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value_buffer, tensorflow::StringPiece name); - - // Emits a call to a runtime fork/join function which dispatches parallel - // calls to 'parallel_function' (and joins threads before returning). - Status EmitParallelForkJoin( - tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* output_address, HloComputation* computation, - llvm::Function* parallel_function); - // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -346,15 +292,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { HloInstruction* target_op, tensorflow::StringPiece desc, const llvm_ir::ElementGenerator& element_generator); - // Emit IR to perform a computation for every element in a partition/slice of - // 'target_shape'. The loop bounds for the outer-dimension partitions are - // passed into the compute function as a runtime argument (accessible from - // GetDynamicLoopBound). - Status EmitParallelTargetElementLoop( - const Shape& target_shape, - const llvm_ir::ElementGenerator& element_generator, - tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array); - // Emits a memcpy from the source instruction's result value to the // destination's. Both source and destination must have an entry in the // emitted_value_ table. @@ -476,13 +413,19 @@ class IrEmitter : public DfsHloVisitorWithDefault { thread_local_buffers_; // The following fields track the IR emission state. According to LLVM memory - // management rules, their memory is owned by the module. - llvm::Function* compute_function_; + // management rules, their memory is owned by the module (Note that IrFunction + // creates the encapsulated llvm::Function s.t. it is added to the llvm + // module's function list). + std::unique_ptr compute_function_; llvm::IRBuilder<> ir_builder_; - // Maps HLOs to their index into the profile counter array. - std::unordered_map hlo_to_profile_idx_; - const tensorflow::gtl::optional entry_computation_profile_idx_; + // Maps HLO instructions to their index into the profile counter array. + const std::unordered_map + instruction_to_profile_idx_; + + // Maps HLO computations to their index into the profile counter array. + const std::unordered_map + computation_to_profile_idx_; // Maps HLOs to Values emitted for them. std::unordered_map emitted_value_; @@ -490,7 +433,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm_ir::AliasAnalysis alias_analysis_; // The number of root instruction outer dimensions used in parallel loop - // emission (EmitParallelTargetElementLoop). + // emission (ParallelLoopEmitter). int64 num_dynamic_loop_bounds_ = 0; // Returns whether the given instruction should be emitted as a parallel loop. @@ -505,15 +448,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // profiling a computation. class ProfilingState { public: - ProfilingState() - : is_top_level_computation_(false), - use_rdtscp_(false), - prof_counters_(nullptr) {} - ProfilingState(bool is_top_level_computation, bool use_rdtscp, - llvm::Argument* prof_counters) - : is_top_level_computation_(is_top_level_computation), - use_rdtscp_(use_rdtscp), - prof_counters_(prof_counters) {} + ProfilingState() : use_rdtscp_(false), prof_counters_(nullptr) {} + ProfilingState(bool use_rdtscp, llvm::Value* prof_counters) + : use_rdtscp_(use_rdtscp), prof_counters_(prof_counters) {} // Record the cycle counter before an HLO executes. void RecordCycleStart(llvm::IRBuilder<>* ir_builder, HloInstruction* hlo); @@ -535,15 +472,12 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* cycle_start); private: - // Is this IrEmitter for a top-level computation? - bool is_top_level_computation_; - // Should we use the x86-specific rdtscp or the generic readcyclecounter // intrinsic? bool use_rdtscp_; // The argument which corresponds to the profile counter buffer. - llvm::Argument* prof_counters_; + llvm::Value* prof_counters_; // The first read cycle counter in the program. llvm::Value* first_read_cycle_start_ = nullptr; diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d6f2f3818a7bd4424aaa7d918ca86abef15c0e9 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc @@ -0,0 +1,333 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/service/cpu/ir_function.h" + +#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/status_macros.h" + +namespace xla { + +namespace { +using llvm_ir::AsStringRef; +} // namespace + +namespace cpu { + +static std::vector GetComputeFunctionParams( + llvm::Module* llvm_module, const int64 num_dynamic_loop_bounds) { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(llvm_module->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = + llvm::Type::getInt64PtrTy(llvm_module->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds > 0) { + compute_function_params.push_back(i64_ptr_type); + } + compute_function_params.push_back(i64_ptr_type); + return compute_function_params; +} + +IrFunction::IrFunction(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, + int64 num_dynamic_loop_bounds) + : ir_builder_(ir_builder), + llvm_module_(llvm_module), + caller_insert_point_guard_(*ir_builder), + num_dynamic_loop_bounds_(num_dynamic_loop_bounds) { + Initialize(function_name, linkage, optimize_for_size_requested, + enable_fast_math); +} + +IrFunction::~IrFunction() { + // Emit function return value. + ir_builder_->CreateRetVoid(); +} + +DynamicLoopBounds IrFunction::GetDynamicLoopBounds() { + DynamicLoopBounds dynamic_loop_bounds(num_dynamic_loop_bounds_); + for (int i = 0; i < num_dynamic_loop_bounds_; ++i) { + dynamic_loop_bounds[i].first = GetDynamicLoopBound(i * 2 + 0); + dynamic_loop_bounds[i].second = GetDynamicLoopBound(i * 2 + 1); + } + return dynamic_loop_bounds; +} + +void IrFunction::Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math) { + // The function signature is: + // void function(i8* retval, i8* run_options, i8** params, i8** temps, + // i64* dynamic_loop_bounds, i64* prof_counters) + // + // retval: points to the returned value. + // params: address of an array with pointers to parameters. + // temps: address of an array with pointers to temporary buffers. + // + // Therefore, the generated function's signature (FunctionType) is statically + // determined - parameter unpacking is done in code generated into the + // function, rather than by a prologue dictated by the platform ABI. + // + // /--------------\ + // retval ----------> | return value | + // \--------------/ + // + // /-------------------------------\ + // run_options -----> | xla::ExecutableRunOptions | + // \-------------------------------/ + // + // /---------------------------------------------\ + // params --------> | param 0 | param 1 | ..... | param N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | param 0 | | param 1 | | param N-1 | + // \---------/ \---------/ \-----------/ + // + // /---------------------------------------------\ + // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 | + // | addr | addr | | addr | + // \---------------------------------------------/ + // | | | + // | | | + // V V V + // /---------\ /---------\ /-----------\ + // | temp 0 | | temp 1 | | temp N-1 | + // \---------/ \---------/ \-----------/ + // + // /--------------------------------------------\ + // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....| + // (elided for aot) \--------------------------------------------/ + // + // /---------------------------------------------\ + // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 | + // \---------------------------------------------/ + + // Even though the type of params and temps is void** in the host's view, in + // LLVM IR this is represented by i8*, similarly to void*. It's up to the code + // to use GEPs to unravel the indirection layers. + llvm::FunctionType* function_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()), + /*Params=*/ + GetComputeFunctionParams(llvm_module_, num_dynamic_loop_bounds_), + /*isVarArg=*/false); + + // Functions with local linkage get an inlining bonus. Because we know + // a-priori that embedded functions (non-entry functions) will not have its + // name resolved, give it local linkage. + function_ = + llvm_ir::CreateFunction(function_type, linkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size_requested, + function_name, llvm_module_); + + // Set meaningful names for the function's arguments: useful for debugging. + llvm::Function::arg_iterator arg_iter = function_->arg_begin(); + arg_iter->setName("retval"); + result_arg_ = &*arg_iter; + (++arg_iter)->setName("run_options"); + exec_run_options_arg_ = &*arg_iter; + (++arg_iter)->setName("params"); + parameters_arg_ = &*arg_iter; + (++arg_iter)->setName("temps"); + temp_buffers_arg_ = &*arg_iter; + if (num_dynamic_loop_bounds_ > 0) { + (++arg_iter)->setName("dynamic_loop_bounds"); + dynamic_loop_bounds_arg_ = &*arg_iter; + } + (++arg_iter)->setName("prof_counters"); + profile_counters_arg_ = &*arg_iter; + + // We know a-priori that the function arguments are guaranteed to point to + // disjoint objects. + llvm::Argument* retval = result_arg(); + for (llvm::Argument& argument : function_->args()) { + // However, the return buffer aliases the temporaries and thus cannot be + // marked noalias. + if (&argument == retval) { + continue; + } + function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias); + } + + ir_builder_->SetInsertPoint(llvm::BasicBlock::Create( + /*Context=*/llvm_module_->getContext(), + /*Name=*/"entry", + /*Parent=*/function_)); +} + +llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) { + CHECK_GT(num_dynamic_loop_bounds_, 0); + CHECK_LT(offset, num_dynamic_loop_bounds_ * 2); + string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset); + return ir_builder_->CreateLoad( + ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_), + ir_builder_->getInt64(offset), AsStringRef(name))); +} + +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg) { + llvm::Value* parameter_addresses_buffer = + llvm_ir::EmitAllocaAtFunctionEntryWithCount( + ir_builder->getInt8PtrTy(), + ir_builder->getInt32(parameter_addresses.size()), + tensorflow::strings::StrCat(name, "_parameter_addresses"), + ir_builder); + for (size_t i = 0; i < parameter_addresses.size(); ++i) { + llvm::Value* parameter_as_i8ptr = ir_builder->CreateBitCast( + parameter_addresses[i], ir_builder->getInt8PtrTy(), + AsStringRef(tensorflow::strings::StrCat(name, "_parameter_", i, + "_address_as_i8ptr"))); + llvm::Value* slot_in_param_addresses = ir_builder->CreateInBoundsGEP( + parameter_addresses_buffer, {ir_builder->getInt64(i)}); + ir_builder->CreateStore(parameter_as_i8ptr, slot_in_param_addresses); + } + + const auto to_int8_ptr = [=](llvm::Value* ptr) { + return ir_builder->CreatePointerCast(ptr, ir_builder->getInt8PtrTy()); + }; + std::vector arguments{ + to_int8_ptr(return_value_buffer), to_int8_ptr(exec_run_options_arg), + parameter_addresses_buffer, temp_buffers_arg}; + if (profile_counters_arg != nullptr) { + arguments.push_back(profile_counters_arg); + } + return arguments; +} + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = + GetComputeFunctionParams(module, /*num_dynamic_loop_bounds=*/0); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module->getContext())); + // Number of partitioned most-major dimensions in 'shape'. + compute_function_params.push_back(ir_builder->getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + std::vector fork_join_arguments(arguments); + + // Create ShapePartitionIterator to generate all partitions of 'shape'. + ShapePartitionIterator partition_iterator(shape, dimension_partition_counts); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'shape'. + const int32 num_partitioned_dims = dimension_partition_counts.size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder->getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder->getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder->getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*M=*/*module, + /*Ty=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + fork_join_arguments.push_back(ir_builder->CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + fork_join_arguments.push_back(ir_builder->getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + fork_join_arguments.push_back( + ir_builder->CreateBitCast(parallel_function, ir_builder->getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder->CreateCall(fork_join_func, fork_join_arguments); + + return Status::OK(); +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h new file mode 100644 index 0000000000000000000000000000000000000000..557aa4a6bfc2ef70cafca4b226f8d8f15ea01e2b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/ir_function.h @@ -0,0 +1,134 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ + +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { +namespace cpu { + +// IrFunction creates and encapsulates an llvm::Function, exposing methods to +// emitters for function and function argument access. +// The llvm::Function is created with the standard function signature +// used in the XLA CPU backend (see ir_function.cc for argument details). +// In addtion IrFunction saves the callers IR insert point during contruction, +// and restores it after desctruction. +// +// Example usage: +// +// // Create and initialize new IrFunction. +// std::unique_ptr compute_function(new IrFunction(...)); +// // Emit IR for function body using IrFunction helper methods. +// ... +// // Store reference to llvm::Function for future invocation. +// ir_functions.push_back(compute_function.function()); +// // Delete IrFunction (finalizes IR function and restores caller insertion +// // point). +// compute_function.reset(); +// + +class IrFunction { + public: + IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage, + const bool optimize_for_size_requested, + const bool enable_fast_math, llvm::Module* llvm_module, + llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds); + ~IrFunction(); + + // Emit ir to read and return the set of ir values representing the dynamic + // loop bounds argument of this function. + // Each element in returned vector is a pair of ir values representing + // the loop bounds for a specific dimension, where the first element of the + // pair is the dimension start index, and the second element of the pair + // is the dimension limit. + // EX: [dimension_i_index_start_ir_value, dimension_i_index_limit_ir_value] + // + DynamicLoopBounds GetDynamicLoopBounds(); + + // Returns the encapculated llvm::Function. + llvm::Function* function() { return function_; } + + // Get the llvm::Value* that represents this functions "retval" argument. + llvm::Argument* result_arg() { return result_arg_; } + + // Get the xla::ExecutableRunOptions that represents this functions + // "run_options" argument. + llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; } + + // Get the llvm::Value* that represents this functions parameters argument. + llvm::Value* parameters_arg() { return parameters_arg_; } + + // Get the llvm::Value* that represents this functions "temps" argument. + llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; } + + // Get the llvm::Value* that represents this functions "prof_counters" + // argument. + llvm::Value* profile_counters_arg() { return profile_counters_arg_; } + + private: + // Initialize an llvm::Function with standard signature based on arguments. + void Initialize(const string& function_name, + llvm::Function::LinkageTypes linkage, + bool optimize_for_size_requested, bool enable_fast_math); + + // Emit ir to read and return the ir value for the dynamic loop bound at + // 'offset' from the "dynamic_loop_bounds" argument of this function. + llvm::Value* GetDynamicLoopBound(int64 offset); + + llvm::IRBuilder<>* ir_builder_; + llvm::Module* llvm_module_; + llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_; + + int64 num_dynamic_loop_bounds_ = 0; + // Encapsulated llvm::Function. + llvm::Function* function_; + // Function argument IR values. + llvm::Argument* result_arg_; + llvm::Value* exec_run_options_arg_; + llvm::Value* parameters_arg_; + llvm::Value* temp_buffers_arg_; + llvm::Value* dynamic_loop_bounds_arg_ = nullptr; + llvm::Value* profile_counters_arg_; +}; + +// Returns an array of compute function call argument ir values. +std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece name, + llvm::Value* return_value_buffer, llvm::Value* exec_run_options_arg, + llvm::Value* temp_buffers_arg, llvm::Value* profile_counters_arg); + +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status EmitCallToParallelForkJoin( + const std::vector& arguments, const Shape& shape, + const std::vector& dimension_partition_counts, + llvm::IRBuilder<>* ir_builder, llvm::Function* parallel_function, + const string& name); + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index 81c29e4726c7be53b433be896f558f502e43c885..2e5cc96098241415b82f225afc81981f3e1069e0 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -20,6 +20,8 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Verifier.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/platform/logging.h" namespace xla { @@ -28,6 +30,10 @@ namespace runtime { const char* const kTanhV4F32SymbolName = "__xla_cpu_runtime_TanhV4F32"; const char* const kTanhV8F32SymbolName = "__xla_cpu_runtime_TanhV8F32"; +const char* const kExpV4F32SymbolName = "__xla_cpu_runtime_ExpV4F32"; +const char* const kExpV8F32SymbolName = "__xla_cpu_runtime_ExpV8F32"; +const char* const kLogV4F32SymbolName = "__xla_cpu_runtime_LogV4F32AVX"; +const char* const kLogV8F32SymbolName = "__xla_cpu_runtime_LogV8F32AVX"; namespace { llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, @@ -42,62 +48,257 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, } llvm::LLVMContext* context = &module->getContext(); - llvm::Type* float_type = llvm::Type::getFloatTy(*context); - llvm::VectorType* vector_type = - llvm::VectorType::get(float_type, vector_width); llvm::BasicBlock* vector_tanh_body = llvm::BasicBlock::Create(*context, "body", vector_tanh_function); llvm::IRBuilder<> ir_builder(vector_tanh_body); - llvm::FastMathFlags fast_math_flags; fast_math_flags.setFast(); ir_builder.setFastMathFlags(fast_math_flags); + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32"); + llvm::Value* input = &*vector_tanh_function->arg_begin(); - CHECK_EQ(input->getType(), vector_type); + CHECK_EQ(input->getType(), vsl.vector_type()); // This implements the same rational interpolant as implemented in Eigen3. - llvm::Value* input_clamped = llvm_ir::EmitFloatMin( - llvm_ir::EmitFloatMax(input, llvm::ConstantFP::get(vector_type, -9.0), - &ir_builder), - llvm::ConstantFP::get(vector_type, 9.0), &ir_builder); - - std::array numerator_coeffs( - {{-2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, - 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, - 4.89352455891786e-03f}}); - - std::array denominator_coeffs( - {{1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, - 4.89352518554385e-03f}}); - - llvm::Value* input_squared = - ir_builder.CreateFMul(input_clamped, input_clamped); - llvm::Value* numerator = - llvm::ConstantFP::get(vector_type, numerator_coeffs[0]); + llvm::Value* input_clamped = + vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0)); + + std::array numerator_coeffs{ + -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f, + 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f, + 4.89352455891786e-03f}; + + std::array denominator_coeffs{ + 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f, + 4.89352518554385e-03f}; + + llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped); + llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0])); for (int i = 1; i < numerator_coeffs.size(); i++) { - numerator = ir_builder.CreateFAdd( - ir_builder.CreateFMul(input_squared, numerator), - llvm::ConstantFP::get(vector_type, numerator_coeffs[i])); + numerator = + vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i])); } - numerator = ir_builder.CreateFMul(input_clamped, numerator); - llvm::Value* denominator = - llvm::ConstantFP::get(vector_type, denominator_coeffs[0]); + numerator = vsl.Mul(input_clamped, numerator); + + llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0])); for (int i = 1; i < denominator_coeffs.size(); i++) { - denominator = ir_builder.CreateFAdd( - ir_builder.CreateFMul(input_squared, denominator), - llvm::ConstantFP::get(vector_type, denominator_coeffs[i])); + denominator = vsl.MulAdd(input_squared, denominator, + GetIeeeF32(denominator_coeffs[i])); } - llvm::Value* result = ir_builder.CreateFDiv(numerator, denominator); + llvm::Value* result = vsl.Div(numerator, denominator); ir_builder.CreateRet(result); DCHECK(!llvm::verifyFunction(*vector_tanh_function)); return vector_tanh_function; } + +llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module, + llvm::StringRef function_name, + int vector_width, + bool enable_fast_math) { + llvm::Function* vector_exp_function = module->getFunction(function_name); + if (vector_exp_function == nullptr) { + // If the function declaration is not present in the module, there can't be + // any calls to resolve. Don't emit the function in this case. + return nullptr; + } + + llvm::LLVMContext* context = &module->getContext(); + + llvm::BasicBlock* vector_exp_body = + llvm::BasicBlock::Create(*context, "body", vector_exp_function); + + llvm::IRBuilder<> ir_builder(vector_exp_body); + llvm::FastMathFlags fast_math_flags; + fast_math_flags.setFast(); + ir_builder.setFastMathFlags(fast_math_flags); + + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32"); + + // This implements the same polynomial approximation as implemented in Eigen3. + + const llvm::APFloat half = GetIeeeF32(0.5); + const llvm::APFloat one = GetIeeeF32(1.0); + + const llvm::APFloat exp_hi = GetIeeeF32(88.3762626647950); + const llvm::APFloat exp_lo = GetIeeeF32(-88.3762626647949); + + const llvm::APFloat cephes_LOG2EF = GetIeeeF32(1.44269504088896341); + const llvm::APFloat cephes_exp_C1 = GetIeeeF32(0.693359375); + const llvm::APFloat cephes_exp_C2 = GetIeeeF32(-2.12194440e-4); + + const llvm::APFloat cephes_exp_p0 = GetIeeeF32(1.9875691500E-4); + const llvm::APFloat cephes_exp_p1 = GetIeeeF32(1.3981999507E-3); + const llvm::APFloat cephes_exp_p2 = GetIeeeF32(8.3334519073E-3); + const llvm::APFloat cephes_exp_p3 = GetIeeeF32(4.1665795894E-2); + const llvm::APFloat cephes_exp_p4 = GetIeeeF32(1.6666665459E-1); + const llvm::APFloat cephes_exp_p5 = GetIeeeF32(5.0000001201E-1); + + llvm::Value* input = &*vector_exp_function->arg_begin(); + llvm::Value* input_clamped = + vsl.Clamp(input, /*low=*/exp_lo, /*high=*/exp_hi); + llvm::Value* fx = vsl.Floor(vsl.MulAdd(input_clamped, cephes_LOG2EF, half)); + llvm::Value* tmp = vsl.Mul(cephes_exp_C1, fx); + llvm::Value* z = vsl.Mul(cephes_exp_C2, fx); + llvm::Value* x = vsl.Sub(input_clamped, tmp); + x = vsl.Sub(x, z); + z = vsl.Mul(x, x); + + llvm::Value* y = vsl.MulAdd(x, cephes_exp_p0, cephes_exp_p1); + y = vsl.MulAdd(y, x, cephes_exp_p2); + y = vsl.MulAdd(y, x, cephes_exp_p3); + y = vsl.MulAdd(y, x, cephes_exp_p4); + y = vsl.MulAdd(y, x, cephes_exp_p5); + y = vsl.MulAdd(y, z, x); + y = vsl.Add(one, y); + + // VectorSupportLibrary (intentionally) can't juggle more than one type at a + // time so drop down to IRBuilder for this bit. + llvm::Value* vector_constant_0x7f = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + llvm::Value* vector_constant_23 = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + llvm::Type* i32_vector_type = + llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + // fx is clamped so we don't have to worry about it being out of range for + // i32. + llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type); + emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f); + emm0 = ir_builder.CreateShl(emm0, vector_constant_23); + llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type()); + + llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input); + + ir_builder.CreateRet(result); + + DCHECK(!llvm::verifyFunction(*vector_exp_function)); + return vector_exp_function; +} + +llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module, + llvm::StringRef function_name, + int vector_width, + bool enable_fast_math) { + llvm::Function* vector_log_function = module->getFunction(function_name); + if (vector_log_function == nullptr) { + // If the function declaration is not present in the module, there can't be + // any calls to resolve. Don't emit the function in this case. + return nullptr; + } + + llvm::LLVMContext* context = &module->getContext(); + + llvm::BasicBlock* vector_log_body = + llvm::BasicBlock::Create(*context, "body", vector_log_function); + + llvm::IRBuilder<> ir_builder(vector_log_body); + llvm::FastMathFlags fast_math_flags; + fast_math_flags.setFast(); + ir_builder.setFastMathFlags(fast_math_flags); + + llvm::Value* input = &*vector_log_function->arg_begin(); + VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32"); + + const llvm::APFloat half = GetIeeeF32(0.5); + const llvm::APFloat one = GetIeeeF32(1.0); + + // This implements the same polynomial approximation as implemented in Eigen3. + // Returns NaN for x < 0, -INF for x = 0 + const llvm::APFloat cephes_SQRTHF = GetIeeeF32(0.707106781186547524); + const llvm::APFloat cephes_log_p0 = GetIeeeF32(7.0376836292E-2); + const llvm::APFloat cephes_log_p1 = GetIeeeF32(-1.1514610310E-1); + const llvm::APFloat cephes_log_p2 = GetIeeeF32(1.1676998740E-1); + const llvm::APFloat cephes_log_p3 = GetIeeeF32(-1.2420140846E-1); + const llvm::APFloat cephes_log_p4 = GetIeeeF32(+1.4249322787E-1); + const llvm::APFloat cephes_log_p5 = GetIeeeF32(-1.6668057665E-1); + const llvm::APFloat cephes_log_p6 = GetIeeeF32(+2.0000714765E-1); + const llvm::APFloat cephes_log_p7 = GetIeeeF32(-2.4999993993E-1); + const llvm::APFloat cephes_log_p8 = GetIeeeF32(+3.3333331174E-1); + const llvm::APFloat cephes_log_q1 = GetIeeeF32(-2.12194440e-4); + const llvm::APFloat cephes_log_q2 = GetIeeeF32(0.693359375); + + // The smallest non denormalized float number. + const llvm::APFloat min_norm_pos = GetIeeeF32FromBitwiseRep(0x00800000); + const llvm::APFloat minus_inf = GetIeeeF32FromBitwiseRep(0xff800000); + const llvm::APFloat inv_mant_mask = GetIeeeF32FromBitwiseRep(~0x7f800000); + + // invalid_mask is set if x is negative or NaN (and therefore output + // must be NaN). + llvm::Value* invalid_mask = vsl.FCmpULEMask(input, vsl.GetZeroVector()); + llvm::Value* iszero_mask = vsl.FCmpEQMask(input, vsl.GetZeroVector()); + + // Cut off denormalized stuff. + input = vsl.Max(min_norm_pos, input); + + // VectorSupportLibrary (intentionally) can't juggle more than one type at a + // time so drop down to IRBuilder for this bit. + llvm::Value* vector_constant_0x7f = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f)); + llvm::Value* vector_constant_23 = + ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23)); + llvm::Type* i32_vector_type = + llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width); + + llvm::Value* emm0 = ir_builder.CreateLShr( + ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23); + + // Keep only the fractional part. + input = vsl.FloatAnd(input, inv_mant_mask); + input = vsl.FloatOr(input, half); + + emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f); + llvm::Value* e = + vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type())); + + // part2: + // if( x < SQRTHF ) { + // e -= 1; + // x = x + x - 1.0; + // } else { x = x - 1.0; } + llvm::Value* mask = vsl.FCmpOLTMask(input, cephes_SQRTHF); + llvm::Value* tmp = vsl.FloatAnd(input, mask); + input = vsl.Sub(input, one); + e = vsl.Sub(e, vsl.FloatAnd(mask, one)); + input = vsl.Add(input, tmp); + + llvm::Value* x2 = vsl.Mul(input, input); + llvm::Value* x3 = vsl.Mul(x2, input); + + llvm::Value *y, *y1, *y2; + y = vsl.MulAdd(input, cephes_log_p0, cephes_log_p1); + y1 = vsl.MulAdd(input, cephes_log_p3, cephes_log_p4); + y2 = vsl.MulAdd(input, cephes_log_p6, cephes_log_p7); + y = vsl.MulAdd(y, input, cephes_log_p2); + y1 = vsl.MulAdd(y1, input, cephes_log_p5); + y2 = vsl.MulAdd(y2, input, cephes_log_p8); + y = vsl.MulAdd(y, x3, y1); + y = vsl.MulAdd(y, x3, y2); + y = vsl.Mul(y, x3); + + y1 = vsl.Mul(cephes_log_q1, e); + tmp = vsl.Mul(half, x2); + y = vsl.Add(y, y1); + input = vsl.Sub(input, tmp); + y2 = vsl.Mul(cephes_log_q2, e); + input = vsl.Add(input, y); + input = vsl.Add(input, y2); + + // Negative arg will be NAN, 0 will be -INF. + llvm::Value* or_lhs = + vsl.FloatAndNot(iszero_mask, vsl.FloatOr(input, invalid_mask)); + llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf); + llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs); + + ir_builder.CreateRet(result); + + DCHECK(!llvm::verifyFunction(*vector_log_function)); + return vector_log_function; +} } // namespace void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { @@ -108,11 +309,28 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { EmitVectorF32TanhIfNeeded(module, kTanhV8F32SymbolName, /*vector_width=*/8, enable_fast_math); + auto* exp_v4f32 = + EmitVectorF32ExpIfNeeded(module, kExpV4F32SymbolName, + /*vector_width=*/4, enable_fast_math); + auto* exp_v8f32 = + EmitVectorF32ExpIfNeeded(module, kExpV8F32SymbolName, + /*vector_width=*/8, enable_fast_math); + + auto* log_v4f32 = + EmitVectorF32LogIfNeeded(module, kLogV4F32SymbolName, + /*vector_width=*/4, enable_fast_math); + auto* log_v8f32 = + EmitVectorF32LogIfNeeded(module, kLogV8F32SymbolName, + /*vector_width=*/8, enable_fast_math); + // Gather all the call sites, force inline them and then delete the vector // function bodies. + // + // TODO(b/73081976): Should we avoid inlining these intrinsics in some cases? std::vector calls_to_inline; - for (auto* function : {tanh_v4f32, tanh_v8f32}) { + for (auto* function : + {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { if (function != nullptr) { for (auto* user : function->users()) { calls_to_inline.push_back(llvm::cast(user)); @@ -125,7 +343,8 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math) { CHECK(llvm::InlineFunction(call_to_inline, inline_function_info)); } - for (auto* function : {tanh_v4f32, tanh_v8f32}) { + for (auto* function : + {tanh_v4f32, tanh_v8f32, exp_v4f32, exp_v8f32, log_v4f32, log_v8f32}) { if (function != nullptr) { function->eraseFromParent(); } diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h index 7f31fb98b0d03c16ef40bff9822227e01f6be46b..5553972677512617ccb6ac4f57a4d33400b664e3 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ #include "llvm/IR/Module.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" @@ -25,6 +25,10 @@ namespace runtime { extern const char* const kTanhV4F32SymbolName; extern const char* const kTanhV8F32SymbolName; +extern const char* const kExpV4F32SymbolName; +extern const char* const kExpV8F32SymbolName; +extern const char* const kLogV4F32SymbolName; +extern const char* const kLogV8F32SymbolName; // The following CPU runtime functions have LLVM-IR only implementations: // @@ -40,4 +44,4 @@ void RewriteIRRuntimeFunctions(llvm::Module* module, bool enable_fast_math); } // namespace cpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTINE_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_LLVM_IR_RUNTIME_H_ diff --git a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h index 2d29550fd5bd659770cc6300e56b57bf1763e671..f8963841158b71a30aa926e3b2b153c42bf78eb1 100644 --- a/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h +++ b/tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ #include @@ -53,4 +53,4 @@ class Registrar { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_ORC_JIT_MEMORY_MAPPER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index 0077e344e2bd34aa598ee076220fee678f31b4ad..cd997f07890cdc1d9a546ede58cc1d992b6416ae 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -61,9 +61,9 @@ ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr> function_names, std::unordered_map> aligned_constants, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), jit_(std::move(jit)), assignment_(std::move(assignment)), @@ -376,19 +376,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( tensorflow::gtl::ArraySlice arguments, tensorflow::gtl::ArraySlice buffers, HloExecutionProfile* hlo_execution_profile) { - std::vector argument_buffers(arguments.size()); - for (int i = 0; i < arguments.size(); ++i) { - argument_buffers[i] = arguments[i]->buffer(/*index=*/{}); - } - return ExecuteComputeFunctions(run_options, argument_buffers, buffers, - hlo_execution_profile); -} - -Status ParallelCpuExecutable::ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - tensorflow::gtl::ArraySlice buffers, - HloExecutionProfile* hlo_execution_profile) { // Allocate profiling counters for each hlo instruction that we would like to // profile. std::vector* profile_counters = nullptr; @@ -428,8 +415,9 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( // just copy the existing buffer into the map containing instruction // results.. if (instruction->opcode() == HloOpcode::kParameter) { - InsertOrDie(&results, instruction, - arguments[instruction->parameter_number()].opaque()); + InsertOrDie( + &results, instruction, + arguments[instruction->parameter_number()]->root_buffer().opaque()); } else if (instruction->opcode() == HloOpcode::kConstant) { unsigned char* aligned_data = FindOrDie(aligned_constants_, instruction).get(); @@ -461,69 +449,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( return Status::OK(); } -StatusOr -ParallelCpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - VLOG(3) << "ExecuteOnStream arg size: " << arguments.size(); - if (!arguments.empty()) { - VLOG(3) << "ExecuteOnStream arg[0]: " << arguments.at(0).opaque(); - } - - // Allocate the temporary buffers required for the computation. - se::StreamExecutor* stream_executor = stream->parent(); - int device_ordinal = stream_executor->device_ordinal(); - int64 buffer_count = assignment_->Allocations().size(); - VLOG(3) << "temp buffer count: " << buffer_count; - - std::vector device_allocations( - assignment_->Allocations().size()); - TF_RETURN_IF_ERROR(AllocateBuffers(memory_allocator, - stream->parent()->device_ordinal(), - &device_allocations)); - - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - const BufferAllocation::Index result_index = result_slice.index(); - VLOG(3) << "result index: " << result_index; - - TF_RETURN_IF_ERROR(ExecuteComputeFunctions( - run_options, arguments, device_allocations, hlo_execution_profile)); - - // Mark the buffers that are actually live (used in the output) when the - // computation finishes executing. - std::unordered_set marked_addresses; - MarkLiveAddressesInOutput(device_allocations[result_index].opaque(), - result_shape(), &marked_addresses); - - VLOG(3) << "Live addresses in output marking found " - << marked_addresses.size() << " addresses:\n" - << tensorflow::str_util::Join( - marked_addresses, ", ", [](string* out, const void* address) { - tensorflow::strings::StrAppend( - out, tensorflow::strings::Printf("%p", address)); - }); - - // Computation is done - deallocate temp buffers. Keep those marked - // live because they are referenced by the output of the computation - // and are needed by the service. They will be deallocated by the - // service. - for (size_t i = 0; i < device_allocations.size(); ++i) { - auto alloc = device_allocations[i]; - if (marked_addresses.count(alloc.opaque()) == 0 && - alloc.opaque() != nullptr) { - VLOG(3) << "ParallelCpuExecutable deallocating buffer #" << i << " [" - << alloc.opaque() << "]"; - TF_RETURN_IF_ERROR(memory_allocator->Deallocate(device_ordinal, &alloc)); - } - } - - return device_allocations[result_index]; -} - StatusOr> ParallelCpuExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -536,9 +461,9 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - auto result_buffer = - MakeUnique(result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal()); + auto result_buffer = MakeUnique( + /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(), + stream->parent()->platform(), stream->parent()->device_ordinal()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); @@ -549,37 +474,30 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( // Copy DeviceMemoryBase values which into the respective location in // ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); - TF_RETURN_IF_ERROR( - result_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffers, &buffers_in_result, &result_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - return Status::OK(); - })); + TF_RETURN_IF_ERROR(result_buffer->buffers().ForEachMutableElementWithStatus( + [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + + // The points to set is unambiguous so the set should be a singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer such as a + // tuple element. The source instruction should have a non-parameter + // buffer assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *device_memory = buffer; + buffers_in_result[buffer_index] = true; + return Status::OK(); + })); // Free all buffers not in the result. for (size_t i = 0; i < buffers.size(); ++i) { @@ -595,10 +513,10 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( return std::move(result_buffer); } -StatusOr +StatusOr> ParallelCpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on CPU."); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index d65e3f42f3cb34eff005f34b51b81fd5c42974a3..c393e9b8ea39bfb4c605ebba8e2cd29726bc4af9 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -55,25 +55,18 @@ class ParallelCpuExecutable : public Executable { std::unordered_map> aligned_constants, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); ~ParallelCpuExecutable() override {} - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; // This should be called after set_ir_module_string. const string& ir_module_string() const { return ir_module_string_; } @@ -108,13 +101,6 @@ class ParallelCpuExecutable : public Executable { // Calls the generated functions in 'function_names_', performing the // computation with the given arguments using the supplied buffers. - Status ExecuteComputeFunctions( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - tensorflow::gtl::ArraySlice - buffers, - HloExecutionProfile* hlo_execution_profile); Status ExecuteComputeFunctions( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc new file mode 100644 index 0000000000000000000000000000000000000000..1e439cde11cf74272101b80c867a308e51ab26a6 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc @@ -0,0 +1,76 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h" + +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace cpu { + +ParallelLoopEmitter::ParallelLoopEmitter( + const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, llvm::IRBuilder<>* ir_builder) + : LoopEmitter(target_element_generator, target_array, ir_builder), + dynamic_loop_bounds_(dynamic_loop_bounds) {} + +llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) { + CHECK(!ShapeUtil::IsTuple(shape_)); + CHECK(!ShapeUtil::IsScalar(shape_)); + + llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_); + const int64 num_dims = shape_.dimensions_size(); + llvm_ir::IrArray::Index array_index(num_dims); + + // Add loops from outer-most to inner-most dimensions. + for (int i = LayoutUtil::MinorToMajor(shape_).size() - 1; i >= 0; --i) { + const int64 dimension = LayoutUtil::Minor(shape_.layout(), i); + const int bounds_index = num_dims - 1 - i; + if (bounds_index < dynamic_loop_bounds_->size()) { + // Emit dynamic loop bounds for this dimension. Dynamic loop bounds + // are read from ir function dynamic loop bounds argument. + llvm::Value* start_index = (*dynamic_loop_bounds_)[bounds_index].first; + llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second; + + std::unique_ptr loop = loop_nest.AddLoop( + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension), + start_index, end_index); + array_index[dimension] = loop->GetIndVarValue(); + } else { + // Emit static loop bounds for this dimension. + std::unique_ptr loop = loop_nest.AddLoop( + /*start_index=*/0, + /*end_index=*/shape_.dimensions(dimension), + /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension)); + array_index[dimension] = loop->GetIndVarValue(); + } + } + // Point IR builder at inner loop BB. + llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), + ir_builder_); + + // Set exit_bb_ to the exit block of the loop nest. + exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock(); + CHECK(exit_bb_ != nullptr); + + return array_index; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..ce92e36a944de33b991d97460f0b2e859ad56081 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h @@ -0,0 +1,73 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" + +namespace xla { +namespace cpu { + +// ParallelLoopEmitter emits a loop nest for the target array shape. +// The outer loop bounds of the loop nest are passed as ir values at runtime +// (specified in 'dynamic_loop_bounds'), and the inner loop bounds are static. +// Dynamic loop bounds are specified as an array of dimension index +// [start, limit) pairs of ir values (one for each partitioned outer dimension). +// +// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major +// dimensions dynamic. Then 'dynamic_loop_bounds' will contain the +// following ir values for the two most-major dimensions: +// [dim0_index_start_ir_value, dim0_index_limit_ir_value] +// [dim1_index_start_ir_value, dim1_index_limit_ir_value] +// +// Code emitted by ParallelLoopEmitter will be called in a multi-threaded +// context where each thread will be assigned a different set of outer dimension +// partitions, and where all threads will collectively iterate over the +// entire target array shape. +// +// Outer dimension partitions can be generated using the ShapePartitionAssigner +// and ShapePartitionIterator utility classes from shape_partition.cc. +// +class ParallelLoopEmitter : public llvm_ir::LoopEmitter { + public: + // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to + // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the + // most-major dimensions, and 'target_array.' shape to set the static loop + // bounds for the most-minor dimensions. + ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator, + const llvm_ir::IrArray& target_array, + const DynamicLoopBounds* dynamic_loop_bounds, + llvm::IRBuilder<>* ir_builder); + + ParallelLoopEmitter(const ParallelLoopEmitter&) = delete; + ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete; + ~ParallelLoopEmitter() override = default; + + llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock( + tensorflow::StringPiece loop_name) override; + + private: + const DynamicLoopBounds* dynamic_loop_bounds_; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_ diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index 4b44ac8941e222d5954121bbb9654062e41f55d6..deb21bf4ef5895cfdbec5c2449b6ce7b306a7008 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -126,7 +126,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( HloInstruction* instruction) { // Currently, we do not assign parallel tasks to instructions with at least // one of the following properties: - // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Internal threading (library calls to kConv, kDot, kFft, kCustomCall). // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). // *) Tuple-shaped. // TODO(b/27458679) Parallelize instructions which are skipped here. @@ -137,6 +137,7 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( instruction->opcode() == HloOpcode::kSelectAndScatter || instruction->opcode() == HloOpcode::kGetTupleElement || instruction->opcode() == HloOpcode::kBitcast || + instruction->opcode() == HloOpcode::kFft || (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) || PotentiallyImplementedAsEigenDot(*instruction) || diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index 5801ec8d270cdaed7f2f65c24987a9ea643edb02..7140dabe516cd7ea9260456e994e8b63b68c60d6 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -99,4 +99,4 @@ class ParallelTaskAssigner : public HloPassInterface { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.cc b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc new file mode 100644 index 0000000000000000000000000000000000000000..848d2d22414e8fc9bca82de90f7676011d8992fd --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.cc @@ -0,0 +1,37 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" + +#define EIGEN_USE_THREADS + +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h" +#include "tensorflow/core/platform/dynamic_annotations.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; + +TF_ATTRIBUTE_NO_SANITIZE_MEMORY void __xla_cpu_runtime_EigenFft( + const void* run_options_ptr, void* out, void* operand, int32 fft_type, + int32 fft_rank, int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + tensorflow::xla::EigenFftImpl(*run_options->intra_op_thread_pool(), out, + operand, fft_type, fft_rank, input_batch, + fft_length0, fft_length1, fft_length2); +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft.h b/tensorflow/compiler/xla/service/cpu/runtime_fft.h new file mode 100644 index 0000000000000000000000000000000000000000..f20c5aa0aa2dcbc700f47c718e75baae18650d1a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +extern void __xla_cpu_runtime_EigenFft( + const void* /* xla::ExecutableRunOptions* */ run_options_ptr, void* out, + void* operand, tensorflow::int32 fft_type, tensorflow::int32 fft_rank, + tensorflow::int64 input_batch, tensorflow::int64 fft_length0, + tensorflow::int64 fft_length1, tensorflow::int64 fft_length2); + +} // extern "C" + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..984cb0616e02475babad7160d0f43bb23de0b50e --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fft_impl.h @@ -0,0 +1,240 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ + +#include + +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/numeric_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/types.h" + +// 'tensorflow' namespace is used so that int64 and other types don't require +// qualification. +namespace tensorflow { +namespace xla { + +namespace internal { + +// Computes either a forward or reverse complex-to-complex FFT. +template +void EigenFftC2C(const EigenDevice& device, complex64* out, complex64* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + // Create the axes (which are always trailing). + const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + constexpr auto direction = Forward ? Eigen::FFT_FORWARD : Eigen::FFT_REVERSE; + + const std::array fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes dims; + dims[0] = input_batch; + for (int i = 0; i < FFTRank; i++) { + dims[i + 1] = fft_shape[i]; + } + const Eigen::TensorMap, + Eigen::Aligned> + input(operand, dims); + Eigen::TensorMap, + Eigen::Aligned> + output(out, dims); + output.device(device) = input.template fft(axes); +} + +// Computes a forward real->complex FFT, slicing out redundant negative +// frequencies from the innermost dimension. +template +void EigenFftR2C(const EigenDevice& device, complex64* out, float* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const std::array fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes in_dims; + in_dims[0] = input_batch; + Eigen::DSizes out_dims; + out_dims[0] = input_batch; + TensorShape temp_shape{input_batch}; + for (int i = 0; i < FFTRank; i++) { + in_dims[i + 1] = fft_shape[i]; + out_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; + temp_shape.AddDim(fft_shape[i]); + } + const Eigen::TensorMap, + Eigen::Aligned> + input(operand, in_dims); + Eigen::TensorMap, + Eigen::Aligned> + output(out, out_dims); + + // Create the axes (which are always trailing). + const auto axes = Eigen::ArrayXi::LinSpaced(FFTRank, 1, FFTRank); + + // Compute the full FFT using a temporary tensor. + Tensor temp(DataTypeToEnum::v(), temp_shape); + auto full_fft = temp.flat_inner_dims(); + const Eigen::DSizes zero_start_indices; + full_fft.device(device) = + input.template fft(axes); + + // Slice away the negative frequency components. + output.device(device) = full_fft.slice(zero_start_indices, out_dims); +} + +// Computes a reverse complex->real FFT, reconstructing redundant negative +// frequencies using reverse conjugate on innermost dimension after doing IFFT +// on outer dimensions. +template +void EigenFftC2R(const EigenDevice& device, float* out, complex64* operand, + int64 input_batch, int64 fft_length0, int64 fft_length1, + int64 fft_length2) { + const std::array fft_shape = { + {fft_length0, fft_length1, fft_length2}}; + + Eigen::DSizes in_dims; + in_dims[0] = input_batch; + Eigen::DSizes out_dims; + out_dims[0] = input_batch; + TensorShape temp_shape{input_batch}; + for (int i = 0; i < FFTRank; i++) { + in_dims[i + 1] = i == FFTRank - 1 ? fft_shape[i] / 2 + 1 : fft_shape[i]; + out_dims[i + 1] = fft_shape[i]; + temp_shape.AddDim(fft_shape[i]); + } + const Eigen::TensorMap, + Eigen::Aligned> + input(operand, in_dims); + Eigen::TensorMap, + Eigen::Aligned> + output(out, out_dims); + + // Calculate the shape of the temporary tensor for the full FFT and the + // region we will slice from input given fft_shape. We slice input to + // fft_shape on its inner-most dimensions, except the last (which we + // slice to fft_shape[-1] / 2 + 1). + Tensor temp(DataTypeToEnum::v(), temp_shape); + auto full_fft = temp.flat_inner_dims(); + + // Calculate the starting point and range of the source of + // negative frequency part. + auto neg_sizes = in_dims; + neg_sizes[FFTRank] = fft_shape[FFTRank - 1] - in_dims[FFTRank]; + Eigen::DSizes neg_target_indices; + neg_target_indices[FFTRank] = in_dims[FFTRank]; + + const Eigen::DSizes zero_start_indices; + Eigen::DSizes neg_start_indices; + neg_start_indices[FFTRank] = 1; + + full_fft.slice(zero_start_indices, in_dims).device(device) = input; + + // First, conduct IFFTs on outer dimensions. We save computation (and + // avoid touching uninitialized memory) by slicing full_fft to the + // subregion we wrote input to. + if (FFTRank > 1) { + const auto outer_axes = + Eigen::ArrayXi::LinSpaced(FFTRank - 1, 1, FFTRank - 1); + full_fft.slice(zero_start_indices, in_dims).device(device) = + full_fft.slice(zero_start_indices, in_dims) + .template fft(outer_axes); + } + + // Reconstruct the full FFT by appending reversed and conjugated + // spectrum as the negative frequency part. + Eigen::array reverse_last_axis; + for (auto i = 0; i <= FFTRank; i++) { + reverse_last_axis[i] = i == FFTRank; + } + + if (neg_sizes[FFTRank] != 0) { + full_fft.slice(neg_target_indices, neg_sizes).device(device) = + full_fft.slice(neg_start_indices, neg_sizes) + .reverse(reverse_last_axis) + .conjugate(); + } + + auto inner_axis = Eigen::array{FFTRank}; + output.device(device) = + full_fft.template fft(inner_axis); +} + +template +void EigenFftWithRank(const EigenDevice& device, void* out, void* operand, + int32 fft_type, int64 input_batch, int64 fft_length0, + int64 fft_length1, int64 fft_length2) { + CHECK(::xla::FftType_IsValid(fft_type)) << fft_type; + switch (fft_type) { + case ::xla::FftType::FFT: + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + break; + case ::xla::FftType::IFFT: + EigenFftC2C( + device, static_cast(out), + static_cast(operand), input_batch, fft_length0, + fft_length1, fft_length2); + break; + case ::xla::FftType::RFFT: + EigenFftR2C( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + break; + case ::xla::FftType::IRFFT: + EigenFftC2R( + device, static_cast(out), static_cast(operand), + input_batch, fft_length0, fft_length1, fft_length2); + break; + default: + LOG(FATAL) << "Unsupported FFT type: " << fft_type; + } +} + +} // namespace internal + +template +void EigenFftImpl(const EigenDevice& device, void* out, void* operand, + int32 fft_type, int32 fft_rank, int64 input_batch, + int64 fft_length0, int64 fft_length1, int64 fft_length2) { + switch (fft_rank) { + case 1: + internal::EigenFftWithRank<1, EigenDevice>( + device, out, operand, fft_type, input_batch, fft_length0, 0, 0); + break; + case 2: + internal::EigenFftWithRank<2, EigenDevice>(device, out, operand, fft_type, + input_batch, fft_length0, + fft_length1, 0); + break; + case 3: + internal::EigenFftWithRank<3, EigenDevice>(device, out, operand, fft_type, + input_batch, fft_length0, + fft_length1, fft_length2); + break; + default: + LOG(FATAL) << "Unsupported FFT rank " << fft_rank; + } +} + +} // namespace xla +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FFT_IMPL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h index fcf1cc62078d3847435a2e75e3ca9d109cf8b200..1cf0ec6e3df400e35fa4e755a0b25b4ce7966e8f 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ #include "tensorflow/core/platform/types.h" @@ -30,4 +30,4 @@ extern void __xla_cpu_runtime_ParallelForkJoin( } // extern "C" -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h index cb7e0a81f09e2702de565012e1fcac8b7cd841ab..1bd8dfb377acc1f7cfbe9a92773f87f0ef25de3a 100644 --- a/tensorflow/compiler/xla/service/cpu/runtime_matvec.h +++ b/tensorflow/compiler/xla/service/cpu/runtime_matvec.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ #include "tensorflow/core/platform/types.h" @@ -42,4 +42,4 @@ void EigenMatVecF64(double* out, double* lhs, double* rhs, tensorflow::int64 m, } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_MATVEC_H_ diff --git a/tensorflow/compiler/xla/service/cpu/shape_partition.h b/tensorflow/compiler/xla/service/cpu/shape_partition.h index 7a2d00421cfdc8e41ec48698a16665621de16bda..33d02b70e61e3311c9af934e80874939fbe3adae 100644 --- a/tensorflow/compiler/xla/service/cpu/shape_partition.h +++ b/tensorflow/compiler/xla/service/cpu/shape_partition.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ #include @@ -102,4 +102,4 @@ class ShapePartitionIterator { } // namespace cpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_SHAPE_PARTITION_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index cda2783307925b77ac6d8cfe679c5b325db2befc..64d3a51f41676bbb4b59c9d272d22f52a87a0559 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -15,29 +15,28 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" -#include #include #include #include #include #include "llvm/ExecutionEngine/ExecutionEngine.h" +#include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/SectionMemoryManager.h" #include "llvm/IR/Mangler.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Host.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" -#include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" #include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/orc_jit_memory_mapper.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fft.h" #include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" +#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" @@ -46,7 +45,7 @@ namespace cpu { namespace { // A simple SymbolResolver that delegates to the host dynamic linker. -class SimpleResolver : public llvm::JITSymbolResolver { +class SimpleResolver : public llvm::LegacyJITSymbolResolver { public: explicit SimpleResolver(ExternalConstantPool* external_constant_pool) : external_constant_pool_(external_constant_pool) {} @@ -99,15 +98,6 @@ llvm::StringRef GetHostCpuName() { cpu_name.consume_back("-avx512"); return cpu_name; } - -CompilerFunctor::VectorIntrinsics GetAvailableIntrinsics() { - CompilerFunctor::VectorIntrinsics intrinsics; - intrinsics.sse_intrinsics = (&__xla_cpu_runtime_ExpV4F32SSE != nullptr); - intrinsics.avx_intrinsics = (&__xla_cpu_runtime_ExpV8F32AVX != nullptr); - intrinsics.neon_intrinsics = (&__xla_cpu_runtime_ExpV4F32NEON != nullptr); - return intrinsics; -} - } // namespace SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, @@ -126,34 +116,57 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, /*MAttrs=*/DetectMachineAttributes()))), disassembler_(*target_machine_), data_layout_(target_machine_->createDataLayout()), - object_layer_([] { - return std::make_shared( - orc_jit_memory_mapper::GetInstance()); - }), - compile_layer_( - object_layer_, - CompilerFunctor(target_machine_.get(), &disassembler_, opt_level, - optimize_for_size, enable_fast_math, - disable_expensive_passes, GetAvailableIntrinsics(), - std::move(pre_optimization_hook), - std::move(post_optimization_hook))) { + execution_session_(string_pool_), + symbol_resolver_(llvm::orc::createLegacyLookupResolver( + [this](const std::string& name) -> llvm::JITSymbol { + if (const uint8* from_constant_pool = + external_constant_pool_.Find(string(name))) { + return llvm::JITEvaluatedSymbol( + reinterpret_cast(from_constant_pool), + llvm::JITSymbolFlags::None); + } + + void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); + if (func_addr == nullptr) { + return nullptr; + } + llvm::JITEvaluatedSymbol symbol_info( + reinterpret_cast(func_addr), + llvm::JITSymbolFlags::None); + return symbol_info; + }, + [](llvm::Error Err) { + cantFail(std::move(Err), "lookupFlags failed"); + })), + object_layer_( + execution_session_, + [](llvm::orc::VModuleKey) { + return std::make_shared( + orc_jit_memory_mapper::GetInstance()); + }, + [this](llvm::orc::VModuleKey K) { return symbol_resolver_; }), + compile_layer_(object_layer_, + CompilerFunctor(target_machine_.get(), &disassembler_, + opt_level, optimize_for_size, + enable_fast_math, disable_expensive_passes, + std::move(pre_optimization_hook), + std::move(post_optimization_hook))) { VLOG(1) << "CPU target: " << target_machine_->getTargetCPU().str() << " features: " << target_machine_->getTargetFeatureString().str(); } -SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( +SimpleOrcJIT::VModuleKeyT SimpleOrcJIT::AddModule( std::unique_ptr module) { - auto handle = cantFail(compile_layer_.addModule( - std::move(module), MakeUnique(external_constant_pool()))); - module_handles_.push_back(handle); - return handle; + auto key = execution_session_.allocateVModule(); + cantFail(compile_layer_.addModule(key, std::move(module))); + module_keys_.push_back(key); + return key; } -void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::ModuleHandleT handle) { - module_handles_.erase( - std::remove(module_handles_.begin(), module_handles_.end(), handle), - module_handles_.end()); - cantFail(compile_layer_.removeModule(handle)); +void SimpleOrcJIT::RemoveModule(SimpleOrcJIT::VModuleKeyT key) { + module_keys_.erase(std::remove(module_keys_.begin(), module_keys_.end(), key), + module_keys_.end()); + cantFail(compile_layer_.removeModule(key)); } llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { @@ -165,10 +178,10 @@ llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { // Resolve symbol from last module to first, allowing later redefinitions of // symbols shadow earlier ones. - for (auto& handle : - llvm::make_range(module_handles_.rbegin(), module_handles_.rend())) { + for (auto& key : + llvm::make_range(module_keys_.rbegin(), module_keys_.rend())) { if (auto symbol = - compile_layer_.findSymbolIn(handle, mangled_name, + compile_layer_.findSymbolIn(key, mangled_name, /*ExportedSymbolsOnly=*/true)) { return symbol; } @@ -196,17 +209,12 @@ bool RegisterKnownJITSymbols() { REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenFft); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); - REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); - REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); - REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); - REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); - REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); @@ -253,15 +261,15 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(ilogb, int (*)(double)); REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int)); REGISTER_LIBM_SYMBOL(lgamma, double (*)(double)); - REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); - REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); + REGISTER_LIBM_SYMBOL(llrint, long long (*)(double)); // NOLINT(runtime/int) + REGISTER_LIBM_SYMBOL(llround, long long (*)(double)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(log, double (*)(double)); REGISTER_LIBM_SYMBOL(log10, double (*)(double)); REGISTER_LIBM_SYMBOL(log1p, double (*)(double)); REGISTER_LIBM_SYMBOL(log2, double (*)(double)); REGISTER_LIBM_SYMBOL(logb, double (*)(double)); - REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); - REGISTER_LIBM_SYMBOL(lround, long (*)(double)); + REGISTER_LIBM_SYMBOL(lrint, long (*)(double)); // NOLINT(runtime/int) + REGISTER_LIBM_SYMBOL(lround, long (*)(double)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*)); REGISTER_LIBM_SYMBOL(nan, double (*)(const char*)); REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double)); @@ -272,10 +280,15 @@ bool RegisterKnownJITSymbols() { REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*)); REGISTER_LIBM_SYMBOL(rint, double (*)(double)); REGISTER_LIBM_SYMBOL(round, double (*)(double)); - REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long)); + REGISTER_LIBM_SYMBOL(scalbln, + double (*)(double, long)); // NOLINT(runtime/int) REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int)); REGISTER_LIBM_SYMBOL(sin, double (*)(double)); +#ifdef __APPLE__ + REGISTER_LIBM_SYMBOL(__sincos, void (*)(double, double*, double*)); +#else REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*)); +#endif REGISTER_LIBM_SYMBOL(sinh, double (*)(double)); REGISTER_LIBM_SYMBOL(sqrt, double (*)(double)); REGISTER_LIBM_SYMBOL(tan, double (*)(double)); diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index ded01e9e4d7442296f7406dd035e6ab385458238..50993afc8f73617a2c65310ae73b3ab00519f550 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -21,8 +21,10 @@ limitations under the License. #include #include "llvm/ADT/Triple.h" +#include "llvm/ExecutionEngine/Orc/Core.h" #include "llvm/ExecutionEngine/Orc/IRCompileLayer.h" #include "llvm/ExecutionEngine/Orc/RTDyldObjectLinkingLayer.h" +#include "llvm/ExecutionEngine/Orc/SymbolStringPool.h" #include "llvm/IR/Module.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" @@ -48,7 +50,7 @@ class SimpleOrcJIT { std::function( llvm::Module&)>; using CompileLayerT = llvm::orc::IRCompileLayer; - using ModuleHandleT = CompileLayerT::ModuleHandleT; + using VModuleKeyT = llvm::orc::VModuleKey; // Create a new JIT, targeting the host architecture. // The |target_options| parameter allows customization of certain code @@ -78,12 +80,12 @@ class SimpleOrcJIT { return target_machine_->getTargetTriple(); } - // Add a module to the JIT. Returns an opaque handle that can be used to later + // Add a module to the JIT. Returns an opaque key that can be used to later // remove this module. - ModuleHandleT AddModule(std::unique_ptr module); + VModuleKeyT AddModule(std::unique_ptr module); // Remove a module from the JIT and free the memory associated with it. - void RemoveModule(ModuleHandleT handle); + void RemoveModule(VModuleKeyT key); // Get the runtime address of the compiled symbol whose name is given. Returns // nullptr if the symbol cannot be found. @@ -96,10 +98,13 @@ class SimpleOrcJIT { } private: - std::vector module_handles_; + std::vector module_keys_; std::unique_ptr target_machine_; const Disassembler disassembler_; const llvm::DataLayout data_layout_; + llvm::orc::SymbolStringPool string_pool_; + llvm::orc::ExecutionSession execution_session_; + std::shared_ptr symbol_resolver_; ObjLayerT object_layer_; CompileLayerT compile_layer_; ExternalConstantPool external_constant_pool_; diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc new file mode 100644 index 0000000000000000000000000000000000000000..eeb049737dddd11ef2ce229df772baec3ac03dd8 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc @@ -0,0 +1,35 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" + +namespace xla { +namespace cpu { + +llvm::TargetTransformInfo* TargetMachineFeatures::GetTargetTransformInfoFor( + const llvm::Function& function) const { + auto it = target_transform_info_cache_.find(&function); + if (it == target_transform_info_cache_.end()) { + auto emplace_result = target_transform_info_cache_.emplace( + &function, target_machine_->getTargetTransformInfo(function)); + CHECK(emplace_result.second); + it = emplace_result.first; + } + + return &it->second; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h new file mode 100644 index 0000000000000000000000000000000000000000..703942615e552dccde7ddec8c8b90e8a486652af --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ + +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Target/TargetMachine.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace cpu { + +// Wraps an llvm::TargetMachine and parses out some information that feeds into +// LLVM IR code generation decisions. +class TargetMachineFeatures { + public: + static constexpr int kX86AvxVectorByteSize = 32; + + TargetMachineFeatures(llvm::TargetMachine* target_machine) + : target_machine_(target_machine) {} + + // Return the vectorization factor, which is the number of bytes of data + // explicitly vectorized routines will try to process at once. + int vectorization_factor_in_bytes() const { + // Ideally this should be a function of the cache line size (which we can + // get from llvm::TargetTransformInfo::getCacheLineSize) of the target + // machine. Guess a value of 128 bytes for now. + return 128; + } + + // Return the size of the largest vector size in bytes. We need to pass in + // "function" since llvm functions can contain annotations for specializing + // them to specific micro-architectures (though currently XLA does not use + // this functionality). + int vector_register_byte_size(const llvm::Function& function) const { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return tti->getRegisterBitWidth(/*Vector=*/true) / 8; + } + + // Return the number of elements of type `type` that can fit into the largest + // vector register available. We need to pass in "function" since llvm + // functions can contain annotations for specializing them to specific + // micro-architectures (though currently XLA does not use this functionality). + int vector_register_num_elements(const llvm::Function& function, + PrimitiveType type) const { + return vector_register_byte_size(function) / + (primitive_util::BitWidth(type) / 8); + } + + private: + llvm::TargetTransformInfo* GetTargetTransformInfoFor( + const llvm::Function& function) const; + + // This cache saves us from having to create a llvm::TargetTransformInfo for + // every call to GetTargetTransformInfoFor (creating a TargetTransformInfo + // costs one heap allocation on X86). + // + // This is mutated from within `GetTargetTransformInfoFor` which is + // semantically a getter (and thus `const`); and is therefore declared + // mutable. Making this mutable is okay because it has cache semantics. + mutable tensorflow::gtl::FlatMap + target_transform_info_cache_; + llvm::TargetMachine* target_machine_; +}; + +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_ diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.cc b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc new file mode 100644 index 0000000000000000000000000000000000000000..150db1cb6edec1af6724a8bca6a5f6272f1a7416 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.cc @@ -0,0 +1,424 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" + +#include "llvm/Support/raw_ostream.h" +#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" + +namespace xla { +namespace cpu { +VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, + int64 vector_size, + llvm::IRBuilder<>* ir_builder, + std::string name) + : vector_size_(vector_size), + primitive_type_(primitive_type), + ir_builder_(ir_builder), + name_(std::move(name)) { + scalar_type_ = llvm_ir::PrimitiveTypeToIrType( + primitive_type, ir_builder_->GetInsertBlock()->getModule()); + scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); + vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); + vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); +} + +static string TypeToString(llvm::Type* type) { + std::string o; + llvm::raw_string_ostream ostream(o); + type->print(ostream); + return ostream.str(); +} + +void VectorSupportLibrary::AssertCorrectTypes( + std::initializer_list values) { + for (llvm::Value* v : values) { + llvm::Type* type = v->getType(); + if (type != scalar_type() && type != vector_type()) { + LOG(FATAL) << "Expected either " << TypeToString(scalar_type()) << " or " + << TypeToString(vector_type()) << " but got " + << TypeToString(type); + } + } +} + +llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return MulInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::MulInternal(llvm::Value* lhs, + llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFMul(lhs, rhs, name()); + } else { + return ir_builder()->CreateMul(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return AddInternal(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::Sub(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return ir_builder()->CreateFSub(lhs, rhs); +} + +llvm::Value* VectorSupportLibrary::Max(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + if (scalar_type_->isFloatingPointTy()) { + return llvm_ir::EmitFloatMax(lhs, rhs, ir_builder_); + } else { + LOG(FATAL) << "Max for integers is unimplemented"; + } +} + +llvm::Value* VectorSupportLibrary::Floor(llvm::Value* a) { + AssertCorrectTypes({a}); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor, {a}, + {a->getType()}, ir_builder()); +} + +llvm::Value* VectorSupportLibrary::Div(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFDiv(lhs, rhs, name()); + } else { + LOG(FATAL) << "Division for integers is unimplemented"; + } +} + +llvm::Value* VectorSupportLibrary::Clamp(llvm::Value* a, + const llvm::APFloat& low, + const llvm::APFloat& high) { + AssertCorrectTypes({a}); + llvm::Type* type = a->getType(); + CHECK(low.compare(high) == llvm::APFloat::cmpLessThan); + CHECK(scalar_type_->isFloatingPointTy()); + return llvm_ir::EmitFloatMin( + llvm_ir::EmitFloatMax(a, GetConstantFloat(type, low), ir_builder_), + GetConstantFloat(type, high), ir_builder_); +} + +llvm::Value* VectorSupportLibrary::FCmpEQMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpOEQ(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::FCmpOLTMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpOLT(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::FCmpULEMask(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + return I1ToFloat(ir_builder()->CreateFCmpULE(lhs, rhs, name())); +} + +llvm::Value* VectorSupportLibrary::I1ToFloat(llvm::Value* i1) { + bool is_vector = llvm::isa(i1->getType()); + llvm::Type* integer_type = IntegerTypeForFloatSize(is_vector); + return ir_builder()->CreateBitCast( + ir_builder()->CreateSExt(i1, integer_type, name()), + is_vector ? vector_type() : scalar_type(), name()); +} + +llvm::Type* VectorSupportLibrary::IntegerTypeForFloatSize(bool vector) { + CHECK(scalar_type()->isFloatingPointTy()); + const llvm::DataLayout& data_layout = + ir_builder()->GetInsertBlock()->getModule()->getDataLayout(); + int64 float_size_bits = data_layout.getTypeSizeInBits(scalar_type()); + llvm::Type* scalar_int_type = ir_builder()->getIntNTy(float_size_bits); + if (vector) { + return llvm::VectorType::get(scalar_int_type, vector_size()); + } else { + return scalar_int_type; + } +} + +llvm::Value* VectorSupportLibrary::BroadcastScalar(llvm::Value* x) { + CHECK_EQ(x->getType(), scalar_type()); + return ir_builder()->CreateVectorSplat(vector_size(), x, name()); +} + +llvm::Value* VectorSupportLibrary::FloatAnd(llvm::Value* lhs, + llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateAnd( + ir_builder()->CreateBitCast(lhs, int_type, name()), + ir_builder()->CreateBitCast(rhs, int_type, name()), name()), + vector_type()); +} + +llvm::Value* VectorSupportLibrary::FloatNot(llvm::Value* lhs) { + AssertCorrectTypes({lhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateNot( + ir_builder()->CreateBitCast(lhs, int_type, name()), name()), + vector_type()); +} + +llvm::Value* VectorSupportLibrary::FloatOr(llvm::Value* lhs, llvm::Value* rhs) { + AssertCorrectTypes({lhs, rhs}); + llvm::Type* int_type = + IntegerTypeForFloatSize(lhs->getType() == vector_type()); + return ir_builder()->CreateBitCast( + ir_builder()->CreateOr(ir_builder()->CreateBitCast(lhs, int_type, name()), + ir_builder()->CreateBitCast(rhs, int_type, name()), + name()), + vector_type(), name()); +} + +llvm::Value* VectorSupportLibrary::AddInternal(llvm::Value* lhs, + llvm::Value* rhs) { + if (scalar_type_->isFloatingPointTy()) { + return ir_builder()->CreateFAdd(lhs, rhs, name()); + } else { + return ir_builder()->CreateAdd(lhs, rhs, name()); + } +} + +llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( + llvm::Value* base_pointer, llvm::Value* offset_elements) { + if (base_pointer->getType() != scalar_pointer_type()) { + base_pointer = ir_builder()->CreateBitCast(base_pointer, + scalar_pointer_type(), name()); + } + return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, + name()); +} + +llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { + if (pointer->getType() != vector_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateAlignedLoad( + pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); +} + +void VectorSupportLibrary::StoreVector(llvm::Value* value, + llvm::Value* pointer) { + AssertCorrectTypes({value}); + if (pointer->getType() != vector_pointer_type()) { + pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +void VectorSupportLibrary::StoreScalar(llvm::Value* value, + llvm::Value* pointer) { + AssertCorrectTypes({value}); + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + ir_builder()->CreateAlignedStore( + value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); +} + +llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { + if (pointer->getType() != scalar_pointer_type()) { + pointer = + ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); + } + return ir_builder()->CreateVectorSplat( + vector_size(), ir_builder()->CreateLoad(pointer), name()); +} + +llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { + llvm::SmallVector mask(vector_size(), nullptr); + for (unsigned i = vector_size(); i != 1; i >>= 1) { + // On every iteration, we shuffle half of the remaining lanes to the top + // half of shuffle, and add two old and the new vector. + + for (unsigned j = 0; j < vector_size(); ++j) { + if (j < (i / 2)) { + mask[j] = ir_builder()->getInt32(i / 2 + j); + } else { + mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); + } + } + + llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( + vector, llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask), ""); + vector = Add(vector, half_remaining_lanes); + } + + return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), + name()); +} + +llvm::Value* VectorSupportLibrary::AvxStyleHorizontalAdd(llvm::Value* lhs, + llvm::Value* rhs) { + CHECK_EQ(lhs->getType(), vector_type()); + CHECK_EQ(rhs->getType(), vector_type()); + CHECK_EQ(vector_size() % 2, 0); + + llvm::SmallVector mask_a, mask_b; + + // Adding the values shuffled using mask_a and mask_b gives us the + // AVX-style horizontal add we want. The masks work as documented + // in https://llvm.org/docs/LangRef.html#shufflevector-instruction + // + // Here are the masks for vector_width() == 8: + // + // index: |0 |1 |2 | 3 |4 |5 | 6 | 7 + // --------+--+--+--+---+--+--+---+--- + // mask_a: |0 |2 |8 |10 |4 |6 |12 |14 + // mask_b: |1 |3 |9 |11 |5 |7 |13 |16 + // + // So, as an example, the value at lane 3 of the result vector is + // the result of adding lane 10 and lane 11 in the combined lhs++rhs + // vector, which are the lanes 2 and 3 in the rhs vector. + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? 0 : (vector_size() / 2); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + for (int i = 0; i < vector_size(); i += 2) { + int increment = i < vector_size() / 2 ? (vector_size() / 2) : vector_size(); + mask_a.push_back(ir_builder()->getInt32(increment + i)); + mask_b.push_back(ir_builder()->getInt32(increment + i + 1)); + } + + llvm::Value* shuffle_0 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_a)); + llvm::Value* shuffle_1 = ir_builder()->CreateShuffleVector( + lhs, rhs, llvm::ConstantVector::get(mask_b)); + + return Add(shuffle_0, shuffle_1); +} + +llvm::Value* VectorSupportLibrary::ExtractLowHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +llvm::Value* VectorSupportLibrary::ExtractHighHalf(llvm::Value* vector) { + llvm::SmallVector mask; + for (int i = 0; i < vector_size() / 2; i++) { + mask.push_back(ir_builder()->getInt32(i + vector_size() / 2)); + } + + return ir_builder()->CreateShuffleVector(vector, + llvm::UndefValue::get(vector_type()), + llvm::ConstantVector::get(mask)); +} + +std::vector VectorSupportLibrary::ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + const int x86_avx_vector_elements = + TargetMachineFeatures::kX86AvxVectorByteSize / scalar_byte_size(); + if (vector_size() == x86_avx_vector_elements && + vectors.size() == x86_avx_vector_elements) { + return ComputeAvxOptimizedHorizontalSums(std::move(vectors), init_values); + } + + std::vector result; + std::transform(vectors.begin(), vectors.end(), std::back_inserter(result), + [this](llvm::Value* vector) { return AddReduce(vector); }); + if (init_values) { + for (int64 i = 0, e = result.size(); i < e; i++) { + result[i] = Add(result[i], ir_builder()->CreateExtractElement( + init_values, ir_builder()->getInt32(i))); + } + } + return result; +} + +std::vector +VectorSupportLibrary::ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values) { + while (vectors.size() != 2) { + std::vector new_vectors; + for (int i = 0; i < vectors.size(); i += 2) { + new_vectors.push_back(AvxStyleHorizontalAdd(vectors[i], vectors[i + 1])); + } + + vectors = std::move(new_vectors); + } + + llvm::Value* low = + AddInternal(ExtractLowHalf(vectors[0]), ExtractHighHalf(vectors[0])); + if (init_values) { + low = AddInternal(ExtractLowHalf(init_values), low); + } + llvm::Value* high = + AddInternal(ExtractLowHalf(vectors[1]), ExtractHighHalf(vectors[1])); + if (init_values) { + high = AddInternal(ExtractHighHalf(init_values), high); + } + + std::vector results; + for (int i = 0; i < 8; i++) { + llvm::Value* scalar_result = ir_builder()->CreateExtractElement( + i < 4 ? low : high, ir_builder()->getInt32(i % 4), name()); + results.push_back(scalar_result); + } + + return results; +} + +llvm::Value* VectorSupportLibrary::GetZeroVector() { + return llvm::Constant::getNullValue(vector_type()); +} + +llvm::Value* VectorSupportLibrary::GetZeroScalar() { + return llvm::Constant::getNullValue(scalar_type()); +} + +LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) + : ir_builder_(ir_builder) { + alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); +} + +llvm::Value* LlvmVariable::Get() const { + return ir_builder_->CreateLoad(alloca_); +} + +void LlvmVariable::Set(llvm::Value* new_value) { + ir_builder_->CreateStore(new_value, alloca_); +} +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/vector_support_library.h b/tensorflow/compiler/xla/service/cpu/vector_support_library.h new file mode 100644 index 0000000000000000000000000000000000000000..6479bf76aab581ae3ec2923d98dab53720cab203 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/vector_support_library.h @@ -0,0 +1,317 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ + +#include + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace cpu { + +// Simple wrappers around llvm::APFloat::APFloat to make the calling code more +// obvious. + +inline llvm::APFloat GetIeeeF32(float f) { return llvm::APFloat(f); } +inline llvm::APFloat GetIeeeF32FromBitwiseRep(int32 bitwise_value) { + return llvm::APFloat(llvm::APFloat::IEEEsingle(), + llvm::APInt(/*numBits=*/32, /*val=*/bitwise_value)); +} + +// A thin wrapper around llvm_util.h to make code generating vector math flow +// more readable. +class VectorSupportLibrary { + public: + // This VectorSupportLibrary instance remembers `primitive_type` and + // `vector_size`, and these are implicitly used by the methods on this + // instance (i.e. LoadVector will load a vector of type <`vector_size` x + // `primitive_type`>). + VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, + llvm::IRBuilder<>* ir_builder, std::string name); + + llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { + return Mul(ir_builder()->getInt64(lhs), rhs); + } + llvm::Value* Mul(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Mul(GetConstantFloat(rhs->getType(), lhs), rhs); + } + + // If your call resolved to these then you probably wanted the versions taking + // APFloat. + llvm::Value* Mul(double lhs, llvm::Value* rhs) = delete; + llvm::Value* Mul(float lhs, llvm::Value* rhs) = delete; + + llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Add(int64 lhs, llvm::Value* rhs) { + return Add(ir_builder()->getInt64(lhs), rhs); + } + llvm::Value* Add(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Add(GetConstantFloat(rhs->getType(), lhs), rhs); + } + + // If your call resolved to these then you probably wanted the versions taking + // APFloat. + llvm::Value* Add(double lhs, llvm::Value* rhs) = delete; + llvm::Value* Add(float lhs, llvm::Value* rhs) = delete; + + llvm::Value* Sub(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Sub(llvm::Value* lhs, const llvm::APFloat& rhs) { + return Sub(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* Max(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* Max(const llvm::APFloat& lhs, llvm::Value* rhs) { + return Max(GetConstantFloat(rhs->getType(), lhs), rhs); + } + llvm::Value* Div(llvm::Value* lhs, llvm::Value* rhs); + + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { + return Add(c, Mul(a, b)); + } + + llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, const llvm::APFloat& c) { + return Add(GetConstantFloat(vector_type(), c), Mul(a, b)); + } + + llvm::Value* MulAdd(llvm::Value* a, const llvm::APFloat& b, + const llvm::APFloat& c) { + return Add(GetConstantFloat(a->getType(), c), + Mul(a, GetConstantFloat(a->getType(), b))); + } + + llvm::Value* Floor(llvm::Value* a); + + llvm::Value* Clamp(llvm::Value* a, const llvm::APFloat& low, + const llvm::APFloat& high); + llvm::Value* SplatFloat(const llvm::APFloat& d) { + return GetConstantFloat(vector_type(), d); + } + + // These compare instructions return a floating point typed mask instead of an + // i1. For instance, on a vector typed input, lanes where the predicate is + // true get a float with all ones and other lanes get a float with all zeros. + // This is slightly odd from the perspective of LLVM's type system, but it + // makes kernel IR generation code written using VectorSupportLibrary (its + // raison d'etre) less cluttered. + + llvm::Value* FCmpEQMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpULEMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpOLTMask(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FCmpOLTMask(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FCmpOLTMask(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + + // These boolean operations operate on the bitwise values of the floating + // point inputs. They return a (vector of) float(s) but like in the mask + // generating predicates above this type system oddity makes the kernel IR + // generation code less cluttered. + llvm::Value* FloatAnd(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FloatAnd(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FloatAnd(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* FloatOr(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* FloatOr(llvm::Value* lhs, const llvm::APFloat& rhs) { + return FloatOr(lhs, GetConstantFloat(lhs->getType(), rhs)); + } + llvm::Value* FloatNot(llvm::Value* lhs); + llvm::Value* FloatAndNot(llvm::Value* lhs, llvm::Value* rhs) { + return FloatAnd(FloatNot(lhs), rhs); + } + + llvm::Value* BroadcastScalar(llvm::Value* x); + llvm::Value* BroadcastScalar(const llvm::APFloat& d) { + return BroadcastScalar(GetConstantFloat(scalar_type(), d)); + } + + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + llvm::Value* offset_elements); + llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, + int64 offset_elements) { + return ComputeOffsetPointer(base_pointer, + ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* pointer); + + llvm::Value* LoadVector(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { + return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* pointer); + + llvm::Value* LoadScalar(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements)); + } + + llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { + return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* pointer); + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreVector(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* pointer); + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + llvm::Value* offset_elements) { + StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements)); + } + + void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, + int64 offset_elements) { + StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + llvm::Value* LoadBroadcast(llvm::Value* pointer); + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, + llvm::Value* offset_elements) { + return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); + } + llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { + return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); + } + + // Compute the horizontal sum of each vector in `vectors`. The i'th element + // in the result vector is the (scalar) horizontal sum of the i'th vector in + // `vectors`. If `init_values` is not nullptr then the value in the i'th lane + // in `init_values` is added to the i'th horizontal sum. + std::vector ComputeHorizontalSums( + std::vector vectors, llvm::Value* init_values = nullptr); + + llvm::Value* GetZeroVector(); + llvm::Value* GetZeroScalar(); + + llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + int64 vector_size() const { return vector_size_; } + llvm::Type* vector_type() const { return vector_type_; } + llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } + llvm::Type* scalar_type() const { return scalar_type_; } + llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; } + int64 scalar_byte_size() const { + return primitive_util::BitWidth(primitive_type_) / 8; + } + + const std::string& name() const { return name_; } + + private: + llvm::Value* ExtractLowHalf(llvm::Value*); + llvm::Value* ExtractHighHalf(llvm::Value*); + + llvm::Value* MulInternal(llvm::Value* lhs, llvm::Value* rhs); + llvm::Value* AddInternal(llvm::Value* lhs, llvm::Value* rhs); + + llvm::Value* AddReduce(llvm::Value* vector); + + // Checks that each value in `values` is either of type scalar_type() or + // vector_type(). This LOG(FATAL)'s so it should only be called in cases + // where a mismatching type is a programmer bug. + void AssertCorrectTypes(std::initializer_list values); + + // Perform an X86 AVX style horizontal add between `lhs` and `rhs`. The + // resulting IR for an 8-float wide vector is expected to lower to a single + // vhaddps instruction on a CPU that supports vhaddps, and not be too bad in + // other cases. + // + // For a vector width of 8, the result vector is computed as: + // Result[0] = Lhs[0] + Lhs[1] + // Result[1] = Lhs[2] + Lhs[3] + // Result[2] = Rhs[0] + Rhs[1] + // Result[3] = Rhs[2] + Rhs[3] + // Result[4] = Lhs[4] + Lhs[5] + // Result[5] = Lhs[6] + Lhs[7] + // Result[6] = Rhs[4] + Rhs[5] + // Result[7] = Rhs[6] + Rhs[7] + llvm::Value* AvxStyleHorizontalAdd(llvm::Value* lhs, llvm::Value* rhs); + + std::vector ComputeAvxOptimizedHorizontalSums( + std::vector vectors, llvm::Value* init_values); + + llvm::Type* IntegerTypeForFloatSize(bool vector); + llvm::Value* I1ToFloat(llvm::Value* i1); + llvm::Value* GetConstantFloat(llvm::Type* type, const llvm::APFloat& f) { + llvm::Constant* scalar_value = llvm::ConstantFP::get(type->getContext(), f); + if (llvm::isa(type)) { + return llvm::ConstantVector::getSplat(vector_size(), scalar_value); + } + return scalar_value; + } + + int64 vector_size_; + PrimitiveType primitive_type_; + llvm::IRBuilder<>* ir_builder_; + llvm::Type* vector_type_; + llvm::Type* vector_pointer_type_; + llvm::Type* scalar_type_; + llvm::Type* scalar_pointer_type_; + std::string name_; +}; + +// This wraps an alloca-backed stack variable which LLVM's SSA construction pass +// can later convert to a SSA value. +class LlvmVariable { + public: + LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); + + llvm::Value* Get() const; + void Set(llvm::Value* new_value); + + private: + llvm::AllocaInst* alloca_; + llvm::IRBuilder<>* ir_builder_; +}; + +class VectorVariable : public LlvmVariable { + public: + VectorVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->vector_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; + +class ScalarVariable : public LlvmVariable { + public: + ScalarVariable(VectorSupportLibrary* vector_support, + llvm::Value* initial_value) + : LlvmVariable(vector_support->scalar_type(), + vector_support->ir_builder()) { + Set(initial_value); + } +}; +} // namespace cpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_VECTOR_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc b/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc new file mode 100644 index 0000000000000000000000000000000000000000..ab308ee6cb16ba95e24694b59a4b5737765bbb8b --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/windows_compatibility.cc @@ -0,0 +1,32 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/windows_compatibility.h" + +#ifdef _MSC_VER + +#include + +void sincos(double x, double *sinv, double *cosv) { + *sinv = sin(x); + *cosv = cos(x); +} + +void sincosf(float x, float *sinv, float *cosv) { + *sinv = sinf(x); + *cosv = cosf(x); +} + +#endif // _MSC_VER diff --git a/tensorflow/compiler/xla/service/cpu/windows_compatibility.h b/tensorflow/compiler/xla/service/cpu/windows_compatibility.h new file mode 100644 index 0000000000000000000000000000000000000000..262f379d8b6017f4a7e0156b724bfee7e8ec5b9a --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/windows_compatibility.h @@ -0,0 +1,31 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_ + +#ifdef _MSC_VER + +extern "C" { + +// MSVC does not have sincos[f]. +void sincos(double x, double *sinv, double *cosv); +void sincosf(float x, float *sinv, float *cosv); + +} + +#endif // _MSC_VER + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_WINDOWS_COMPATIBILITY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc index d0f214202908266371639af8f431ad8269ad0e35..47543b2082f55cf7b8cf60f1c5bbb16a0a609912 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.cc @@ -41,6 +41,8 @@ void XfeedQueueManager::EnqueueBuffersAtomically( tensorflow::mutex_lock l(mu_); bool was_empty = enqueued_buffers_.empty(); for (XfeedBuffer* b : buffers) { + VLOG(3) << "Enqueueing " << queue_name_ << " buffer (of " << buffers.size() + << " buffers) with length: " << b->length(); enqueued_buffers_.push_back(b); } if (was_empty && !buffers.empty()) { @@ -54,9 +56,11 @@ void XfeedQueueManager::EnqueueBuffersAtomically( XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { tensorflow::mutex_lock l(mu_); + VLOG(3) << "Waiting for an available buffer."; while (enqueued_buffers_.empty()) { cv_.wait(l); } + VLOG(3) << "A buffer is available!"; CHECK(current_buffer_ == nullptr); current_buffer_ = enqueued_buffers_.front(); enqueued_buffers_.pop_front(); @@ -65,6 +69,9 @@ XfeedBuffer* XfeedQueueManager::BlockingDequeueBuffer() { void XfeedQueueManager::ReleaseCurrentBuffer(int32 length, void* data, StatusOr shape) { + VLOG(3) << "Releasing buffer with shape: " + << (shape.ok() ? ShapeUtil::HumanString(shape.ValueOrDie()) + : ""); tensorflow::mutex_lock l(mu_); CHECK(current_buffer_ != nullptr); CHECK_EQ(length, current_buffer_->length()); diff --git a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h index 6af55700052007a2ca419d52b63dddea2052bd0b..b4ace232607e14fbfec01d48946f0031d96cd027 100644 --- a/tensorflow/compiler/xla/service/cpu/xfeed_manager.h +++ b/tensorflow/compiler/xla/service/cpu/xfeed_manager.h @@ -50,7 +50,7 @@ class XfeedBuffer { // Reusable component for managing the infeed and outfeed queue state. class XfeedQueueManager { public: - XfeedQueueManager() = default; + XfeedQueueManager(string queue_name) : queue_name_(queue_name) {} // Calls the completion callback for any enqueued buffers that have // not been dequeued by the runtime, and empties the @@ -86,6 +86,8 @@ class XfeedQueueManager { void ReleaseCurrentBuffer(int32 length, void* data, StatusOr shape); private: + const string queue_name_; + tensorflow::mutex mu_; // Condition variable that is signaled every time a buffer is @@ -112,8 +114,8 @@ class XfeedManager { XfeedQueueManager* outfeed() { return &outfeed_; } private: - XfeedQueueManager infeed_; - XfeedQueueManager outfeed_; + XfeedQueueManager infeed_ = {"infeed"}; + XfeedQueueManager outfeed_ = {"outfeed"}; }; } // namespace runtime diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.cc b/tensorflow/compiler/xla/service/device_memory_allocator.cc index 2e4b0a5230516b5308aeed892de9a49565a09f2e..78e7aa48accdbb51a8477455f5f9c004828c068f 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.cc +++ b/tensorflow/compiler/xla/service/device_memory_allocator.cc @@ -24,7 +24,7 @@ limitations under the License. namespace xla { StreamExecutorMemoryAllocator::StreamExecutorMemoryAllocator( - perftools::gputools::Platform* platform, + const perftools::gputools::Platform* platform, tensorflow::gtl::ArraySlice stream_executors) : DeviceMemoryAllocator(platform), diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index 00caefab667cba6abfef200050ca18f229fc0320..39dfad84c1c1c1c461c24de555ecd919cea47d83 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -33,7 +33,7 @@ class DeviceMemoryAllocator { public: // Parameter platform indicates which platform the allocator allocates memory // on. Must be non-null. - explicit DeviceMemoryAllocator(perftools::gputools::Platform* platform) + explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform) : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} @@ -49,14 +49,14 @@ class DeviceMemoryAllocator { int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0; // Return the platform that the allocator allocates memory on. - perftools::gputools::Platform* platform() const { return platform_; } + const perftools::gputools::Platform* platform() const { return platform_; } // Can we call Deallocate() as soon as a computation has been scheduled on // a stream, or do we have to wait for the computation to complete first? virtual bool AllowsAsynchronousDeallocation() const = 0; protected: - perftools::gputools::Platform* platform_; + const perftools::gputools::Platform* platform_; }; // Default memory allocator for a platform which uses @@ -64,7 +64,7 @@ class DeviceMemoryAllocator { class StreamExecutorMemoryAllocator : public DeviceMemoryAllocator { public: StreamExecutorMemoryAllocator( - perftools::gputools::Platform* platform, + const perftools::gputools::Platform* platform, tensorflow::gtl::ArraySlice stream_executors); diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 91086fd4a5f68211ef56c2417bb0ef4a38de2cff..a803b3171f9afa6297553c5507c4f9aa45e420ab 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -103,6 +103,7 @@ class DfsHloVisitorBase { return HandleElementwiseBinary(hlo); } virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; + virtual Status HandleFft(HloInstructionPtr fft) = 0; virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; virtual Status HandleCompare(HloInstructionPtr hlo) { return HandleElementwiseBinary(hlo); @@ -247,6 +248,10 @@ class DfsHloVisitorBase { // affecting correctness. void ReserveVisitStates(int num) { visit_state_.Reserve(num); } + // Useful when we want to visit the same computation more than once with the + // same visitor. + void ResetVisitStates() { visit_state_.Reset(); } + void SetVisitState(int id, VisitState state) { visit_state_.SetState(id, state); } @@ -326,6 +331,7 @@ class DfsHloVisitorBase { *w = (*w & ~mask) | (static_cast(state) << shift); DCHECK_EQ(GetState(id), state); } + void Reset() { states_.clear(); } private: static const uint32 kStatesPerWord = sizeof(uint64) / 2 /*bits per entry*/; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 133aa2509405738de8388708b0c61a82023e2738..170adb3d241b3648bc53f96dde9866f0b794f80a 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -85,6 +85,9 @@ class DfsHloVisitorWithDefaultBase Status HandleConvolution(HloInstructionPtr convolution) override { return DefaultAction(convolution); } + Status HandleFft(HloInstructionPtr fft) override { + return DefaultAction(fft); + } Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } diff --git a/tensorflow/compiler/xla/service/dot_decomposer.cc b/tensorflow/compiler/xla/service/dot_decomposer.cc new file mode 100644 index 0000000000000000000000000000000000000000..12faed69677cd99c6ed82c8d13dad3138d9461b7 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.cc @@ -0,0 +1,185 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/dot_decomposer.h" + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +namespace { + +// TODO(b/69062148) Remove this code when all backends support BatchDot +// natively. +Status DecomposeBatchDot(HloInstruction* dot) { + auto computation = dot->parent(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + HloInstruction* lhs = dot->mutable_operand(0); + HloInstruction* rhs = dot->mutable_operand(1); + const Shape& lhs_shape = lhs->shape(); + const Shape& rhs_shape = rhs->shape(); + const Shape& dot_shape = dot->shape(); + + // ShapeInference should guarantee that lhs/rhs batch dimensions match. + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + const int64 num_batch_dims = dnums.lhs_batch_dimensions_size(); + // Calculate total batch size (note that ShapeInference requires that + // the batch dimensions are most-major). + int64 batch_size = 1; + for (int i = 0; i < num_batch_dims; ++i) { + CHECK_EQ(lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)), + rhs_shape.dimensions(dnums.rhs_batch_dimensions(i))); + batch_size *= lhs_shape.dimensions(dnums.lhs_batch_dimensions(i)); + } + + // Set lhs/rhs_transpose. + CHECK_EQ(1, dnums.lhs_contracting_dimensions_size()); + const int64 lhs_contracting_dim_number = dnums.lhs_contracting_dimensions(0); + const bool lhs_transpose = (lhs_contracting_dim_number - num_batch_dims) == 0; + + CHECK_EQ(1, dnums.rhs_contracting_dimensions_size()); + const int64 rhs_contracting_dim_number = dnums.rhs_contracting_dimensions(0); + const bool rhs_transpose = (rhs_contracting_dim_number - num_batch_dims) == 1; + + // Compute R3 and R3 shapes for lhs. + PrimitiveType lhs_type = lhs_shape.element_type(); + const int64 lhs_rows = lhs_shape.dimensions(num_batch_dims + 0); + const int64 lhs_cols = lhs_shape.dimensions(num_batch_dims + 1); + Shape lhs_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {batch_size, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r3 = + ShapeUtil::MakeShape(lhs_type, {1, lhs_rows, lhs_cols}); + Shape lhs_slice_shape_r2 = + ShapeUtil::MakeShape(lhs_type, {lhs_rows, lhs_cols}); + + // Compute R3 and R3 shapes for rhs. + PrimitiveType rhs_type = rhs_shape.element_type(); + const int64 rhs_rows = rhs_shape.dimensions(num_batch_dims + 0); + const int64 rhs_cols = rhs_shape.dimensions(num_batch_dims + 1); + Shape rhs_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {batch_size, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r3 = + ShapeUtil::MakeShape(rhs_type, {1, rhs_rows, rhs_cols}); + Shape rhs_slice_shape_r2 = + ShapeUtil::MakeShape(rhs_type, {rhs_rows, rhs_cols}); + + // Compute R3 and R3 shapes for dot output. + PrimitiveType dot_type = dot_shape.element_type(); + const int64 dot_rows = dot_shape.dimensions(num_batch_dims + 0); + const int64 dot_cols = dot_shape.dimensions(num_batch_dims + 1); + Shape dot_shape_r2 = ShapeUtil::MakeShape(dot_type, {dot_rows, dot_cols}); + Shape dot_shape_r3 = ShapeUtil::MakeShape(dot_type, {1, dot_rows, dot_cols}); + Shape concat_shape_r3 = + ShapeUtil::MakeShape(dot_type, {batch_size, dot_rows, dot_cols}); + + // Reshape lhs/rhs into R3. + auto lhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_shape_r3, lhs)); + auto rhs_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_shape_r3, rhs)); + + // Loop through batch size, slicing out required lhs/rhs to compute each Dot. + std::vector output_slices(batch_size); + for (int64 i = 0; i < batch_size; ++i) { + // Slice R3 shape from 'lhs' and reshape to R2. + auto lhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(lhs_slice_shape_r3, lhs_r3, {i, 0, 0}, + {i + 1, lhs_rows, lhs_cols}, {1, 1, 1})); + auto lhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(lhs_slice_shape_r2, lhs_slice_r3)); + + // Slice R3 shape from 'rhs' and reshape to R2. + auto rhs_slice_r3 = computation->AddInstruction( + HloInstruction::CreateSlice(rhs_slice_shape_r3, rhs_r3, {i, 0, 0}, + {i + 1, rhs_rows, rhs_cols}, {1, 1, 1})); + auto rhs_slice_r2 = computation->AddInstruction( + HloInstruction::CreateReshape(rhs_slice_shape_r2, rhs_slice_r3)); + + // Transpose lhs/rhs (if needed). + if (lhs_transpose) { + Shape lhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(lhs_type, {lhs_cols, lhs_rows}); + lhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + lhs_slice_shape_r2_transpose, lhs_slice_r2, {1, 0})); + } + if (rhs_transpose) { + Shape rhs_slice_shape_r2_transpose = + ShapeUtil::MakeShape(rhs_type, {rhs_cols, rhs_rows}); + rhs_slice_r2 = + computation->AddInstruction(HloInstruction::CreateTranspose( + rhs_slice_shape_r2_transpose, rhs_slice_r2, {1, 0})); + } + + // Compute Dot of lhs/rhs R2 slices. + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot_r2 = computation->AddInstruction(HloInstruction::CreateDot( + dot_shape_r2, lhs_slice_r2, rhs_slice_r2, dot_dnums)); + + // Reshape Dot to R3 so we can concat along batch dimension. + auto dot_r3 = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape_r3, dot_r2)); + + output_slices[i] = dot_r3; + } + + // Concatenate slices from 'output_slices' along batch dimension. + auto concat = computation->AddInstruction( + HloInstruction::CreateConcatenate(concat_shape_r3, output_slices, 0)); + // Reshape output 'new_dot' to original dimensions. + auto new_dot = computation->AddInstruction( + HloInstruction::CreateReshape(dot_shape, concat)); + + // Replace all uses of 'dot' in 'computation' with 'new_dot'. + return computation->ReplaceInstruction(dot, new_dot); +} + +} // namespace + +StatusOr DotDecomposer::Run(HloModule* module) { + XLA_VLOG_LINES(2, "DotDecomposer ENTRY\n" + module->ToString()); + // Gather all batch Dot operations. + std::vector batch_dots; + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kDot) { + continue; + } + const DotDimensionNumbers& dnums = instruction->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 && decompose_batch_dot_) { + batch_dots.push_back(instruction); + } + } + } + // Decompose each batch Dot in 'batch_dots'. + bool changed = false; + for (auto* dot : batch_dots) { + TF_RETURN_IF_ERROR(DecomposeBatchDot(dot)); + changed = true; + } + XLA_VLOG_LINES(2, "DotDecompose EXIT\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/dot_decomposer.h b/tensorflow/compiler/xla/service/dot_decomposer.h new file mode 100644 index 0000000000000000000000000000000000000000..1959b687f16d6909a3283021c8635b3e65e6e412 --- /dev/null +++ b/tensorflow/compiler/xla/service/dot_decomposer.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// DotDecomposer is a pass which decomposes batch Dot operations into a +// sequence of smaller (R2) Dot operations. +class DotDecomposer : public HloPassInterface { + public: + // Decomposes batch Dot operations when 'decompose_batch_dot' is true. + DotDecomposer(bool decompose_batch_dot = true) + : decompose_batch_dot_(decompose_batch_dot) {} + ~DotDecomposer() = default; + tensorflow::StringPiece name() const override { return "dot_decomposer"; } + + // Run DotDecomposer pass on computations in 'module'. + // Returns whether the 'module' was changed. + StatusOr Run(HloModule* module) override; + + private: + bool decompose_batch_dot_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DOT_DECOMPOSER_H_ diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index b9407818cd8bc82aabd32ed02f61ef66fe442625..4468adbadbf823f1420a8b665a26f66cb7d36b43 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -50,11 +50,161 @@ using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; using tensorflow::strings::StrCat; +namespace { + +llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits, + int64 mantissa_bits, + llvm::IRBuilder<>* ir_builder) { + // Integer and float types for casting and constant generation. + llvm::Type* float_type = x->getType(); + llvm::IntegerType* int_type = ir_builder->getInt32Ty(); + + // Cast the input value to an integer for bitwise manipulation. + llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type); + + if (mantissa_bits < 23) { + // Last remaining mantissa bit. + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. This is + // equal to a base value of 0111... plus one bit if the last remaining + // mantissa bit is 1. + const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; + llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr( + ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), + (23 - mantissa_bits)); + llvm::Value* x_rounding_bias = ir_builder->CreateAdd( + x_last_mantissa_bit, + llvm::ConstantInt::get(int_type, base_rounding_bias)); + + // Add rounding bias, and mask out truncated bits. Note that the case + // where adding the rounding bias overflows into the exponent bits is + // correct; the non-masked mantissa bits will all be zero, and the + // exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias); + x_as_int = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); + } + + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- + // significant bit -- is equal to 1.0f for all exponent sizes. Adding + // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- + // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' + // exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n is + // (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + llvm::Value* x_exponent = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + llvm::Value* x_overflows = ir_builder->CreateICmpUGT( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); + llvm::Value* x_underflows = ir_builder->CreateICmpULE( + x_exponent, + llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); + + // Compute appropriately-signed values of zero and infinity. + llvm::Value* x_signed_zero = ir_builder->CreateAnd( + x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); + llvm::Value* x_signed_inf = ir_builder->CreateOr( + x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); + + // Force to zero or infinity if overflow or underflow. (Note that this + // truncates all denormal values to zero, rather than rounding them.) + x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int); + x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int); + } + + // Cast the result back to a floating-point type. + llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type); + + // Correct result for NaN inputs. + // + // The exponent handling will "normalize" NaN values to infinities, which is + // undesirable (except in the case with no mantissa bits, in which case it + // is mandatory). This logic also handles cases where mantissa-rounding + // causes a NaN's mantissa to overflow into the exponent bits, which would + // otherwise create an erroneous zero value. + // + // If the fast-math flags are set to assume no NaNs, the comparison is likely + // to be optimized away, so there's no point in even emitting it. + if (!ir_builder->getFastMathFlags().noNaNs()) { + llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x); + + if (mantissa_bits > 0) { + result = ir_builder->CreateSelect(x_is_nan, x, result); + } else { + result = ir_builder->CreateSelect( + x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); + } + } + return result; +} + +llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, + llvm::IRBuilder<>* ir_builder) { + auto reduced_precision = EmitReducePrecisionFloat( + f32_value, + /*exponent_bits=*/primitive_util::kBFloat16ExponentBits, + /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder); + auto as_int32 = + ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateLShr(as_int32, 16); + auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty()); + return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty()); +} + +llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, + llvm::IRBuilder<>* ir_builder) { + auto as_int16 = + ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty()); + auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty()); + auto shifted = ir_builder->CreateShl(as_int32, 16); + return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy()); +} + +llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value, + PrimitiveType from_type, + PrimitiveType to_type, llvm::Module* module, + llvm::IRBuilder<>* ir_builder) { + if (primitive_util::IsSignedIntegralType(from_type)) { + return ir_builder->CreateSIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } else { + CHECK(primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED); + return ir_builder->CreateUIToFP( + integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module)); + } +} + +} // namespace + StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; - } else if (operand_value->getType()->isIntegerTy()) { + } else if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + op->operand(0)->shape().element_type() == PRED) { return EmitIntegerUnaryOp(op, operand_value); } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { return EmitComplexUnaryOp(op, operand_value); @@ -79,15 +229,14 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { - if (primitive_util::IsSignedIntegralType(from_type)) { - return ir_builder_->CreateSIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); - } - if (primitive_util::IsUnsignedIntegralType(from_type) || - from_type == PRED) { - return ir_builder_->CreateUIToFP( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + if (to_type == BF16) { + return EmitF32ToBF16( + EmitIntegralToFloating(operand_value, from_type, F32, module_, + ir_builder_), + ir_builder_); } + return EmitIntegralToFloating(operand_value, from_type, to_type, + module_, ir_builder_); } if (primitive_util::IsComplexType(to_type)) { auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( @@ -207,6 +356,17 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), nullptr); } + if (from_type == BF16) { + TF_RET_CHECK(to_type != BF16); + operand_value = EmitBF16ToF32(operand_value, ir_builder_); + from_type = F32; + if (from_type == to_type) { + return operand_value; + } + } + if (from_type == F32 && to_type == BF16) { + return EmitF32ToBF16(operand_value, ir_builder_); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); @@ -244,21 +404,13 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( primitive_util::BitWidth(to_type)); } case HloOpcode::kExp: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitExp(op->shape().element_type(), operand_value); case HloOpcode::kLog: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitLog(op->shape().element_type(), operand_value); case HloOpcode::kCos: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitCos(op->shape().element_type(), operand_value); case HloOpcode::kSin: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {operand_value}, - {operand_value->getType()}, - ir_builder_); + return EmitSin(op->shape().element_type(), operand_value); case HloOpcode::kFloor: return llvm_ir::EmitCallToIntrinsic( llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()}, @@ -276,7 +428,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( llvm::Intrinsic::round, {operand_value}, {operand_value->getType()}, ir_builder_); case HloOpcode::kSign: { - // TODO(b/32151903): Ensure consistent sign behavior for -0.0 + // TODO(b/32151903): Ensure consistent sign behavior for -0.0. auto type = operand_value->getType(); auto zero = llvm::ConstantFP::get(type, 0.0); auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero); @@ -309,9 +461,25 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( StatusOr ElementalIrEmitter::EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + PrimitiveType component_type = + primitive_util::IsComplexType(input_type) + ? primitive_util::ComplexComponentType(input_type) + : input_type; switch (op->opcode()) { - // TODO(b/65209142): Angle/Log require atan2. - // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + case HloOpcode::kLog: { + // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq)); + TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return EmitComposeComplex( + op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); + } case HloOpcode::kConvert: { PrimitiveType from_type = op->operand(0)->shape().element_type(); TF_RET_CHECK(primitive_util::IsComplexType(from_type)); @@ -333,15 +501,12 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( } case HloOpcode::kExp: { // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto exp_a = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::exp, {EmitExtractReal(operand_value)}, - {EmitExtractReal(operand_value)->getType()}, ir_builder_); - auto cos_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::cos, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); - auto sin_b = llvm_ir::EmitCallToIntrinsic( - llvm::Intrinsic::sin, {EmitExtractImag(operand_value)}, - {EmitExtractImag(operand_value)->getType()}, ir_builder_); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitExp(component_type, EmitExtractReal(operand_value))); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value))); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value))); return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), ir_builder_->CreateFMul(exp_a, sin_b)); } @@ -356,16 +521,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -386,16 +548,13 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( auto a = EmitExtractReal(operand_value); auto b = EmitExtractImag(operand_value); auto type = a->getType(); - auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b)); auto half_exp_b = ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); auto half_exp_neg_b = ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); - auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, - {type}, ir_builder_); - auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, - {type}, ir_builder_); + TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a)); + TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a)); return EmitComposeComplex( op, ir_builder_->CreateFMul( @@ -403,6 +562,58 @@ StatusOr ElementalIrEmitter::EmitComplexUnaryOp( ir_builder_->CreateFMul( cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); } + case HloOpcode::kTanh: { + /* + tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) + e^(a+bi) = e^a*(cos(b)+sin(b)i) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) + cos(b)=cos(-b), sin(-b)=-sin(b) + so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / + (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) + =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / + (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / + (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) + This is a complex division, so we can multiply by denom_conj/denom_conj + =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * + (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + + i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / + ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) + */ + auto a = EmitExtractReal(operand_value); + auto b = EmitExtractImag(operand_value); + TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a)); + TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b)); + TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b)); + auto exp_neg_a = ir_builder_->CreateFDiv( + llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); + auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(exp_a, exp_a), + ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); + auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); + auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); + auto real_num = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), + ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); + auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); + auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); + auto exp_a_plus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); + auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); + auto exp_a_minus_exp_neg_a_sq = + ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); + auto imag_num = ir_builder_->CreateFMul( + cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, + exp_a_minus_exp_neg_a_sq)); + auto denom = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), + ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); + return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), + ir_builder_->CreateFDiv(imag_num, denom)); + } case HloOpcode::kAbs: { auto sum_sq = ir_builder_->CreateFAdd( ir_builder_->CreateFMul(EmitExtractReal(operand_value), @@ -449,7 +660,8 @@ StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { PrimitiveType operand_type = op->operand(0)->shape().element_type(); - if (lhs_value->getType()->isIntegerTy()) { + if (ShapeUtil::ElementIsIntegral(op->operand(0)->shape()) || + operand_type == PRED) { return EmitIntegerBinaryOp( op, lhs_value, rhs_value, primitive_util::IsSignedIntegralType(operand_type)); @@ -464,7 +676,6 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { switch (op->opcode()) { - // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support case HloOpcode::kComplex: return EmitComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: @@ -508,10 +719,9 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( case HloOpcode::kMinimum: return EmitFloatMin(lhs_value, rhs_value); case HloOpcode::kPower: - return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, - {lhs_value, rhs_value}, - {lhs_value->getType()}, ir_builder_); - + return EmitPow(op->shape().element_type(), lhs_value, rhs_value); + case HloOpcode::kAtan2: + return EmitAtan2(op->shape().element_type(), lhs_value, rhs_value); default: return Unimplemented("binary floating point op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -607,9 +817,40 @@ StatusOr ElementalIrEmitter::EmitComplexBinaryOp( EmitExtractImag(lhs_value), EmitExtractImag(rhs_value), ir_builder_)); - // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic - // case HloOpcode::kPower: - // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2) + case HloOpcode::kPower: { + // (a+bi)^(c+di) = + // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), + // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) + PrimitiveType component_type = + primitive_util::ComplexComponentType(op->shape().element_type()); + auto a = EmitExtractReal(lhs_value); + auto b = EmitExtractImag(lhs_value); + auto c = EmitExtractReal(rhs_value); + auto d = EmitExtractImag(rhs_value); + auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); + auto half_c = ir_builder_->CreateFMul(one_half, c); + + TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c, + EmitPow(component_type, aa_p_bb, half_c)); + auto neg_d = ir_builder_->CreateFNeg(d); + TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a)); + auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); + TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs, + EmitExp(component_type, neg_d_arg_lhs)); + auto coeff = + ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); + TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb)); + auto half_d = ir_builder_->CreateFMul(one_half, d); + auto q = + ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), + ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); + TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q)); + TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q)); + return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), + ir_builder_->CreateFMul(coeff, sin_q)); + } default: return Unimplemented("binary complex op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -629,7 +870,10 @@ llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value, StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, llvm::Value* x) const { if (prim_type != F32) { - return Unimplemented("inverse erf only implemented for F32 (b/34339814)"); + // TODO(b/34339814): Implement inverse erf for F64. + return Unimplemented( + "Inverse erf is only implemented for element " + "type F32."); } auto getFloat = [&](const float f) { return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f); @@ -712,116 +956,51 @@ StatusOr ElementalIrEmitter::EmitErfcInv( return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } -StatusOr ElementalIrEmitter::EmitReducePrecision( - const HloInstruction* hlo, llvm::Value* x) const { - if (hlo->operand(0)->shape().element_type() != F32) { - return Unimplemented("reduce-precision only implemented for F32"); - } - - // Integer and float types for casting and constant generation. - llvm::Type* float_type = x->getType(); - llvm::IntegerType* int_type = ir_builder_->getInt32Ty(); - - // Cast the input value to an integer for bitwise manipulation. - llvm::Value* x_as_int = ir_builder_->CreateBitCast(x, int_type); - - if (hlo->mantissa_bits() < 23) { - // Last remaining mantissa bit. - const uint32_t last_mantissa_bit_mask = 1u << (23 - hlo->mantissa_bits()); - - // Compute rounding bias for round-to-nearest with ties to even. This is - // equal to a base value of 0111... plus one bit if the last remaining - // mantissa bit is 1. - const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1; - llvm::Value* x_last_mantissa_bit = ir_builder_->CreateLShr( - ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)), - (23 - hlo->mantissa_bits())); - llvm::Value* x_rounding_bias = ir_builder_->CreateAdd( - x_last_mantissa_bit, - llvm::ConstantInt::get(int_type, base_rounding_bias)); - - // Add rounding bias, and mask out truncated bits. Note that the case - // where adding the rounding bias overflows into the exponent bits is - // correct; the non-masked mantissa bits will all be zero, and the - // exponent will be incremented by one. - const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); - x_as_int = ir_builder_->CreateAdd(x_as_int, x_rounding_bias); - x_as_int = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, truncation_mask)); - } - - if (hlo->exponent_bits() < 8) { - // Masks for f32 values. - const uint32_t f32_sign_bit_mask = 1u << 31; - const uint32_t f32_exp_bits_mask = 0xffu << 23; - - // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the most- - // significant bit -- is equal to 1.0f for all exponent sizes. Adding - // 2^(n-1)-1 to this gives us the highest non-infinite exponent for a bit- - // size of n, and subtracting 2^(n-1)-1 from this gives us the lowest' - // exponent (corresponding to 0.0f). - // - // Thus, the f32 exponent corresponding to the highest non-infinite - // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 - // exponent corresponding to the lowest exponent for a bit size of n is - // (2^7-1) - 2^(n-1)-1. - // - // Note that we have already checked that exponents_bits >= 1. - const uint32_t f32_exponent_bias = (1 << 7) - 1; - const uint32_t reduced_exponent_bias = - (1 << (hlo->exponent_bits() - 1)) - 1; - const uint32_t reduced_max_exponent = - f32_exponent_bias + reduced_exponent_bias; - const uint32_t reduced_min_exponent = - f32_exponent_bias - reduced_exponent_bias; +StatusOr ElementalIrEmitter::EmitLog(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value}, + {value->getType()}, ir_builder_); +} - // Do we overflow or underflow? - llvm::Value* x_exponent = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); - llvm::Value* x_overflows = ir_builder_->CreateICmpUGT( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_max_exponent << 23)); - llvm::Value* x_underflows = ir_builder_->CreateICmpULE( - x_exponent, - llvm::ConstantInt::get(int_type, reduced_min_exponent << 23)); +StatusOr ElementalIrEmitter::EmitSin(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value}, + {value->getType()}, ir_builder_); +} - // Compute appropriately-signed values of zero and infinity. - llvm::Value* x_signed_zero = ir_builder_->CreateAnd( - x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask)); - llvm::Value* x_signed_inf = ir_builder_->CreateOr( - x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask)); +StatusOr ElementalIrEmitter::EmitCos(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value}, + {value->getType()}, ir_builder_); +} - // Force to zero or infinity if overflow or underflow. (Note that this - // truncates all denormal values to zero, rather than rounding them.) - x_as_int = ir_builder_->CreateSelect(x_overflows, x_signed_inf, x_as_int); - x_as_int = ir_builder_->CreateSelect(x_underflows, x_signed_zero, x_as_int); - } +StatusOr ElementalIrEmitter::EmitExp(PrimitiveType prim_type, + llvm::Value* value) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value}, + {value->getType()}, ir_builder_); +} - // Cast the result back to a floating-point type. - llvm::Value* result = ir_builder_->CreateBitCast(x_as_int, float_type); +StatusOr ElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs}, + {lhs->getType()}, ir_builder_); +} - // Correct result for NaN inputs. - // - // The exponent handling will "normalize" NaN values to infinities, which is - // undesirable (except in the case with no mantissa bits, in which case it - // is mandatory). This logic also handles cases where mantissa-rounding - // causes a NaN's mantissa to overflow into the exponent bits, which would - // otherwise create an erroneous zero value. - // - // If the fast-math flags are set to assume no NaNs, the comparison is likely - // to be optimized away, so there's no point in even emitting it. - if (!ir_builder_->getFastMathFlags().noNaNs()) { - llvm::Value* x_is_nan = ir_builder_->CreateFCmpUNO(x, x); +StatusOr ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return Unimplemented("atan2"); +} - if (hlo->mantissa_bits() > 0) { - result = ir_builder_->CreateSelect(x_is_nan, x, result); - } else { - result = ir_builder_->CreateSelect( - x_is_nan, llvm::ConstantFP::getInfinity(float_type), result); - } +StatusOr ElementalIrEmitter::EmitReducePrecision( + const HloInstruction* hlo, llvm::Value* x) const { + if (hlo->operand(0)->shape().element_type() != F32) { + return Unimplemented("reduce-precision only implemented for F32"); } - return result; + return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(), + /*mantissa_bits=*/hlo->mantissa_bits(), + ir_builder_); } StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( @@ -864,17 +1043,9 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE, lhs_value, rhs_value, ir_builder_); case HloOpcode::kMinimum: - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return EmitIntegralMin(lhs_value, rhs_value, is_signed); case HloOpcode::kMaximum: - return ir_builder_->CreateSelect( - ir_builder_->CreateICmp( - is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, - lhs_value, rhs_value), - lhs_value, rhs_value); + return EmitIntegralMax(lhs_value, rhs_value, is_signed); case HloOpcode::kAnd: return ir_builder_->CreateAnd(lhs_value, rhs_value); case HloOpcode::kOr: @@ -891,6 +1062,26 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( } } +llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value, + llvm::Value* rhs_value, + bool is_signed) const { + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, + lhs_value, rhs_value), + lhs_value, rhs_value); +} + +llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value, + llvm::Value* rhs_value, + bool is_signed) const { + return ir_builder_->CreateSelect( + ir_builder_->CreateICmp( + is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE, + lhs_value, rhs_value), + lhs_value, rhs_value); +} + llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex( const llvm_ir::IrArray::Index& target_index, const HloInstruction& hlo, int64 operand_no) const { @@ -1088,14 +1279,6 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( get_next_uniform_float()))); return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m); } - case RNG_BERNOULLI: { - TF_ASSIGN_OR_RETURN(llvm::Value * p, - operand_to_generator.at(hlo->operand(0))(index)); - return ir_builder_->CreateZExt( - ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - module_)); - } default: return InvalidArgument( "unhandled distribution %s", @@ -1195,7 +1378,18 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * max_value, operand_to_generator.at(hlo->operand(2))( ElementwiseSourceIndex(index, *hlo, 2))); - return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + PrimitiveType prim_type = hlo->shape().element_type(); + if (primitive_util::IsFloatingPointType(prim_type)) { + return EmitFloatMin(max_value, EmitFloatMax(min_value, arg_value)); + } else if (primitive_util::IsIntegralType(prim_type)) { + bool is_signed = primitive_util::IsSignedIntegralType(prim_type); + return EmitIntegralMin( + max_value, EmitIntegralMax(min_value, arg_value, is_signed), + is_signed); + } else { + return Unimplemented("Clamp unimplemented for %s", + PrimitiveType_Name(prim_type).c_str()); + } }; case HloOpcode::kReducePrecision: return [this, hlo, &operand_to_generator]( diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index cccb498f82936283a215370787907b293827ff2d..c516a826d9e382bc738e54635426db639d17108c 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -39,7 +39,7 @@ class ElementalIrEmitter { module_(module), hlo_module_config_(hlo_module_config) {} - virtual ~ElementalIrEmitter() {} + virtual ~ElementalIrEmitter() = default; virtual StatusOr EmitUnaryOp(const HloInstruction* op, llvm::Value* operand_value) const; @@ -86,12 +86,38 @@ class ElementalIrEmitter { virtual llvm::Value* EmitFloatMin(llvm::Value* lhs_value, llvm::Value* rhs_value) const; + llvm::Value* EmitIntegralMax(llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed) const; + + llvm::Value* EmitIntegralMin(llvm::Value* lhs_value, llvm::Value* rhs_value, + bool is_signed) const; + virtual StatusOr EmitErfInv(PrimitiveType prim_type, llvm::Value* value) const; virtual StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const; + virtual StatusOr EmitAtan2(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + + virtual StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value) const; + + virtual StatusOr EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const; + virtual StatusOr EmitReducePrecision(const HloInstruction* hlo, llvm::Value* x) const; diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 9c96d9eb30b5f9e51b7f5d82391c6b9f366898d6..90481c7a88f90edea5399ee44aee2d2c77fc115f 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -24,25 +24,25 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" +using tensorflow::gtl::ArraySlice; + namespace xla { -StatusOr> +StatusOr>> Executable::ExecuteOnStreams( - tensorflow::gtl::ArraySlice run_options, - tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> - arguments) { + ArraySlice run_options, + ArraySlice> arguments) { TF_RET_CHECK(run_options.size() == arguments.size()); + std::vector> return_values(run_options.size()); + if (run_options.size() == 1) { - TF_ASSIGN_OR_RETURN(auto result, + TF_ASSIGN_OR_RETURN(return_values[0], ExecuteOnStream(&run_options[0], arguments[0], /*hlo_execution_profile=*/nullptr)); - return std::vector({result}); + return std::move(return_values); } - std::vector return_values( - run_options.size()); for (size_t i = 0; i < run_options.size(); ++i) { // We cannot BlockHostUntilDone() on the already-launched executions in case // of error, since if the executions communicate, the initially launched @@ -52,9 +52,77 @@ Executable::ExecuteOnStreams( } for (const auto& options : run_options) { TF_RET_CHECK(options.stream() != nullptr); - options.stream()->BlockHostUntilDone(); + TF_RETURN_IF_ERROR(options.stream()->BlockHostUntilDone()); + } + return std::move(return_values); +} + +StatusOr> Executable::ExecuteOnStreamWrapper( + const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, + ArraySlice arguments) { + perftools::gputools::Stream* stream = run_options->stream(); + std::unique_ptr timer; + if (profile != nullptr) { + timer.reset(new perftools::gputools::Timer(stream->parent())); + stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); } - return return_values; + + VLOG(1) << "enqueueing executable on stream..."; + // If the profiling flag isn't enabled, we pass nullptr as the profile to + // indicate profiling is not requested. + std::unique_ptr profile_ptr = + module_config().debug_options().xla_hlo_profile() && + hlo_profiling_enabled() + ? MakeUnique(&hlo_profile_printer_data(), + &hlo_profile_index_map()) + : nullptr; + + StatusOr> return_value = + ExecuteOnStream(run_options, arguments, profile_ptr.get()); + + if (profile != nullptr) { + VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; + stream->ThenStopTimer(timer.get()); + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); + VLOG(1) << "done with block-host-until-done"; + + // Merge in run-time profile information from execution_profile. + // + // TODO(b/71713097): This is buggy -- even though the mutex takes care of + // C++ level races, some other concurrent ExecuteOnStreamWrapper call could + // have rewritten the execution_profile before we get to it. + profile->MergeFrom(execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + if (stream->ok()) { + // Don't read timer->Nanoseconds() if the stream isn't OK -- that's + // illegal. + profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); + } + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + + if (profile_ptr != nullptr) { + XLA_LOG_LINES( + tensorflow::INFO, + profile_ptr->ToString(stream->parent()->GetDeviceDescription())); + hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", + profile_ptr.get()); + } + + return return_value; } Status Executable::DumpSessionModule() { diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index 08862308c90af736c1adcaa9438973f858852506..0aee535ee780ef000bc5e9963ff48786b3a61eb2 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -44,13 +44,14 @@ namespace xla { // interface that is used for launching compiled programs across platforms. class Executable { public: - explicit Executable(std::unique_ptr hlo_module, - std::unique_ptr hlo_profile_printer, - std::unique_ptr hlo_profile_index_map) + explicit Executable( + std::unique_ptr hlo_module, + std::unique_ptr hlo_profile_printer_data, + std::unique_ptr hlo_profile_index_map) : hlo_module_(std::move(hlo_module)), - hlo_profile_printer_(std::move(hlo_profile_printer)), + hlo_profile_printer_data_(std::move(hlo_profile_printer_data)), hlo_profile_index_map_(std::move(hlo_profile_index_map)) { - CHECK_EQ(hlo_profile_printer_.get() == nullptr, + CHECK_EQ(hlo_profile_printer_data_.get() == nullptr, hlo_profile_index_map_.get() == nullptr); } virtual ~Executable() {} @@ -61,16 +62,7 @@ class Executable { // If the hlo_execution_profile is provided as non-nullptr, profiling will be // enabled. // - // Returns the device memory region that a successful execution would - // populate. - virtual StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) = 0; - - // Overload of ExecuteOnStream which returns and takes arguments as - // ShapedBuffers. Used for LocalService execution. + // Returns a shaped buffer containing the result of the computation. virtual StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, @@ -78,21 +70,19 @@ class Executable { // Same as ExecuteOnStream(), but this call is non-blocking and returns as // soon as all of the operations are enqueued for launch on the stream. - virtual StatusOr ExecuteAsyncOnStream( + virtual StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) = 0; + tensorflow::gtl::ArraySlice arguments) = 0; // Same as ExecuteOnStream(), but runs this executable on multiple // streams. arguments[i] contains the arguments to the execution on // run_options[i]->stream() and the returned value is at index i of the // returned vector. - virtual StatusOr> - ExecuteOnStreams( + virtual StatusOr>> ExecuteOnStreams( tensorflow::gtl::ArraySlice run_options, tensorflow::gtl::ArraySlice< - tensorflow::gtl::ArraySlice> + tensorflow::gtl::ArraySlice> arguments); // Populates `hlo_execution_profile` from `executor`. This is implicit in any @@ -107,13 +97,10 @@ class Executable { // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the - // given ExecutionProfile if non-null. The ExecuteOnStream overloads have - // different argument types and return types, so this method is templated on - // argument type and return type of the execute function. - template - StatusOr ExecuteOnStreamWrapper( + // given ExecutionProfile if non-null. + StatusOr> ExecuteOnStreamWrapper( const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - const ArgT& arguments); + tensorflow::gtl::ArraySlice arguments); // Returns the ExecutionProfile from executing on the device. This includes // the number of cycles taken for the computation or the compilation time. @@ -130,9 +117,9 @@ class Executable { "Equality test on this executable is not implemented."); } - const HloProfilePrinter& hlo_profile_printer() const { + const HloProfilePrinterData& hlo_profile_printer_data() const { CHECK(hlo_profiling_enabled()); - return *hlo_profile_printer_; + return *hlo_profile_printer_data_; } const HloProfileIndexMap& hlo_profile_index_map() const { @@ -143,7 +130,9 @@ class Executable { // Returns whether this executable was compiled with HLO profilings support // enabled. If not, the caller should not expect an hlo_execution_profile // passed to ExecuteOnStream above to be populated during execution. - bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; } + bool hlo_profiling_enabled() const { + return hlo_profile_printer_data_ != nullptr; + } const HloModule& module() const { return *hlo_module_; } @@ -193,70 +182,10 @@ class Executable { // execution. int64 execution_count_ = 0; - std::unique_ptr hlo_profile_printer_; + std::unique_ptr hlo_profile_printer_data_; std::unique_ptr hlo_profile_index_map_; }; -template -StatusOr Executable::ExecuteOnStreamWrapper( - const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile, - const ArgT& arguments) { - perftools::gputools::Stream* stream = run_options->stream(); - std::unique_ptr timer; - if (profile != nullptr) { - timer.reset(new perftools::gputools::Timer(stream->parent())); - stream->InitTimer(timer.get()).ThenStartTimer(timer.get()); - } - - VLOG(1) << "enqueueing executable on stream..."; - // If the profiling flag isn't enabled, we pass nullptr as the profile to - // indicate profiling is not requested. - std::unique_ptr profile_ptr = - module_config().debug_options().xla_hlo_profile() && - hlo_profiling_enabled() - ? MakeUnique(&hlo_profile_printer(), - &hlo_profile_index_map()) - : nullptr; - - auto return_value = - ExecuteOnStream(run_options, arguments, profile_ptr.get()); - - if (profile != nullptr) { - VLOG(1) << "enqueueing 'stop timer' and blocking host until done..."; - stream->ThenStopTimer(timer.get()).BlockHostUntilDone(); - VLOG(1) << "done with block-host-until-done"; - - // Merge in run-time profile information from execution_profile. - profile->MergeFrom(execution_profile()); - - // Overall execution time (in nanoseconds) from the executor timer. - profile->set_compute_and_transfer_time_ns(timer->Nanoseconds()); - - // TODO(b/28123297): On GPU we end up including transfer time in - // the compute time this way. Instead, we should get the correct - // value by measuring it. Setting the field here at least lets - // benchmarks provide *some* value for GPU computations. - // - // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually - // the compute time without the transfer time, so this way we get the - // correct compute time. We should instead have the correct value for - // compute_and_transfer_time and set compute_time to the compute time. - if (profile->compute_time_ns() == 0) { - profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); - } - } - - if (profile_ptr != nullptr) { - XLA_LOG_LINES( - tensorflow::INFO, - profile_ptr->ToString(stream->parent()->GetDeviceDescription())); - hlo_graph_dumper::MaybeDumpHloModule(module(), "Service::Execute", - profile_ptr.get()); - } - - return return_value; -} - } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_EXECUTABLE_H_ diff --git a/tensorflow/compiler/xla/service/execution_tracker.cc b/tensorflow/compiler/xla/service/execution_tracker.cc index c225e62e3e11d2d01251b0f92272b0949eff8af1..2f0b9ed2bd98fbea4e67c0a30d5aa41ff6a06979 100644 --- a/tensorflow/compiler/xla/service/execution_tracker.cc +++ b/tensorflow/compiler/xla/service/execution_tracker.cc @@ -39,9 +39,7 @@ AsyncExecution::AsyncExecution(Backend* backend, tensorflow::Status AsyncExecution::BlockUntilDone() const { for (auto& stream : streams_) { - if (!stream->BlockHostUntilDone()) { - return InternalError("failed to block until done"); - } + TF_RETURN_IF_ERROR(stream->BlockHostUntilDone()); } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index dfba22a6c4c5cf071c2cd8621643b8da6587ee3b..2b6caa149439a86d6d047605099bc3ff7b295a8e 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -26,7 +26,10 @@ namespace xla { namespace { -// Helper to replace the called computation at a while- or call-instruction. +// Helper to replace the called computation at a while-, call-, or +// conditional-instruction. This function replaces exactly one instance of +// 'computation' with 'new_computation' even if 'instruction' calls +// 'computation' more than once. void ReplaceCalledComputation(HloInstruction* instruction, HloComputation* computation, HloComputation* new_computation) { @@ -45,6 +48,15 @@ void ReplaceCalledComputation(HloInstruction* instruction, instruction->set_to_apply(new_computation); break; } + case HloOpcode::kConditional: { + if (computation == instruction->true_computation()) { + instruction->set_true_computation(new_computation); + } else { + CHECK_EQ(computation, instruction->false_computation()); + instruction->set_false_computation(new_computation); + } + break; + } default: LOG(FATAL) << "unexpected opcode: " << HloOpcodeString(instruction->opcode()); diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index a68e90b7d009890012f94baa790d911871c9c960..d3854b40de3572a60df1ad99d8a4589f59ad7194 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -223,5 +223,35 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { EXPECT_EQ(1, b_node.caller_callsites().size()); } +TEST_F(FlattenCallGraphTest, FlattenCallsInConditional) { + auto module = CreateNewModule(); + HloComputation* sub_computation = + module->AddEmbeddedComputation(MakeScalarComputation()); + + // Create entry computation, which is a conditional that has the same + // computation in the true and false branch. + HloComputation::Builder builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + builder.AddInstruction(HloInstruction::CreateConditional( + kScalarShape, pred, constant1, sub_computation, constant2, + sub_computation)); + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(2, module->computation_count()); + + TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); + EXPECT_TRUE(result); + std::unique_ptr call_graph = CallGraph::Build(module.get()); + // The true and false computations must now be different. + EXPECT_EQ(3, module->computation_count()); + + const CallGraphNode& sub_node = call_graph->GetNode(sub_computation); + EXPECT_EQ(1, sub_node.caller_callsites().size()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 74aa77b4f165be76fbc0a8aa1a4a7e90a8e9acec..78dc0ad4fcd167c93f19d0c2b18ea72d666897ef 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -51,83 +51,7 @@ se::Platform::Id GenericTransferManager::PlatformId() const { return platform_id_; } -Status GenericTransferManager::TransferLiteralFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, Literal* literal) { - VLOG(2) << "transferring literal shape from device: " - << ShapeUtil::HumanString(literal_shape) - << "; device location: " << source.opaque(); - TF_RET_CHECK(ShapeUtil::Compatible(device_shape, literal_shape)); - - // Tuples are a special case and contain one or more shapes inside of them to - // an arbitrary nesting depth. - if (device_shape.element_type() == TUPLE) { - *literal->mutable_shape() = literal_shape; - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - ShallowCopyTupleFromDevice(executor, source, device_shape)); - TF_RET_CHECK(element_buffers.size() == - ShapeUtil::TupleElementCount(device_shape)); - for (int64 i = 0; i < element_buffers.size(); ++i) { - const Shape& element_device_shape = device_shape.tuple_shapes(i); - const Shape& element_literal_shape = literal_shape.tuple_shapes(i); - Literal* element_literal = literal->add_tuple_literals(); - // Recursively call TransferFromDevice to copy over the data in the - // element array. - TF_RETURN_IF_ERROR(TransferLiteralFromDevice( - executor, element_buffers[i], /*device_shape=*/element_device_shape, - /*literal_shape=*/element_literal_shape, element_literal)); - } - return Status::OK(); - } - - *literal->mutable_shape() = device_shape; - literal->Reserve(ShapeUtil::ElementsIn(device_shape)); - TF_RETURN_IF_ERROR(TransferBufferFromDevice( - executor, source, /*size=*/ShapeUtil::ByteSizeOf(device_shape), - /*destination=*/literal->MutableInternalData())); - if (!ShapeUtil::Equal(literal_shape, device_shape)) { - *literal = std::move(*literal->Relayout(literal_shape.layout())); - } - TF_RET_CHECK(ShapeUtil::Equal(literal_shape, literal->shape())); - return Status::OK(); -} - -StatusOr> -GenericTransferManager::ShallowCopyTupleFromDevice( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - // For devices which use the GenericTransferManager, a tuple is stored as an - // array of pointers to buffers. Copy the contents of the tuple buffer into - // a vector of void* pointers. - std::vector element_pointers(ShapeUtil::TupleElementCount(shape), - nullptr); - int64 tuple_size = ShapeUtil::ByteSizeOf(shape, pointer_size_); - auto copy_status = executor->SynchronousMemcpyD2H(source, tuple_size, - element_pointers.data()); - if (!copy_status.ok()) { - return AddStatus( - Status(static_cast(copy_status.code()), - copy_status.error_message()), - "failed transfer of tuple buffer " + ShapeUtil::HumanString(shape)); - } - - // Create a DeviceMemoryBase from each void* pointer. - std::vector destination; - for (size_t i = 0; i < element_pointers.size(); ++i) { - if (element_pointers[i] == nullptr && - !ShapeUtil::HasZeroElements(shape.tuple_shapes(i))) { - return FailedPrecondition("tuple contains nullptr at element %lu", i); - } - destination.emplace_back(element_pointers[i], - GetByteSizeRequirement(shape.tuple_shapes(i))); - } - return std::move(destination); -} - -Status GenericTransferManager::WriteTuplePointersToDevice( +Status GenericTransferManager::WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, const Shape& shape, perftools::gputools::DeviceMemoryBase* region) { @@ -145,16 +69,19 @@ StatusOr> GenericTransferManager::TransferLiteralFromDevice( se::StreamExecutor* executor, const ShapedBuffer& device_buffer) { VLOG(2) << "transferring literal from device ordinal " - << executor->device_ordinal() << "; device shape: " - << ShapeUtil::HumanStringWithLayout(device_buffer.shape()) - << "; opaque: " << device_buffer.buffer(/*index=*/{}).opaque(); + << executor->device_ordinal() << "; device buffer: " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); + std::unique_ptr literal = - Literal::CreateFromShape(device_buffer.shape()); + Literal::CreateFromShape(device_buffer.on_host_shape()); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& subshape, const ShapeIndex& index) -> Status { if (!ShapeUtil::IsTuple(subshape)) { TF_RETURN_IF_ERROR(TransferBufferFromDevice( @@ -162,7 +89,7 @@ GenericTransferManager::TransferLiteralFromDevice( /*source=*/device_buffer.buffer(index), /*size=*/GetByteSizeRequirement(subshape), /*destination=*/ - literal->GetSubliteral(index).MutableInternalData())); + literal->untyped_data(index))); } return Status::OK(); @@ -175,33 +102,39 @@ Status GenericTransferManager::TransferLiteralToDevice( const ShapedBuffer& device_buffer) { const Shape& shape = literal.shape(); VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) << "; device location: " - << device_buffer.buffer(/*index=*/{}).opaque(); + << ShapeUtil::HumanString(shape) + << "; device buffer: " << device_buffer; + + // The on-host and on-device shape should always be the same for the generic + // transfer manager. + TF_RET_CHECK(ShapeUtil::Equal(device_buffer.on_device_shape(), + device_buffer.on_host_shape())); - TF_RET_CHECK(ShapeUtil::Compatible(literal.shape(), device_buffer.shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(literal.shape(), device_buffer.on_host_shape())); TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); TF_RETURN_IF_ERROR(WriteTupleIndexTables(executor, device_buffer)); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_host_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); if (ShapeUtil::IsArray(device_subshape)) { TF_RET_CHECK(GetByteSizeRequirement(device_subshape) == device_memory.size()); // Element is array-shaped: transfer array data to device buffer. - const Literal& subliteral = literal.GetSubliteral(index); + const auto subliteral = LiteralView::Create(literal, index); std::unique_ptr relayed_out_literal; const void* source; if (LayoutUtil::Equal(device_subshape.layout(), subliteral.shape().layout())) { - source = subliteral.InternalData(); + source = subliteral.untyped_data(); } else { // Relayout data before transferring. relayed_out_literal = subliteral.Relayout(device_subshape.layout(), /*shape_index=*/{}); - source = relayed_out_literal->InternalData(); + source = relayed_out_literal->untyped_data(); } return TransferBufferToDevice( executor, @@ -212,33 +145,6 @@ Status GenericTransferManager::TransferLiteralToDevice( }); } -Status GenericTransferManager::TransferLiteralToDevice( - se::StreamExecutor* executor, const Literal& literal, - se::DeviceMemoryBase* destination) { - const Shape& shape = literal.shape(); - VLOG(2) << "transferring literal shape to device: " - << ShapeUtil::HumanString(shape) - << "; device location: " << destination->opaque(); - - if (ShapeUtil::IsTuple(literal.shape())) { - std::vector tuple_elements_on_device; - for (const Literal& tuple_element : literal.tuple_literals()) { - se::DeviceMemoryBase allocation = executor->AllocateArray( - GetByteSizeRequirement(tuple_element.shape())); - TF_RETURN_IF_ERROR( - TransferLiteralToDevice(executor, tuple_element, &allocation)); - tuple_elements_on_device.push_back(allocation.opaque()); - } - return TransferBufferToDevice( - executor, tuple_elements_on_device.size() * sizeof(void*), - tuple_elements_on_device.data(), destination); - } - - return TransferBufferToDevice(executor, - /*size=*/GetByteSizeRequirement(shape), - /*source=*/literal.InternalData(), destination); -} - Status GenericTransferManager::TransferLiteralToInfeed( se::StreamExecutor* executor, const Literal& literal) { return Unimplemented("Generic transfer to Infeed"); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 50dca6aec5012f0b02cb54846b622f008600e48e..63a7c820cf4e5fbbdf870086a4fb5316ac50d10b 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -42,16 +42,6 @@ class GenericTransferManager : public TransferManager { perftools::gputools::Platform::Id PlatformId() const override; - Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) override; - - Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* destination) override; - StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) override; @@ -62,9 +52,6 @@ class GenericTransferManager : public TransferManager { Status TransferLiteralToInfeed(perftools::gputools::StreamExecutor* executor, const Literal& literal) override; - Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, - int64 size, const void* source) override; - Status TransferLiteralFromOutfeed( perftools::gputools::StreamExecutor* executor, const Shape& literal_shape, Literal* literal) override; @@ -73,16 +60,13 @@ class GenericTransferManager : public TransferManager { tensorflow::gtl::ArraySlice executors) override; - StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) override; - int64 GetByteSizeRequirement(const Shape& shape) const override; protected: - Status WriteTuplePointersToDevice( + Status TransferBufferToInfeed(perftools::gputools::StreamExecutor* executor, + int64 size, const void* source) override; + + Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e57558b5788965214cadf5eab1024860f1a39ca1..9da4fb97fa27a238fead74985cb481a9be1f4a65 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -23,6 +23,15 @@ filegroup( load("//tensorflow:tensorflow.bzl", "tf_cc_test") +cc_library( + name = "gpu_constants", + srcs = ["gpu_constants.cc"], + hdrs = ["gpu_constants.h"], + deps = [ + "//tensorflow/compiler/xla:types", + ], +) + cc_library( name = "partition_assignment", srcs = [ @@ -120,9 +129,13 @@ cc_library( hdrs = [ "ir_emitter.h", "ir_emitter_context.h", + "ir_emitter_nested.h", + "ir_emitter_unnested.h", ], deps = [ + ":cudnn_convolution_runner", ":elemental_ir_emitter", + ":gpu_constants", ":gpu_executable", ":hlo_to_ir_bindings", ":ir_emission_utils", @@ -203,6 +216,7 @@ cc_library( srcs = ["buffer_allocations.cc"], hdrs = ["buffer_allocations.h"], deps = [ + ":gpu_constants", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", @@ -217,8 +231,11 @@ cc_library( cc_library( name = "gpu_executable", srcs = [ + "conditional_thunk.cc", "convolution_thunk.cc", "copy_thunk.cc", + "cudnn_batchnorm_thunk.cc", + "fft_thunk.cc", "for_thunk.cc", "gemm_thunk.cc", "gpu_executable.cc", @@ -230,8 +247,11 @@ cc_library( "while_thunk.cc", ], hdrs = [ + "conditional_thunk.h", "convolution_thunk.h", "copy_thunk.h", + "cudnn_batchnorm_thunk.h", + "fft_thunk.h", "for_thunk.h", "gemm_thunk.h", "gpu_executable.h", @@ -245,7 +265,9 @@ cc_library( ], deps = [ ":buffer_allocations", + ":cudnn_convolution_runner", ":infeed_manager", + ":ir_emission_utils", ":partition_assignment", ":stream_assignment", "//tensorflow/compiler/xla:array2d", @@ -269,6 +291,7 @@ cc_library( "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cudnn_plugin", + "//tensorflow/core/platform/default/build_config:cufft_plugin", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep ], ) @@ -290,9 +313,41 @@ cc_library( ) cc_library( - name = "convolution_folding", - srcs = ["convolution_folding.cc"], - hdrs = ["convolution_folding.h"], + name = "cudnn_convolution_algorithm_picker", + srcs = ["cudnn_convolution_algorithm_picker.cc"], + hdrs = ["cudnn_convolution_algorithm_picker.h"], + deps = [ + ":cudnn_convolution_runner", + ":gpu_executable", + ":ir_emission_utils", + "//tensorflow/compiler/xla/service:device_memory_allocator", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_runner", + srcs = ["cudnn_convolution_runner.cc"], + hdrs = ["cudnn_convolution_runner.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/core:stream_executor_no_cuda", + ], +) + +cc_library( + name = "cudnn_convolution_rewriter", + srcs = ["cudnn_convolution_rewriter.cc"], + hdrs = ["cudnn_convolution_rewriter.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:literal_util", @@ -306,15 +361,18 @@ cc_library( ) tf_cc_test( - name = "convolution_folding_test", - srcs = ["convolution_folding_test.cc"], + name = "cudnn_convolution_rewriter_test", + srcs = ["cudnn_convolution_rewriter_test.cc"], deps = [ - ":convolution_folding", + ":cudnn_convolution_rewriter", + ":ir_emission_utils", + "//tensorflow/compiler/xla:test", "//tensorflow/compiler/xla:test_helpers", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", "//tensorflow/compiler/xla/service:shape_inference", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:test", ], ) @@ -427,15 +485,18 @@ cc_library( srcs = ["gpu_compiler.cc"], hdrs = ["gpu_compiler.h"], deps = [ - ":convolution_folding", + ":cudnn_convolution_algorithm_picker", + ":cudnn_convolution_rewriter", ":fusion_merger", + ":gpu_constants", ":gpu_copy_insertion", ":gpu_executable", + ":gpu_hlo_support_checker", + ":gpu_layout_assignment", ":hlo_schedule", ":instruction_fusion", ":ir_emission_utils", ":ir_emitter", - ":layout_assignment", ":pad_insertion", ":partition_assignment", ":stream_assignment", @@ -445,16 +506,18 @@ cc_library( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/service:algebraic_simplifier", - "//tensorflow/compiler/xla/service:batchnorm_rewriter", + "//tensorflow/compiler/xla/service:batchnorm_expander", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", + "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_constant_folding", "//tensorflow/compiler/xla/service:hlo_cse", "//tensorflow/compiler/xla/service:hlo_dce", + "//tensorflow/compiler/xla/service:hlo_element_type_converter", "//tensorflow/compiler/xla/service:hlo_pass", "//tensorflow/compiler/xla/service:hlo_pass_pipeline", "//tensorflow/compiler/xla/service:hlo_proto", @@ -467,11 +530,14 @@ cc_library( "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_simplifier", + "//tensorflow/compiler/xla/service:zero_sized_hlo_elimination", + "//tensorflow/compiler/xla/service/gpu:cudnn_batchnorm_rewriter", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", "@llvm//:support", @@ -479,6 +545,18 @@ cc_library( alwayslink = True, # Contains compiler registration ) +cc_library( + name = "cudnn_batchnorm_rewriter", + srcs = ["cudnn_batchnorm_rewriter.cc"], + hdrs = ["cudnn_batchnorm_rewriter.h"], + deps = [ + ":ir_emission_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_pass", + ], +) + cc_library( name = "infeed_manager", srcs = ["infeed_manager.cc"], @@ -492,9 +570,9 @@ cc_library( ) cc_library( - name = "layout_assignment", - srcs = ["layout_assignment.cc"], - hdrs = ["layout_assignment.h"], + name = "gpu_layout_assignment", + srcs = ["gpu_layout_assignment.cc"], + hdrs = ["gpu_layout_assignment.h"], deps = [ ":ir_emission_utils", "//tensorflow/compiler/xla:shape_util", @@ -508,17 +586,18 @@ cc_library( ) tf_cc_test( - name = "layout_assignment_test", - srcs = ["layout_assignment_test.cc"], + name = "gpu_layout_assignment_test", + srcs = ["gpu_layout_assignment_test.cc"], deps = [ - ":layout_assignment", + ":gpu_layout_assignment", + ":ir_emission_utils", "//tensorflow/compiler/xla:shape_layout", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], ) @@ -586,6 +665,32 @@ tf_cc_test( ], ) +cc_library( + name = "gpu_hlo_support_checker", + srcs = ["gpu_hlo_support_checker.cc"], + hdrs = ["gpu_hlo_support_checker.h"], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo_pass", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "gpu_hlo_support_checker_test", + srcs = ["gpu_hlo_support_checker_test.cc"], + deps = [ + ":gpu_hlo_support_checker", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc index 9fdf717b5d463010e2709b6209c070f25555de72..2029c303d47e9a62135b003c3bd9be6f8b3438d4 100644 --- a/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc +++ b/tensorflow/compiler/xla/service/gpu/buffer_allocations.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -48,6 +49,15 @@ StatusOr> BufferAllocations::Builder::Build( // If buffer #i's address is already registered (e.g. external arguments or // result buffers), use that registered buffer. if (registered_buffers_.count(i)) { + se::DeviceMemoryBase address = FindOrDie(registered_buffers_, i); + if (reinterpret_cast(address.opaque()) % + kCudaMallocAlignBytes != + 0) { + return InternalError( + "Address of registered buffer %lld must be a multiple of %llx, but " + "was %p", + i, kCudaMallocAlignBytes, address.opaque()); + } buffer_allocations->SetBuffer(i, FindOrDie(registered_buffers_, i)); continue; } @@ -67,6 +77,14 @@ StatusOr> BufferAllocations::Builder::Build( tensorflow::strings::HumanReadableNumBytes(buffer_size).c_str(), i); } + if (reinterpret_cast(buffer_address.opaque()) % + kCudaMallocAlignBytes != + 0) { + return InternalError( + "Address returned by memory_allocator->Allocate must be a " + "multiple of %llx, but was %p", + kCudaMallocAlignBytes, buffer_address.opaque()); + } } buffer_allocations->SetBuffer(i, buffer_address); if (allocation.IsPreallocatedTempBuffer()) { @@ -80,6 +98,14 @@ StatusOr> BufferAllocations::Builder::Build( } } + if (VLOG_IS_ON(2)) { + for (BufferAllocation::Index i = 0; i < num_buffers; ++i) { + const auto& buf = buffer_allocations->buffers_[i]; + VLOG(2) << "Buffer " << i << " -> " << buf.opaque() << " (" << buf.size() + << "B)"; + } + } + return std::move(buffer_allocations); } diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..790ca535b11ee47724ef6227de40726d940d6153 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +ConditionalThunk::ConditionalThunk( + const BufferAllocation::Slice& predicate_buffer_index, + const BufferAllocation::Slice& true_operand_buffer_index, + const BufferAllocation::Slice& false_operand_buffer_index, + ThunkSequence true_thunk_sequence, ThunkSequence false_thunk_sequence, + const HloInstruction* hlo) + : Thunk(Kind::kConditional, hlo), + predicate_buffer_index_(predicate_buffer_index), + true_operand_buffer_index_(true_operand_buffer_index), + false_operand_buffer_index_(false_operand_buffer_index), + true_thunk_(std::move(true_thunk_sequence), hlo), + false_thunk_(std::move(false_thunk_sequence), hlo) {} + +Status ConditionalThunk::Initialize(const GpuExecutable& executable) { + TF_RETURN_IF_ERROR(true_thunk_.Initialize(executable)); + TF_RETURN_IF_ERROR(false_thunk_.Initialize(executable)); + return Status::OK(); +} + +Status ConditionalThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + // Copy the predicate value from device. + bool predicate; + perftools::gputools::DeviceMemoryBase predicate_address = + buffer_allocations.GetDeviceAddress(predicate_buffer_index_); + stream->ThenMemcpy(&predicate, predicate_address, sizeof(bool)); + + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to retrieve predicate value on stream %p: %s.", + stream, block_status.error_message().c_str()); + } + + // Execute the true or the false computation depending on the value of the + // predicate. + if (predicate) { + TF_RETURN_IF_ERROR(true_thunk_.ExecuteOnStream(buffer_allocations, stream)); + } else { + TF_RETURN_IF_ERROR( + false_thunk_.ExecuteOnStream(buffer_allocations, stream)); + } + + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/conditional_thunk.h b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..7725c46a3b4b51af34a4dd977885353ff32c21f6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/conditional_thunk.h @@ -0,0 +1,65 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ + +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// ConditionalThunk implements the conditional instruction on GPU by reading the +// predicate of the conditional and executing the true or the false computation +// depending on the value of the predicate. +// +// ConditionalThunk assumes that the buffers of the conditional result and the +// result of the true and false computations share the same allocation. Also, +// the buffers of the true operand of the conditional and that of the parameter +// instruction of the true computation share the same allocation. Similarly, the +// buffers of the false operand and that of the parameter instruction of the +// false computation share the same allocation. +class ConditionalThunk : public Thunk { + public: + ConditionalThunk(const BufferAllocation::Slice& predicate_buffer_index, + const BufferAllocation::Slice& true_operand_buffer_index, + const BufferAllocation::Slice& false_operand_buffer_index, + ThunkSequence true_thunk_sequence, + ThunkSequence false_thunk_sequence, + const HloInstruction* hlo); + + ConditionalThunk(const ConditionalThunk&) = delete; + ConditionalThunk& operator=(const ConditionalThunk&) = delete; + + Status Initialize(const GpuExecutable& executable) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice predicate_buffer_index_; + BufferAllocation::Slice true_operand_buffer_index_; + BufferAllocation::Slice false_operand_buffer_index_; + SequentialThunk true_thunk_; + SequentialThunk false_thunk_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONDITIONAL_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 037eec8ef59e1aeccdfc43dbb5c1a852403780d1..15bba49b73bce8eb4a18175f8874f05049119458 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -36,364 +37,70 @@ using se::dnn::DataLayout; using se::dnn::FilterDescriptor; using se::dnn::FilterLayout; -ConvolveScratchAllocator::ConvolveScratchAllocator( - int device_ordinal, DeviceMemoryAllocator* memory_allocator) - : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} - -ConvolveScratchAllocator::~ConvolveScratchAllocator() { - for (auto& allocated_buffer : allocated_buffers_) { - if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) - .ok()) { - // The program can still continue with failed deallocation. - LOG(ERROR) << "Failed to deallocate the allocated buffer: " - << allocated_buffer.opaque(); - } - } -} - -int64 ConvolveScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { - constexpr int64 kConvolveScratchSize = 1LL << 32; // 4GB by default. - return kConvolveScratchSize; -} - -se::port::StatusOr> -ConvolveScratchAllocator::AllocateBytes(se::Stream* stream, int64 byte_size) { - CHECK_GE(byte_size, 0) << "byte_size must be positive."; - if (byte_size > GetMemoryLimitInBytes(stream)) { - return se::port::Status( - se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Allocating %lld bytes exceeds the memory limit of %lld bytes.", - byte_size, GetMemoryLimitInBytes(stream))); - } - - auto status_or_memory = - memory_allocator_->Allocate(device_ordinal_, byte_size, - /*retry_on_failure=*/false); - if (!status_or_memory.ok()) { - return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, - tensorflow::strings::Printf( - "Failed to allocate %lld bytes on device %d.", - byte_size, device_ordinal_)); - } - se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); - allocated_buffers_.push_back(allocated_buffer); - total_allocated_bytes_ += byte_size; - return se::DeviceMemory(allocated_buffer); -} - -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind) { - switch (convolution_kind) { - case ConvolutionThunk::ConvolutionKind::kForward: - return "forward"; - case ConvolutionThunk::ConvolutionKind::kBackwardFilter: - return "backward_filter"; - case ConvolutionThunk::ConvolutionKind::kBackwardInput: - return "backward_input"; - } - return "unknown convolution kind"; -} - ConvolutionThunk::ConvolutionThunk( - ConvolutionKind convolution_kind, - const BufferAllocation::Slice& input_buffer, + CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, - const BufferAllocation::Slice& output_buffer, const Shape& input_shape, + const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dim_nums, const HloInstruction* hlo) + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + bool tensor_ops_enabled, const HloInstruction* hlo) : Thunk(Kind::kConvolution, hlo), convolution_kind_(convolution_kind), input_buffer_(input_buffer), filter_buffer_(filter_buffer), output_buffer_(output_buffer), + tuple_result_buffer_(tuple_result_buffer), + scratch_buffer_(scratch_buffer), input_shape_(input_shape), filter_shape_(filter_shape), output_shape_(output_shape), window_(window), - dim_nums_(dim_nums) {} + dim_nums_(dim_nums), + algorithm_(algorithm), + tensor_ops_enabled_(tensor_ops_enabled) {} -tensorflow::Status ConvolutionThunk::ExecuteOnStream( +Status ConvolutionThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, se::Stream* stream) { - VLOG(3) << "Convolution kind: " << ConvolutionKindToString(convolution_kind_); - VLOG(3) << "input shape: { " << input_shape_.ShortDebugString() << " }"; - VLOG(3) << "filter shape: { " << filter_shape_.ShortDebugString() << " }"; - VLOG(3) << "Output shape: { " << output_shape_.ShortDebugString() << " }"; - VLOG(3) << "Dim nums: { " << dim_nums_.ShortDebugString() << " }"; - VLOG(3) << "Window: { " << window_.ShortDebugString() << " }"; - - const int num_dimensions = window_.dimensions_size(); - CHECK_LE(num_dimensions, 3); - // cuDNN does not support 1D convolutions. We therefore express 1D - // convolutions as 2D convolutions where the first spatial dimension is 1. - // This matches the behavior of TF (see definition of conv1d in - // tensorflow/python/ops/nn_ops.py). - const int effective_num_dimensions = std::max(2, num_dimensions); - - CHECK_EQ(F32, output_shape_.element_type()); - CHECK_EQ(num_dimensions, dim_nums_.input_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.kernel_spatial_dimensions_size()); - CHECK_EQ(num_dimensions, dim_nums_.output_spatial_dimensions_size()); - for (const WindowDimension& dim : window_.dimensions()) { - CHECK_EQ(dim.padding_low(), dim.padding_high()); - } - - // cuDNN's convolution APIs support the BDYX layout for activations/output and - // the OIYX layout for weights. - BatchDescriptor input_descriptor(effective_num_dimensions); - input_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - input_shape_.dimensions(dim_nums_.input_feature_dimension())) - .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - // Note that the dimensions are reversed. The same holds below. - input_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - input_shape_.dimensions(dim_nums_.input_spatial_dimensions(dim))); - } - - FilterDescriptor filter_descriptor(effective_num_dimensions); - filter_descriptor.set_layout(FilterLayout::kOutputInputYX) - .set_input_feature_map_count( - filter_shape_.dimensions(dim_nums_.kernel_input_feature_dimension())) - .set_output_feature_map_count(filter_shape_.dimensions( - dim_nums_.kernel_output_feature_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - filter_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - filter_shape_.dimensions(dim_nums_.kernel_spatial_dimensions(dim))); - } - - ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); - for (int dim = 0; dim < num_dimensions; ++dim) { - convolution_descriptor - .set_zero_padding( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).padding_low()) - .set_filter_stride( - static_cast(effective_num_dimensions - dim - 1), - window_.dimensions(dim).stride()); - } - - BatchDescriptor output_descriptor(effective_num_dimensions); - output_descriptor.set_layout(DataLayout::kBatchDepthYX) - .set_feature_map_count( - output_shape_.dimensions(dim_nums_.output_feature_dimension())) - .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension())); - for (int dim = 0; dim < num_dimensions; ++dim) { - output_descriptor.set_spatial_dim( - static_cast(effective_num_dimensions - dim - 1), - output_shape_.dimensions(dim_nums_.output_spatial_dimensions(dim))); - } - - // Add a singleton dimension in the 1D convolution case. - if (num_dimensions == 1) { - input_descriptor.set_spatial_dim(static_cast(0), 1); - output_descriptor.set_spatial_dim(static_cast(0), 1); - filter_descriptor.set_spatial_dim(static_cast(0), 1); - convolution_descriptor - .set_zero_padding(static_cast(0), 0) - .set_filter_stride(static_cast(0), 1); - } - se::DeviceMemory input_data( buffer_allocations.GetDeviceAddress(input_buffer_)); se::DeviceMemory filter_data( buffer_allocations.GetDeviceAddress(filter_buffer_)); se::DeviceMemory output_data( buffer_allocations.GetDeviceAddress(output_buffer_)); - return ConvolveWithTune(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, buffer_allocations, stream); -} - -tensorflow::Status ConvolutionThunk::Convolve( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, - const se::dnn::AlgorithmConfig& algorithm_config, se::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, - se::dnn::ProfileResult* profile_result) { - bool launch_ok; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - launch_ok = - stream - ->ThenConvolveBackwardFilterWithAlgorithm( - input_descriptor, input_data, output_descriptor, output_data, - convolution_descriptor, filter_descriptor, &filter_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - case ConvolutionKind::kBackwardInput: - launch_ok = stream - ->ThenConvolveBackwardDataWithAlgorithm( - filter_descriptor, filter_data, output_descriptor, - output_data, convolution_descriptor, input_descriptor, - &input_data, scratch_allocator, algorithm_config, - profile_result) - .ok(); - break; - case ConvolutionKind::kForward: - launch_ok = - stream - ->ThenConvolveWithAlgorithm( - input_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, &output_data, - scratch_allocator, algorithm_config, profile_result) - .ok(); - break; - } - if (launch_ok) { - return tensorflow::Status::OK(); - } - return InternalError( - "Unable to launch convolution for thunk %p with type %s and algorithm " - "(%lld, %lld)", - this, ConvolutionKindToString(convolution_kind_).c_str(), - algorithm_config.algorithm().algo_id(), - algorithm_config.algorithm_no_scratch().algo_id()); -} - -std::vector ConvolutionThunk::GetAlgorithms( - bool with_winograd_nonfused, se::StreamExecutor* stream_exec) const { - std::vector algorithms; - switch (convolution_kind_) { - case ConvolutionKind::kBackwardFilter: - CHECK(stream_exec->GetConvolveBackwardFilterAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kBackwardInput: - CHECK(stream_exec->GetConvolveBackwardDataAlgorithms( - with_winograd_nonfused, &algorithms)); - break; - case ConvolutionKind::kForward: - CHECK(stream_exec->GetConvolveAlgorithms(with_winograd_nonfused, - &algorithms)); - break; - } - return algorithms; -} - -static string AlgorithmToString(const se::dnn::AlgorithmDesc& algo) { - if (algo.tensor_ops_enabled()) { - return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); - } - return tensorflow::strings::StrCat(algo.algo_id()); -} - -// Determines whether we can safely perform a winograd non-fused convolution for -// the given input and output descriptors. This works around b/68264959, an -// integer overflow in cuDNNv5 and cuDNNv6. -static bool ShouldIncludeWinogradNonfusedAlgo( - const BatchDescriptor& input_descriptor, - const BatchDescriptor& output_descriptor) { - int64 batch = input_descriptor.count(); - int64 in_depths = input_descriptor.feature_map_count(); - int64 in_rows = input_descriptor.height(); - int64 in_cols = input_descriptor.width(); - int64 out_depths = output_descriptor.feature_map_count(); - - int64 total_size = 16 * std::ceil(batch / 16.0) * - std::max(in_depths, out_depths) * in_cols * in_rows * - sizeof(float); - int64 threshold = 1L << 31; - - return total_size < threshold; -} - -tensorflow::Status ConvolutionThunk::ConvolveWithTune( - const BatchDescriptor& input_descriptor, se::DeviceMemory input_data, - const FilterDescriptor& filter_descriptor, - se::DeviceMemory filter_data, - const BatchDescriptor& output_descriptor, - se::DeviceMemory output_data, - const ConvolutionDescriptor& convolution_descriptor, - const BufferAllocations& buffer_allocations, se::Stream* stream) { - // TODO(b/29126320): Try cudnn v5's new auto-tuner when it's rolled out. - if (best_algorithm_.algorithm().is_default()) { - // Auto-tuning either is disabled or only happens in the first run of this - // function. - VLOG(2) << "Profiling for best convolution algorithm used for " - "ConvolutionThunk: " - << this; - - bool with_winograd_nonfused = - ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor); - - se::dnn::ProfileResult best_result; - se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = - GetAlgorithms(with_winograd_nonfused, stream->parent()); - for (auto algorithm : algorithms) { - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - se::dnn::ProfileResult profile_result; - VLOG(3) << "Trying algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk: " << this; - bool launch_ok = - Convolve(input_descriptor, input_data, filter_descriptor, filter_data, - output_descriptor, output_data, convolution_descriptor, - se::dnn::AlgorithmConfig(algorithm, algorithm), stream, - &scratch_allocator, &profile_result) - .ok(); - if (launch_ok && profile_result.is_valid()) { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " succeeded, taking " - << profile_result.elapsed_time_in_ms() - << "ms. (Best result: " << best_result.elapsed_time_in_ms() - << "ms)"; - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalAllocatedBytes() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_without_scratch.elapsed_time_in_ms()) { - best_result_without_scratch = profile_result; - } - } else { - VLOG(3) << "Run of algorithm " << AlgorithmToString(algorithm) - << " for ConvolutionThunk " << this << " failed."; - } - } - - if (best_result.is_valid()) { - best_algorithm_.set_algorithm(best_result.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm works with profiling. Fall back " - "to the default algorithm."; - best_algorithm_.set_algorithm(AlgorithmDesc()); + se::DeviceMemoryBase scratch = + buffer_allocations.GetDeviceAddress(scratch_buffer_); + + se::dnn::AlgorithmConfig algorithm_config( + se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_)); + + TF_RETURN_IF_ERROR(RunCudnnConvolution( + convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data, + filter_data, output_data, scratch, window_, dim_nums_, algorithm_config, + stream)); + + // Figure out which of output/input/filter is the result produced by this op, + // and write the result tuple. + void* result_ptr = [&] { + switch (convolution_kind_) { + case CudnnConvKind::kForward: + return output_data.opaque(); + case CudnnConvKind::kBackwardInput: + return input_data.opaque(); + case CudnnConvKind::kBackwardFilter: + return filter_data.opaque(); } + }(); + void* ptrs[] = {result_ptr, scratch.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(tuple_result_buffer_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); - if (best_result_without_scratch.is_valid()) { - best_algorithm_.set_algorithm_no_scratch( - best_result_without_scratch.algorithm()); - } else { - LOG(ERROR) << "No convolution algorithm without scratch works with " - "profiling. Fall back " - "to the default algorithm."; - best_algorithm_.set_algorithm_no_scratch(AlgorithmDesc()); - } - } - - { - VLOG(2) << "Using convolution algorithm (" - << AlgorithmToString(best_algorithm_.algorithm()) << ", " - << AlgorithmToString(best_algorithm_.algorithm_no_scratch()) - << ") for ConvolutionThunk: " << this; - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - return Convolve(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, best_algorithm_, stream, - &scratch_allocator, nullptr); + if (!stream->ok()) { + return InternalError("ConvolutionThunk::ExecuteOnStream failed."); } + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 5ac5db2f04b6796c6013a7f87dd40b485233baa6..900d9cb6243088b56a1825fb3ab8c06cf8d74726 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -18,89 +18,60 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" #include "tensorflow/compiler/xla/service/gpu/thunk.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" namespace xla { namespace gpu { -// A one-time scratch allocator for forward and backward convolution. The -// scratch buffers allocated are released on destruction. -// -// Not thread-safe. -class ConvolveScratchAllocator : public perftools::gputools::ScratchAllocator { - public: - ConvolveScratchAllocator(int device_ordinal, - DeviceMemoryAllocator* memory_allocator); - - ~ConvolveScratchAllocator() override; - - int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; - - int64 TotalAllocatedBytes() { return total_allocated_bytes_; } - - perftools::gputools::port::StatusOr> - AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; - - private: - const int device_ordinal_; - DeviceMemoryAllocator* memory_allocator_; - std::vector allocated_buffers_; - int64 total_allocated_bytes_ = 0; -}; - // This class stores everything that StreamExecutor needs to launch a BNN // convolution. It is generated by IrEmitter. // // This is thread-compatible. class ConvolutionThunk : public Thunk { public: - // ConvolutionThunk performs one of the following types of convolution. - enum class ConvolutionKind { - kBackwardFilter, // Backward convolution for filter. - kBackwardInput, // Backward convolution for input. - kForward, // Forward convolution. - }; - - // Constructs a thunk for launching a DNN convolution. + // Constructs a thunk for launching a DNN convolution. When run, it will + // write a tuple (result, scratch_memory) into `tuple_result_buffer`. + // + // `algorithm` is a cudnn algorithm number. `algorithm == -1` indicates that + // we should use the default (i.e. baseline) cudnn algorithm. + // + // Note that "output" here doesn't refer to the output from running this + // thunk, but rather to the "output" of a hypothetical forward convolution + // that corresponds to this input+filter+output triple. That is, the result + // generated by this thunk is "output" for forward convs, "input" for + // backward-input convs, and "filter" for backward-filter convs. + // // Semantics of null hlo_instruction argument are as in Thunk. - ConvolutionThunk(ConvolutionKind convolution_kind, + ConvolutionThunk(CudnnConvKind convolution_kind, const BufferAllocation::Slice& input_buffer, const BufferAllocation::Slice& filter_buffer, const BufferAllocation::Slice& output_buffer, + const BufferAllocation::Slice& tuple_result_buffer, + const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape, const Shape& filter_shape, const Shape& output_shape, const Window& window, - const ConvolutionDimensionNumbers& dnums, - const HloInstruction* hlo); + const ConvolutionDimensionNumbers& dim_nums, int64 algorithm, + bool tensor_ops_enabled, const HloInstruction* hlo); ConvolutionThunk(const ConvolutionThunk&) = delete; ConvolutionThunk& operator=(const ConvolutionThunk&) = delete; - // Does the convolution for the thunk on "stream". Auto-tuning happens on the - // first run of this function. - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + // Does the convolution for the thunk on "stream". + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: - tensorflow::Status ConvolveWithTune( - const perftools::gputools::dnn::BatchDescriptor& input_descriptor, - perftools::gputools::DeviceMemory input_data, - const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, - perftools::gputools::DeviceMemory filter_data, - const perftools::gputools::dnn::BatchDescriptor& output_descriptor, - perftools::gputools::DeviceMemory output_data, - const perftools::gputools::dnn::ConvolutionDescriptor& - convolution_descriptor, - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream); + class ScratchAllocator; - tensorflow::Status Convolve( + Status Convolve( const perftools::gputools::dnn::BatchDescriptor& input_descriptor, perftools::gputools::DeviceMemory input_data, const perftools::gputools::dnn::FilterDescriptor& filter_descriptor, @@ -110,39 +81,27 @@ class ConvolutionThunk : public Thunk { const perftools::gputools::dnn::ConvolutionDescriptor& convolution_descriptor, const perftools::gputools::dnn::AlgorithmConfig& algorithm_config, - perftools::gputools::Stream* stream, - ConvolveScratchAllocator* scratch_allocator, + perftools::gputools::Stream* stream, ScratchAllocator* scratch_allocator, perftools::gputools::dnn::ProfileResult* profile_result); - // Returns the convolve algorithms that can be used for this ConvolutionThunk. - std::vector GetAlgorithms( - bool with_winograd_nonfused, - perftools::gputools::StreamExecutor* stream_exec) const; - - // Fastest cuDNN convolution algorithm for this thunk learned from - // auto-tuning. If auto-tuning is disabled or failed, best_algorithm_ is set - // to the default value indicating cuDNN's convolution will choose - // the best algorithm from some heuristics based on its parameters. - perftools::gputools::dnn::AlgorithmConfig best_algorithm_; - - const ConvolutionKind convolution_kind_; + const CudnnConvKind convolution_kind_; const BufferAllocation::Slice input_buffer_; const BufferAllocation::Slice filter_buffer_; const BufferAllocation::Slice output_buffer_; + const BufferAllocation::Slice tuple_result_buffer_; + const BufferAllocation::Slice scratch_buffer_; const Shape input_shape_; const Shape filter_shape_; const Shape output_shape_; const Window window_; - const ConvolutionDimensionNumbers dim_nums_; + int64 algorithm_; + bool tensor_ops_enabled_; }; -string ConvolutionKindToString( - ConvolutionThunk::ConvolutionKind convolution_kind); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc new file mode 100644 index 0000000000000000000000000000000000000000..db6924c742e4a949a3e939b6d6659e92c2d1e312 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.cc @@ -0,0 +1,219 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" + +namespace xla { +namespace gpu { +namespace { + +class Visitor : public DfsHloVisitorWithDefault { + public: + explicit Visitor(HloComputation* computation) : computation_(computation) {} + + static bool Run(HloComputation* computation) { + Visitor visitor(computation); + TF_CHECK_OK(computation->Accept(&visitor)); + return visitor.changed_; + } + + Status DefaultAction(HloInstruction* /*hlo_instruction*/) override { + return Status::OK(); + } + + Status HandleBatchNormInference(HloInstruction* batch_norm) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm) override; + + private: + bool changed_ = false; + HloComputation* computation_; +}; + +// cudnn defines CUDNN_BN_MIN_EPSILON = 1e-5 as the minimum acceptable epsilon +// for calls to its batchnorm ops. +bool EpsilonInRange(HloInstruction* batch_norm) { + return batch_norm->epsilon() >= 1e-5; +} + +Status Visitor::HandleBatchNormInference(HloInstruction* batch_norm) { + if (batch_norm->operand(0)->shape().element_type() != F32) { + VLOG(1) << "Not rewriting op with non-F32 element type: " + << batch_norm->ToString(); + return Status::OK(); + } + + // cudnn errors out on zero-sized inputs. + if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { + return Status::OK(); + } + + if (!EpsilonInRange(batch_norm)) { + return Status::OK(); + } + + HloInstruction* epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* feature_index = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(batch_norm->feature_index()))); + + std::vector operands(batch_norm->operands().begin(), + batch_norm->operands().end()); + operands.push_back(epsilon); + operands.push_back(feature_index); + + std::unique_ptr libcall = HloInstruction::CreateCustomCall( + batch_norm->shape(), operands, kCudnnBatchNormForwardInferenceCallTarget); + TF_RETURN_IF_ERROR( + computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall))); + changed_ = true; + return Status::OK(); +} + +Status Visitor::HandleBatchNormTraining(HloInstruction* batch_norm) { + if (batch_norm->operand(0)->shape().element_type() != F32) { + VLOG(1) << "Not rewriting op with non-F32 element type: " + << batch_norm->ToString(); + return Status::OK(); + } + + // cudnn errors out on zero-sized inputs. + if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { + return Status::OK(); + } + + if (!EpsilonInRange(batch_norm)) { + return Status::OK(); + } + + HloInstruction* epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* feature_index = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(batch_norm->feature_index()))); + + std::vector operands(batch_norm->operands().begin(), + batch_norm->operands().end()); + operands.push_back(epsilon); + operands.push_back(feature_index); + + HloInstruction* libcall = + computation_->AddInstruction(HloInstruction::CreateCustomCall( + batch_norm->shape(), operands, + kCudnnBatchNormForwardTrainingCallTarget)); + + // The cudnn libcall returns a tuple + // {output, mean, rsqrt(variance + epsilon)}, + // but the batchnorm HLO returns {output, mean, variance}. Fix it up. + HloInstruction* inverse_stddev = + computation_->AddInstruction(HloInstruction::CreateGetTupleElement( + libcall->shape().tuple_shapes(2), libcall, 2)); + HloInstruction* variance_plus_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + inverse_stddev->shape(), HloOpcode::kPower, inverse_stddev, + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-2))))); + HloInstruction* variance = + computation_->AddInstruction(HloInstruction::CreateBinary( + variance_plus_epsilon->shape(), HloOpcode::kSubtract, + variance_plus_epsilon, epsilon)); + + // Repackage the results. + std::unique_ptr new_tuple = HloInstruction::CreateTuple({ + computation_->AddInstruction(HloInstruction::CreateGetTupleElement( + libcall->shape().tuple_shapes(0), libcall, 0)), + computation_->AddInstruction(HloInstruction::CreateGetTupleElement( + libcall->shape().tuple_shapes(1), libcall, 1)), + variance, + }); + + TF_RETURN_IF_ERROR(computation_->ReplaceWithNewInstruction( + batch_norm, std::move(new_tuple))); + changed_ = true; + return Status::OK(); +} + +Status Visitor::HandleBatchNormGrad(HloInstruction* batch_norm) { + if (batch_norm->operand(0)->shape().element_type() != F32) { + VLOG(1) << "Not rewriting op with non-F32 element type: " + << batch_norm->ToString(); + return Status::OK(); + } + + // cudnn errors out on zero-sized inputs. + if (ShapeUtil::ElementsIn(batch_norm->operand(0)->shape()) == 0) { + return Status::OK(); + } + + if (!EpsilonInRange(batch_norm)) { + return Status::OK(); + } + + HloInstruction* epsilon = computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon()))); + HloInstruction* feature_index = + computation_->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(batch_norm->feature_index()))); + + // The cudnn libcall expects its input to be rsqrt(variance + epsilon), but + // the batchnorm HLO takes plain variance as input. Fix it up. + HloInstruction* var_plus_epsilon = + computation_->AddInstruction(HloInstruction::CreateBinary( + batch_norm->operand(3)->shape(), HloOpcode::kAdd, + batch_norm->mutable_operand(3), epsilon)); + HloInstruction* inverse_stddev = + computation_->AddInstruction(HloInstruction::CreateBinary( + var_plus_epsilon->shape(), HloOpcode::kPower, var_plus_epsilon, + computation_->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-.5))))); + + std::vector operands(batch_norm->operands().begin(), + batch_norm->operands().end()); + operands[3] = inverse_stddev; + operands.push_back(epsilon); + operands.push_back(feature_index); + + std::unique_ptr libcall = HloInstruction::CreateCustomCall( + batch_norm->shape(), operands, kCudnnBatchNormBackwardCallTarget); + + TF_RETURN_IF_ERROR( + computation_->ReplaceWithNewInstruction(batch_norm, std::move(libcall))); + changed_ = true; + return Status::OK(); +} + +} // anonymous namespace + +StatusOr CudnnBatchNormRewriter::Run(HloModule* module) { + VLOG(2) << "CudnnBatchNormRewriter::Run(), before:"; + XLA_VLOG_LINES(2, module->ToString()); + + bool changed = false; + for (auto* comp : module->MakeNonfusionComputations()) { + if (Visitor::Run(comp)) { + changed = true; + } + } + + VLOG(2) << "CudnnBatchNormRewriter::Run(), after:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..e09cde9abf85454c7a020566cd8c2671ae12ffc3 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h @@ -0,0 +1,66 @@ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ + +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrites BatchNorm HLOs into calls into cudnn where possible. +// +// A call into cudnn for performing a batchnorm op is represented as a +// CustomCall HLO with custom_call_target equal to one of +// +// - kCudnnBatchNormForwardInferenceCallTarget +// - kCudnnBatchNormForwardTrainingCallTarget, or +// - kCudnnBatchNormBackwardCallTarget. +// +// A CustomCall created by this pass has the same operands corresponding +// batchnorm HLO, except the epsilon() and feature_index() properties of the +// batchnorm HLO are converted into proper operands, added to the end of the +// CustomCall's operands list. +// +// The inputs/outputs of the cudnn calls for BatchNormTraining and BatchNormGrad +// do not correspond exactly to the HLOs. In particular, the training cudnn +// call returns 1/sqrt(variance + epsilon), while the HLO returns plain +// variance. Similarly, the grad cudnn call expects 1/sqrt(variance + epsilon) +// as input, whereas the HLO expects plain variance. +// +// This pass adds HLOs in front of / behind the CustomCalls to fix up the +// inputs/outputs as appropriate, and we rely on the AlgebraicSimplifier to +// remove these where possible. +// +// Currently batchnorm ops over F32s are converted into cudnn calls, so long as +// epsilon is not too small. This pass leaves other batchnorm ops unmodified. +// +// The GPU backend does not implement a lowering for the batchnorm HLOs -- it +// expects them to be lowered to cudnn calls via this pass or to HLO soup via +// BatchNormRewriter. +class CudnnBatchNormRewriter : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "cudnn_batchnorm_rewriter"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..58d9c8caff31e878487fbef01afce566e6187fd9 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.cc @@ -0,0 +1,285 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" + +#include + +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +namespace se = ::perftools::gputools; +namespace dnn = se::dnn; + +static std::pair +MakeDescriptors(const Shape& shape, int64 feature_index) { + std::vector logical_to_physical = + LayoutUtil::MakeLogicalToPhysical(shape.layout()); + + auto physical_dim_size = [&](int64 physical_dim) { + return shape.dimensions(LayoutUtil::Major(shape.layout(), physical_dim)); + }; + + // Batchnorm only cares about the location of the depth (aka "feature") dim. + // The other dims are all treated the same. Thus we can use the kBatchDepthYX + // cudnn layout for any XLA shape+layout, even XLA shapes that don't have + // exactly 4 dimensions: We put everything that comes before the feature dim + // into "batch", and everything that comes after the feature dim into "Y". + int64 batch_size = 1; + int64 y_size = 1; + int64 physical_dim; + for (physical_dim = 0; physical_dim != logical_to_physical[feature_index]; + ++physical_dim) { + CHECK_LT(physical_dim, shape.dimensions_size()); + batch_size *= physical_dim_size(physical_dim); + } + ++physical_dim; // Skip the feature dimension. + for (; physical_dim < shape.dimensions_size(); ++physical_dim) { + y_size *= physical_dim_size(physical_dim); + } + + dnn::BatchDescriptor input_desc; + input_desc.set_layout(dnn::DataLayout::kBatchDepthYX) + .set_count(batch_size) + .set_feature_map_count(shape.dimensions(feature_index)) + .set_height(y_size) + .set_width(1); + + dnn::BatchDescriptor scale_offset_desc; + scale_offset_desc.set_layout(dnn::DataLayout::kBatchDepthYX) + .set_feature_map_count(input_desc.feature_map_count()) + .set_height(1) + .set_width(1) + .set_count(1); + + return std::make_pair(input_desc, scale_offset_desc); +} + +CudnnBatchNormForwardInferenceThunk::CudnnBatchNormForwardInferenceThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, + const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& variance, float epsilon, int64 feature_index, + const BufferAllocation::Slice& output, const HloInstruction* hlo) + : Thunk(Thunk::Kind::kCudnnBatchNormForwardInference, hlo), + operand_(operand), + scale_(scale), + offset_(offset), + mean_(mean), + variance_(variance), + epsilon_(epsilon), + feature_index_(feature_index), + output_(output) { + CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); + CHECK_EQ(hlo->custom_call_target(), + kCudnnBatchNormForwardInferenceCallTarget); + CHECK( + LayoutUtil::LayoutsInShapesEqual(hlo->shape(), hlo->operand(0)->shape())); + CHECK_EQ(hlo->shape().element_type(), F32) << "Not yet implemented"; +} + +Status CudnnBatchNormForwardInferenceThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + dnn::BatchDescriptor operand_desc; + dnn::BatchDescriptor scale_offset_desc; + std::tie(operand_desc, scale_offset_desc) = + MakeDescriptors(hlo_instruction()->shape(), feature_index_); + + se::DeviceMemory output(buffer_allocations.GetDeviceAddress(output_)); + stream->ThenBatchNormalizationForward( + se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(variance_)), + operand_desc, // + scale_offset_desc, // + epsilon_, // + &output, // + /*batch_mean=*/nullptr, // + /*batch_var=*/nullptr, // + /*saved_mean=*/nullptr, // + /*saved_inv_var=*/nullptr, // + /*is_training=*/false, // + /*var_to_inv_var=*/nullptr, // + /*inv_var_to_var=*/nullptr); + if (!stream->ok()) { + return InternalError("BatchNormalizationForward call failed."); + } + return Status::OK(); +} + +CudnnBatchNormForwardTrainingThunk::CudnnBatchNormForwardTrainingThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& offset, + float epsilon, int64 feature_index, + const BufferAllocation::Slice& output_data, + const BufferAllocation::Slice& output_mean, + const BufferAllocation::Slice& output_inv_stddev, + const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) + : Thunk(Thunk::Kind::kCudnnBatchNormForwardTraining, hlo), + operand_(operand), + scale_(scale), + offset_(offset), + epsilon_(epsilon), + feature_index_(feature_index), + output_data_(output_data), + output_mean_(output_mean), + output_inv_stddev_(output_inv_stddev), + output_tuple_(output_tuple) { + CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); + CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormForwardTrainingCallTarget); + CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), + hlo->operand(0)->shape())); + for (const auto& tuple_shape : hlo->shape().tuple_shapes()) { + CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented"; + } +} + +Status CudnnBatchNormForwardTrainingThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + dnn::BatchDescriptor operand_desc; + dnn::BatchDescriptor scale_offset_desc; + // The BatchNormTraining HLO outputs a tuple of three elements: output data, + // batch mean, and batch variance. We want to make our descriptors based on + // the shape of the output data. + std::tie(operand_desc, scale_offset_desc) = MakeDescriptors( + hlo_instruction()->shape().tuple_shapes(0), feature_index_); + + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_data_)); + se::DeviceMemory output_mean( + buffer_allocations.GetDeviceAddress(output_mean_)); + se::DeviceMemory output_inv_stddev( + buffer_allocations.GetDeviceAddress(output_inv_stddev_)); + + se::DeviceMemory null_device_ptr(nullptr); + stream->ThenBatchNormalizationForward( + se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(offset_)), + /*estimated_mean=*/null_device_ptr, + /*estimated_variance=*/null_device_ptr, + operand_desc, // + scale_offset_desc, // + epsilon_, // + &output_data, // + /*batch_mean=*/&null_device_ptr, // + /*batch_var=*/&null_device_ptr, // + /*saved_mean=*/&output_mean, // + /*saved_inv_var=*/&output_inv_stddev, // + /*is_training=*/true, // + /*var_to_inv_var=*/nullptr, // + /*inv_var_to_var=*/nullptr); + + // Write the tuple. + void* ptrs[] = {output_data.opaque(), output_mean.opaque(), + output_inv_stddev.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(output_tuple_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); + + if (!stream->ok()) { + return InternalError("BatchNormalizationTraining call failed."); + } + return Status::OK(); +} + +CudnnBatchNormBackwardThunk::CudnnBatchNormBackwardThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& inv_stddev, + const BufferAllocation::Slice& grad_output, float epsilon, + int64 feature_index, const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& output_grad_scale, + const BufferAllocation::Slice& output_grad_offset, + const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo) + : Thunk(Thunk::Kind::kCudnnBatchNormBackward, hlo), + operand_(operand), + scale_(scale), + mean_(mean), + inv_stddev_(inv_stddev), + grad_output_(grad_output), + epsilon_(epsilon), + feature_index_(feature_index), + output_grad_data_(output_grad_data), + output_grad_scale_(output_grad_scale), + output_grad_offset_(output_grad_offset), + output_tuple_(output_tuple) { + CHECK_EQ(hlo->opcode(), HloOpcode::kCustomCall); + CHECK_EQ(hlo->custom_call_target(), kCudnnBatchNormBackwardCallTarget); + CHECK_EQ(hlo->shape().tuple_shapes_size(), 3); + CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), + hlo->operand(0)->shape())); + CHECK(LayoutUtil::LayoutsInShapesEqual(hlo->shape().tuple_shapes(0), + hlo->operand(4)->shape())); + for (const auto& tuple_shape : hlo->shape().tuple_shapes()) { + CHECK_EQ(tuple_shape.element_type(), F32) << "Not yet implemented"; + } +} + +Status CudnnBatchNormBackwardThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + dnn::BatchDescriptor operand_desc; + dnn::BatchDescriptor scale_offset_desc; + + // This call outputs a tuple of three elements: grad data, grad offset, and + // grad scale. We want to make our descriptors based on the shape of the grad + // data. + std::tie(operand_desc, scale_offset_desc) = MakeDescriptors( + hlo_instruction()->shape().tuple_shapes(0), feature_index_); + + se::DeviceMemory output_grad_data( + buffer_allocations.GetDeviceAddress(output_grad_data_)); + se::DeviceMemory output_grad_scale( + buffer_allocations.GetDeviceAddress(output_grad_scale_)); + se::DeviceMemory output_grad_offset( + buffer_allocations.GetDeviceAddress(output_grad_offset_)); + + stream->ThenBatchNormalizationBackward( + se::DeviceMemory( + buffer_allocations.GetDeviceAddress(grad_output_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(operand_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(scale_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(mean_)), + se::DeviceMemory(buffer_allocations.GetDeviceAddress(inv_stddev_)), + operand_desc, scale_offset_desc, epsilon_, &output_grad_data, + &output_grad_scale, &output_grad_offset); + + // Write the output tuple. + void* ptrs[] = {output_grad_data.opaque(), output_grad_scale.opaque(), + output_grad_offset.opaque()}; + se::DeviceMemory tuple_addr( + buffer_allocations.GetDeviceAddress(output_tuple_)); + stream->ThenMemcpyH2D(ptrs, &tuple_addr); + + if (!stream->ok()) { + return InternalError("BatchNormalizationBackward call failed."); + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..c5fbb6d8a3912d380172d496d8d35e80dc9f5c71 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status.h" + +namespace xla { +namespace gpu { + +// This file contains thunks which call into cudnn to run the various flavors of +// batch normalization: BatchNormInference, BatchNormTraining, and +// BatchNormGrad, known to cudnn as BatchNormForwardInference, +// BatchNormForwardTraining, and BatchNormBackward. +// +// As an alternative to using these thunks, XLA can decompose batchnorm HLOs +// into smaller components using the BatchNormRewriter pass. This can result in +// faster code because those individual components can fuse into their +// inputs/outputs, but it may also be slower if cudnn's batchnorm implementation +// outperforms the code XLA generates for these components. +// +// Currently these thunks require that their inputs are F32s. +// +// Note that these thunks do not take full advantage of the cudnn batchnorm +// functions. For example, cudnn lets you bias and/or scale the input/output, +// but these thunks don't currently support that. + +class CudnnBatchNormForwardInferenceThunk : public Thunk { + public: + CudnnBatchNormForwardInferenceThunk(const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, + const BufferAllocation::Slice& offset, + const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& variance, + float epsilon, int64 feature_index, + const BufferAllocation::Slice& output, + const HloInstruction* hlo); + + CudnnBatchNormForwardInferenceThunk( + const CudnnBatchNormForwardInferenceThunk&) = delete; + CudnnBatchNormForwardInferenceThunk& operator=( + const CudnnBatchNormForwardInferenceThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice operand_; + BufferAllocation::Slice scale_; + BufferAllocation::Slice offset_; + BufferAllocation::Slice mean_; + BufferAllocation::Slice variance_; + float epsilon_; + int64 feature_index_; + BufferAllocation::Slice output_; +}; + +class CudnnBatchNormForwardTrainingThunk : public Thunk { + public: + CudnnBatchNormForwardTrainingThunk( + const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, + const BufferAllocation::Slice& offset, float epsilon, int64 feature_index, + const BufferAllocation::Slice& output_data, + const BufferAllocation::Slice& output_mean, + const BufferAllocation::Slice& output_inv_stddev, + const BufferAllocation::Slice& output_tuple, const HloInstruction* hlo); + + CudnnBatchNormForwardTrainingThunk( + const CudnnBatchNormForwardTrainingThunk&) = delete; + CudnnBatchNormForwardTrainingThunk& operator=( + const CudnnBatchNormForwardTrainingThunk&) = delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice operand_; + BufferAllocation::Slice scale_; + BufferAllocation::Slice offset_; + float epsilon_; + int64 feature_index_; + BufferAllocation::Slice output_data_; + BufferAllocation::Slice output_mean_; + BufferAllocation::Slice output_inv_stddev_; + BufferAllocation::Slice output_tuple_; +}; + +class CudnnBatchNormBackwardThunk : public Thunk { + public: + CudnnBatchNormBackwardThunk(const BufferAllocation::Slice& operand, + const BufferAllocation::Slice& scale, + const BufferAllocation::Slice& mean, + const BufferAllocation::Slice& inv_stddev, + const BufferAllocation::Slice& grad_output, + float epsilon, int64 feature_index, + const BufferAllocation::Slice& output_grad_data, + const BufferAllocation::Slice& output_grad_scale, + const BufferAllocation::Slice& output_grad_offset, + const BufferAllocation::Slice& output_tuple, + const HloInstruction* hlo); + + CudnnBatchNormBackwardThunk(const CudnnBatchNormBackwardThunk&) = delete; + CudnnBatchNormBackwardThunk& operator=(const CudnnBatchNormBackwardThunk&) = + delete; + + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + BufferAllocation::Slice operand_; + BufferAllocation::Slice scale_; + BufferAllocation::Slice mean_; + BufferAllocation::Slice inv_stddev_; + BufferAllocation::Slice grad_output_; + float epsilon_; + int64 feature_index_; + BufferAllocation::Slice output_grad_data_; + BufferAllocation::Slice output_grad_scale_; + BufferAllocation::Slice output_grad_offset_; + BufferAllocation::Slice output_tuple_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_BATCHNORM_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc new file mode 100644 index 0000000000000000000000000000000000000000..c29aa31d4ee31c88ec6d315480d4258b190bbcff --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.cc @@ -0,0 +1,379 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = perftools::gputools; + +using se::DeviceMemoryBase; +using se::dnn::AlgorithmConfig; +using se::dnn::AlgorithmDesc; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +class ScratchAllocator : public se::ScratchAllocator { + public: + ScratchAllocator(int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + + ~ScratchAllocator() override; + + int64 GetMemoryLimitInBytes(se::Stream* stream) override { + return 1LL << 32; // 4GB. TODO(jlebar): Tune this? + } + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +ScratchAllocator::~ScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +se::port::StatusOr> ScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return se::port::Status(se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Failed to allocate %lld bytes on device %d.", + byte_size, device_ordinal_)); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory(allocated_buffer); +} + +// Determines whether we can safely perform a winograd non-fused convolution for +// the given input and output shapes. This works around b/68264959, an integer +// overflow in cuDNNv5 and cuDNNv6. +// +// TODO(jlebar): We shouldn't need this check for cuDNNv7. +bool ShouldIncludeWinogradNonfusedAlgo( + const Shape& input_shape, const Shape& output_shape, + const ConvolutionDimensionNumbers& dnums) { + int64 batch = input_shape.dimensions(dnums.input_batch_dimension()); + int64 in_depths = input_shape.dimensions(dnums.input_feature_dimension()); + int64 in_rows = input_shape.dimensions(dnums.input_spatial_dimensions(0)); + int64 in_cols = + dnums.input_spatial_dimensions_size() == 1 + ? 1 + : input_shape.dimensions(dnums.input_spatial_dimensions(1)); + int64 out_depths = output_shape.dimensions(dnums.output_feature_dimension()); + + int64 total_size = CeilOfRatio(batch, int64{16}) * + std::max(in_depths, out_depths) * in_cols * in_rows * + sizeof(float); + + const int64 threshold = 1L << 31; + return total_size < threshold; +} + +std::vector GetAlgorithms(CudnnConvKind kind, + bool with_winograd_nonfused, + se::StreamExecutor* stream_exec_) { + std::vector algorithms; + switch (kind) { + case CudnnConvKind::kBackwardFilter: + CHECK(stream_exec_->GetConvolveBackwardFilterAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kBackwardInput: + CHECK(stream_exec_->GetConvolveBackwardDataAlgorithms( + with_winograd_nonfused, &algorithms)); + break; + case CudnnConvKind::kForward: + CHECK(stream_exec_->GetConvolveAlgorithms(with_winograd_nonfused, + &algorithms)); + break; + } + + // Remove any algorithms with tensor math enabled. These have lower precision + // than regular algorithms, and we don't yet have a way to turn this on/off in + // XLA. + algorithms.erase(std::remove_if(algorithms.begin(), algorithms.end(), + [&](const AlgorithmDesc& a) { + return a.tensor_ops_enabled(); + }), + algorithms.end()); + + return algorithms; +} + +string AlgorithmToString(const AlgorithmDesc& algo) { + if (algo.tensor_ops_enabled()) { + return tensorflow::strings::StrCat(algo.algo_id(), "+TC"); + } + return tensorflow::strings::StrCat(algo.algo_id()); +} + +string NumBytesToString(int64 bytes) { + return tensorflow::strings::StrCat( + tensorflow::strings::HumanReadableNumBytes(bytes), " (", bytes, "B)"); +} + +} // anonymous namespace + +// We could have caching here so that we don't redo this work for two identical +// convolutions. Unfortunately our cache key would have to be a tuple +// containing the protos passed to this function, and we have no utility for +// hashing protos. We could write our own hash functions, but they'd silently +// break if we ever added a field to one of the protos. Perhaps we could hack +// using the binary-encoded proto as the hash key, on the assumption that two +// protos being binary-equal is a sufficient, if not necessary, condition for +// proper equality. But that would still leave us open to having unnecessary +// cache misses and doing extra work. Overall, caching doesn't seem worth the +// trouble, but we may want to revisit this if we ever find a model where +// caching would speed up compilation a lot. +optional> +CudnnConvolutionAlgorithmPicker::PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr) { + // Create a stream for us to do our work on. + se::Stream stream{stream_exec_}; + stream.Init(); + const auto device_ordinal = stream_exec_->device_ordinal(); + + // allocator either points to this->allocator_ or, if that's null, to a + // StreamExecutorMemoryAllocator for stream_exec_. + DeviceMemoryAllocator* allocator; + optional se_allocator; + if (allocator_ != nullptr) { + allocator = allocator_; + } else { + se_allocator.emplace( + stream_exec_->platform(), + tensorflow::gtl::ArraySlice({stream_exec_})); + allocator = &*se_allocator; + } + + // Allocate space for the input, filter, and output of the convolution. We + // use a ScratchAllocator for this instead of calling allocator_ directly so + // that our allocations don't leak. + // + // We don't put any data in these buffers, because (in theory, anyway) the + // speed of a conv isn't affected by the data being convolved. + ScratchAllocator input_output_allocator(device_ordinal, allocator); + se::port::StatusOr input_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(input_shape)); + se::port::StatusOr filter_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(filter_shape)); + se::port::StatusOr output_buf = + input_output_allocator.AllocateBytes(&stream, + ShapeUtil::ByteSizeOf(output_shape)); + if (!input_buf.ok() || !filter_buf.ok() || !output_buf.ok()) { + LOG(WARNING) + << "Couldn't allocate space for input/filter/output of convolution " + << instr->ToString() << ". Falling back to default algorithm."; + return nullopt; + } + + const bool use_winograd_nonfused = + ShouldIncludeWinogradNonfusedAlgo(input_shape, output_shape, dnums); + se::dnn::ProfileResult best_result; + int64 best_result_bytes_used = 0; + for (const AlgorithmDesc& alg : + GetAlgorithms(kind, use_winograd_nonfused, stream_exec_)) { + ScratchAllocator scratch_allocator(device_ordinal, allocator); + se::dnn::ProfileResult profile_result; + VLOG(3) << "Trying algorithm " << AlgorithmToString(alg) << " for " + << instr->ToString(); + + bool launch_ok = + RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + se::DeviceMemory(input_buf.ValueOrDie()), + se::DeviceMemory(filter_buf.ValueOrDie()), + se::DeviceMemory(output_buf.ValueOrDie()), + &scratch_allocator, window, dnums, + AlgorithmConfig(alg), &stream, &profile_result) + .ok(); + + if (launch_ok && profile_result.is_valid()) { + int64 scratch_bytes_used = scratch_allocator.TotalAllocatedBytes(); + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) + << " succeeded, taking " << profile_result.elapsed_time_in_ms() + << "ms and using " << NumBytesToString(scratch_bytes_used) + << " of scratch (Best result: " + << best_result.elapsed_time_in_ms() << "ms, " + << NumBytesToString(best_result_bytes_used) << " of scratch)"; + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + best_result_bytes_used = scratch_bytes_used; + } + } else { + VLOG(3) << "Run of algorithm " << AlgorithmToString(alg) << " failed."; + } + } + if (best_result.is_valid()) { + VLOG(2) << "Best algorithm for " << instr->ToString() << ": " + << AlgorithmToString(best_result.algorithm()) << ", takes " + << best_result.elapsed_time_in_ms() << "ms, and uses " + << best_result_bytes_used << "B of scratch memory."; + return std::make_tuple(best_result.algorithm().algo_id(), + best_result.algorithm().tensor_ops_enabled(), + best_result_bytes_used); + } + + LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString() + << " failed. Falling back to default algorithm."; + return nullopt; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnInstruction( + HloInstruction* instr) { + CHECK(IsCustomCallToDnnConvolution(*instr)); + + const auto& call_target = instr->custom_call_target(); + const auto& lhs_shape = instr->operand(0)->shape(); + const auto& rhs_shape = instr->operand(1)->shape(); + const auto& conv_result_shape = instr->shape().tuple_shapes(0); + optional> alg_scratch_and_tc; + if (call_target == kCudnnConvForwardCallTarget) { + alg_scratch_and_tc = PickBestAlgorithm( + CudnnConvKind::kForward, /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardInputCallTarget) { + alg_scratch_and_tc = PickBestAlgorithm( + CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(), + instr->convolution_dimension_numbers(), instr); + } else if (call_target == kCudnnConvBackwardFilterCallTarget) { + alg_scratch_and_tc = PickBestAlgorithm( + CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape, + instr->window(), instr->convolution_dimension_numbers(), instr); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instr->ToString(); + } + + if (!alg_scratch_and_tc.has_value()) { + return false; + } + + int64 algorithm; + bool tensor_ops_enabled; + int64 scratch_bytes; + + std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc; + + VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and " + << NumBytesToString(scratch_bytes) + << " of scratch memory: " << instr->ToString() + << " tensor_ops_enabled: " << tensor_ops_enabled; + + // Replace instr with a new CustomCall which has the correct algorithm, and + // whose output shape has the appropriate amount of scratch memory. + HloComputation* computation = instr->parent(); + Shape new_call_shape = + ShapeUtil::MakeTupleShape({instr->shape().tuple_shapes(0), + ShapeUtil::MakeShape(U8, {scratch_bytes})}); + HloInstruction* algorithm_hlo = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(algorithm))); + HloInstruction* tensor_ops_enabled_hlo = + computation->AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(tensor_ops_enabled))); + + HloInstruction* new_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + new_call_shape, + {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo, + tensor_ops_enabled_hlo}, + instr->custom_call_target())); + new_call->set_window(instr->window()); + new_call->set_convolution_dimension_numbers( + instr->convolution_dimension_numbers()); + + // Repackage new_call so it has the same shape as the original call, namely + // (conv_result, u8[0]). + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple( + {computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_call_shape.tuple_shapes(0), new_call, 0)), + computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({})))})); + + TF_RETURN_IF_ERROR(instr->parent()->ReplaceInstruction(instr, new_tuple)); + return true; +} + +StatusOr CudnnConvolutionAlgorithmPicker::RunOnComputation( + HloComputation* computation) { + std::vector convs; + for (auto* instr : computation->instructions()) { + if (IsCustomCallToDnnConvolution(*instr)) { + convs.push_back(instr); + } + } + + bool changed = false; + for (auto* instr : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(instr)); + changed |= result; + } + return changed; +} + +StatusOr CudnnConvolutionAlgorithmPicker::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; + } + return changed; +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h new file mode 100644 index 0000000000000000000000000000000000000000..516210ec2e500cf03774d27408300ac3346e7b4f --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ + +#include "tensorflow/compiler/xla/service/device_memory_allocator.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// Modifies CustomCalls to cudnn convolutions, choosing the best algorithm for +// each and adding explicit scratch space to the CustomCalls. +class CudnnConvolutionAlgorithmPicker : public HloPassInterface { + public: + // If the `allocator` parameter is not null, we will use it to allocate temp + // memory while timing the various convolution algorithms. If it's null, + // we'll use the default allocator on the StreamExecutor. + CudnnConvolutionAlgorithmPicker( + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* allocator) + : stream_exec_(stream_exec), allocator_(allocator) {} + + tensorflow::StringPiece name() const override { + return "cudnn-convolution-algorithm-picker"; + } + + StatusOr Run(HloModule* module) override; + + private: + StatusOr RunOnComputation(HloComputation* computation); + StatusOr RunOnInstruction(HloInstruction* instr); + tensorflow::gtl::optional> PickBestAlgorithm( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, const Window& window, + const ConvolutionDimensionNumbers& dnums, HloInstruction* instr); + + perftools::gputools::StreamExecutor* stream_exec_; // never null + DeviceMemoryAllocator* allocator_; // may be null +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_ALGORITHM_PICKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc similarity index 77% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc index 828ae675d7ba60b4cee1c3f5312b069263d5a814..e0c73aa73acb7f3313eb54fb07390cb76590433e 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include #include @@ -33,14 +33,32 @@ namespace xla { namespace gpu { namespace { + +bool CanImplementAsCudnnForwardConv(HloInstruction* conv) { + const ConvolutionDimensionNumbers& dnums = + conv->convolution_dimension_numbers(); + if (dnums.input_spatial_dimensions_size() > 3) { + return false; + } + + // CuDNN does not accept zero-element arguments + if (ShapeUtil::HasZeroElements(conv->operand(0)->shape()) || + ShapeUtil::HasZeroElements(conv->operand(1)->shape())) { + return false; + } + + if (window_util::HasWindowReversal(conv->window())) { + return false; + } + return true; +} + // Try to match a backward filter pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardFilter(HloInstruction* conv) { +std::tuple MatchBackwardFilter( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Step 1: match the instruction pattern without considering the paddings and // dimension numbers just yet. We may need some generic pattern matcher // similar to third_party/llvm/llvm/include/llvm/IR/PatternMatch.h @@ -55,19 +73,7 @@ MatchBackwardFilter(HloInstruction* conv) { // v v // Convolution // conv - // | - // v - // Transpose (optional if identity transposition) CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); - // If the forward convolution is followed by a transpose, we can fuse the - // transpose into the backward convolution as well. - HloInstruction* transpose = nullptr; - if (conv->user_count() == 1) { - HloInstruction* single_user = *conv->users().begin(); - if (single_user->opcode() == HloOpcode::kTranspose) { - transpose = single_user; - } - } // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = @@ -75,6 +81,9 @@ MatchBackwardFilter(HloInstruction* conv) { auto input_batch_dim = conv_dnums.input_batch_dimension(); auto input_feature_dim = conv_dnums.input_feature_dimension(); auto input_spatial_dims = conv_dnums.input_spatial_dimensions(); + auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension(); + auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension(); + auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions(); auto output_batch_dim = conv_dnums.output_batch_dimension(); auto output_feature_dim = conv_dnums.output_feature_dimension(); auto output_spatial_dims = conv_dnums.output_spatial_dimensions(); @@ -96,9 +105,14 @@ MatchBackwardFilter(HloInstruction* conv) { VLOG(1) << "Padding low should be non-negative."; return no_match_result; } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } // Padding high will be checked in Step 3. } - if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) { + if (input_batch_dim == output_batch_dim && + !window_util::HasWindowDilation(conv->window())) { VLOG(1) << conv->ToString() << " is a regular forward convolution. No need " "to fold it to a backward filter convolution."; @@ -169,64 +183,40 @@ MatchBackwardFilter(HloInstruction* conv) { } } - // To make future HLO passes easier, we canonicalize the fused expression by - // adding an identity transposition if it's omitted in the pattern. - if (transpose == nullptr) { - // Create an identity transposition with the same rank as the forward - // convolution. - HloComputation* parent_computation = conv->parent(); - std::vector transpose_dimensions(ShapeUtil::Rank(conv->shape())); - std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0); - transpose = - parent_computation->AddInstruction(HloInstruction::CreateTranspose( - conv->shape(), conv, transpose_dimensions)); - TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose)); - } - // Restore the dimension numbers of the backward convolution from the forward // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; backward_conv_dnums.set_input_batch_dimension(input_feature_dim); backward_conv_dnums.set_input_feature_dimension(input_batch_dim); - backward_conv_dnums.set_output_batch_dimension(output_feature_dim); - backward_conv_dnums.set_output_feature_dimension(output_batch_dim); for (int i = 0; i < input_spatial_dims.size(); ++i) { backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]); } - for (int i = 0; i < output_spatial_dims.size(); ++i) { - backward_conv_dnums.add_output_spatial_dimensions(output_spatial_dims[i]); + backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim); + backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim); + for (int i = 0; i < kernel_spatial_dims.size(); ++i) { + backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]); } // The dimension numbering of the output of the forward convolution (before // transposition) is the same as that of the activations (according to the // semantics of kConvolution). The batch dimension of the activations should // be treated as the input feature dimension, and the feature dimension should // be treated as the output feature. - // - // The output of the forward convolution needs to be transposed to fit into - // the dimension numbering of the weight gradients. This transposition maps - // dimension i to PositionInContainer(transpose->dimensions(), i). - backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), output_batch_dim)); - backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), output_feature_dim)); + backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim); + backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim); for (int i = 0; i < output_spatial_dims.size(); ++i) { - backward_conv_dnums.add_kernel_spatial_dimensions( - PositionInContainer(transpose->dimensions(), output_spatial_dims[i])); + backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]); } - return std::make_tuple(true, std::vector({transpose, conv}), - backward_conv_window, backward_conv_dnums); + return std::make_tuple(true, backward_conv_window, backward_conv_dnums); } // Try to match a backward input pattern that contains "conv". // Precondition: "conv" is a kConvolution. -std::tuple, Window, - ConvolutionDimensionNumbers> -MatchBackwardInput(HloInstruction* conv) { +std::tuple MatchBackwardInput( + HloInstruction* conv) { const auto no_match_result = - std::make_tuple(false, std::vector(), Window(), - ConvolutionDimensionNumbers()); + std::make_tuple(false, Window(), ConvolutionDimensionNumbers()); // Match instruction pattern. CHECK_EQ(HloOpcode::kConvolution, conv->opcode()); @@ -275,6 +265,10 @@ MatchBackwardInput(HloInstruction* conv) { << " should have no window dilation."; return no_match_result; } + if (window_dim.window_reversal()) { + VLOG(1) << "Window reversal field not supported"; + return no_match_result; + } } const auto& input_spatial_dims = dnums.input_spatial_dimensions(); @@ -395,58 +389,82 @@ MatchBackwardInput(HloInstruction* conv) { dnums.set_kernel_output_feature_dimension( conv->convolution_dimension_numbers().kernel_input_feature_dimension()); - return std::make_tuple(true, - std::vector({conv, reverse_filter}), - new_window, dnums); + return std::make_tuple(true, new_window, dnums); } -} // namespace -StatusOr ConvolutionFolding::Run(HloModule* module) { - HloComputation* entry_computation = module->entry_computation(); - std::vector convs; - for (auto* hlo : entry_computation->instructions()) { - if (hlo->opcode() == HloOpcode::kConvolution) { - convs.push_back(hlo); - } - } +// Tries to rewrite a single convolution into a call to cudnn. +StatusOr RunOnInstruction(HloInstruction* conv) { + CHECK_EQ(conv->opcode(), HloOpcode::kConvolution); - bool changed = false; - for (HloInstruction* conv : convs) { + HloInstruction* custom_call = [&]() -> HloInstruction* { bool match; - std::vector hlos_to_fuse; Window window; ConvolutionDimensionNumbers dnums; - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardFilter(conv); + + std::tie(match, window, dnums) = MatchBackwardFilter(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardFilter, - window, dnums); - VLOG(2) << "to backward filter convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + return CreateCudnnConvBackwardFilter( + conv->shape(), conv->mutable_operand(0), conv->mutable_operand(1), + window, dnums); } - std::tie(match, hlos_to_fuse, window, dnums) = MatchBackwardInput(conv); + std::tie(match, window, dnums) = MatchBackwardInput(conv); if (match) { - VLOG(2) << "Fuse instructions"; - for (HloInstruction* hlo_to_fuse : hlos_to_fuse) { - VLOG(2) << " " << hlo_to_fuse->ToString(); - } - HloInstruction* backward_convolution = - entry_computation->CreateFusionInstructionForBackwardConvolution( - hlos_to_fuse, HloInstruction::FusionKind::kConvBackwardInput, - window, dnums); - VLOG(2) << "to backward input convolution"; - VLOG(2) << " " << backward_convolution->ToString(); - changed = true; - continue; + // Backward input conv subsumes the conv plus the reverse in operand 1. + HloInstruction* reverse = conv->mutable_operand(1); + CHECK_EQ(reverse->opcode(), HloOpcode::kReverse); + HloInstruction* rhs = reverse->mutable_operand(0); + + return CreateCudnnConvBackwardInput( + conv->shape(), conv->mutable_operand(0), rhs, window, dnums); + } + + // If all else fails, try a forward convolution. + if (CanImplementAsCudnnForwardConv(conv)) { + return CreateCudnnConvForward(conv->shape(), conv->mutable_operand(0), + conv->mutable_operand(1), conv->window(), + conv->convolution_dimension_numbers()); } + + return nullptr; + }(); + + if (custom_call == nullptr) { + return false; + } + + // The CustomCall returns a tuple (conv_result, scratch_memory). Extract out + // the conv result and replace `conv` with it. + TF_RETURN_IF_ERROR(conv->parent()->ReplaceWithNewInstruction( + conv, + HloInstruction::CreateGetTupleElement(conv->shape(), custom_call, 0))); + return true; +} + +// Rewrites the convolutions in the given computation into calls to cudnn. +// Returns true if it made any changes. +StatusOr RunOnComputation(HloComputation* computation) { + std::vector convs; + for (auto* hlo : computation->instructions()) { + if (hlo->opcode() == HloOpcode::kConvolution) { + convs.push_back(hlo); + } + } + + bool changed = false; + for (HloInstruction* conv : convs) { + TF_ASSIGN_OR_RETURN(bool result, RunOnInstruction(conv)); + changed |= result; + } + return changed; +} +} // namespace + +StatusOr CudnnConvolutionRewriter::Run(HloModule* module) { + bool changed = false; + for (HloComputation* computation : module->MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation)); + changed |= result; } return changed; } diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h new file mode 100644 index 0000000000000000000000000000000000000000..0c0578d88840fed1d77f7456c9acef27dec380f5 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace gpu { + +// Rewrites plain convolutions, backwards-filter convolutions, and +// backwards-input convolutions into CustomCall HLOs that call into cuDNN. +class CudnnConvolutionRewriter : public HloPassInterface { + public: + tensorflow::StringPiece name() const override { + return "cudnn-convolution-rewriter"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_REWRITER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc similarity index 80% rename from tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc rename to tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc index 112c496e1f6bd17f89ac389ccf0256846dfa1971..65588b6aaf24da628ea586eb52c462b78b8daaa7 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter_test.cc @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,23 +13,29 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/core/platform/test.h" namespace xla { namespace gpu { +namespace { -class ConvolutionFoldingTest : public HloTestBase { +namespace op = xla::testing::opcode_matchers; + +class CudnnConvolutionRewriterTest : public HloTestBase { public: - ConvolutionFoldingTest() { + CudnnConvolutionRewriterTest() { for (int i = 0; i < 2; ++i) { WindowDimension* window_dim = default_conv_window_.add_dimensions(); window_dim->set_size(1); @@ -44,20 +50,21 @@ class ConvolutionFoldingTest : public HloTestBase { // the batch and feature dimension in the activations, and treat the batch // dimension in gradients as the input feature dimension in the filter. // - // TODO(jingyue): Add more tests on NCHW input order which TF also supports. + // TODO(jingyue): Add more tests on NCHW input order, which TF also + // supports. tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); - tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2); - tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension( 3); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0); + tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); @@ -74,9 +81,8 @@ class ConvolutionFoldingTest : public HloTestBase { } protected: - bool FoldConvolution(HloModule* module) { - ConvolutionFolding convolution_folding; - return convolution_folding.Run(module).ValueOrDie(); + bool RunPass(HloModule* module) { + return CudnnConvolutionRewriter().Run(module).ValueOrDie(); } // A convolution window with stride 1 and zero padding. The size fields are @@ -86,7 +92,7 @@ class ConvolutionFoldingTest : public HloTestBase { ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_; }; -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolve) { HloComputation::Builder builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -108,14 +114,13 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveEquivalentToForwardConvolution) { HloComputation::Builder builder(TestName()); HloInstruction* activations = @@ -135,12 +140,17 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from block35 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedActivations) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -155,26 +165,22 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) { conv_window.mutable_dimensions(i)->set_padding_low(1); conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } // Extracted from inception v3 training. -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardFilterConvolveWithPaddedGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -189,25 +195,20 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) { conv_window.mutable_dimensions(i)->set_padding_high(-1); conv_window.mutable_dimensions(i)->set_window_dilation(2); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardFilterConvolveWithUnevenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* activations = builder.AddInstruction(HloInstruction::CreateParameter( @@ -222,25 +223,20 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) { // Uneven padding: padding_low=0, padding_high=1 conv_window.mutable_dimensions(i)->set_padding_high(1); } - HloInstruction* convolution = - builder.AddInstruction(HloInstruction::CreateConvolve( - ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, - conv_window, tf_default_dnums_for_backward_filter_)); - - builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0})); + builder.AddInstruction(HloInstruction::CreateConvolve( + ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients, + conv_window, tf_default_dnums_for_backward_filter_)); auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardFilter == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardFilterCallTarget), 0)); } -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveEvenPadding) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -284,14 +280,15 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); // Low padding of the backward input convolution // = kernel_size - 1 - low padding on gradients. EXPECT_EQ(3, window_dim.padding_low()); @@ -303,7 +300,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { // Convolve([abc], [x], base_dilation=2) // = Convolve([abc], Reverse([x]), base_dilation=2) // = BackwardInputConvolve([abc], [x], stride=2) -TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1Filter) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. HloInstruction* output = @@ -328,17 +325,16 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolve1x1Filter) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); } // BackwardInputConvolve([abc], [x], stride=1) is equivalent to // ForwardConvolve([abc], [x], stride=1). No need to fold it into backward input // convolution. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolve1x1FilterEquivalentToForwardConvolve) { auto builder = HloComputation::Builder(TestName()); // NHWC dimension order. @@ -359,8 +355,12 @@ TEST_F(ConvolutionFoldingTest, tf_default_dnums_for_backward_input_)); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from Inception V3 training. @@ -377,7 +377,8 @@ TEST_F(ConvolutionFoldingTest, // 20x10x10x192 // // Gradients are padded unevenly. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { +TEST_F(CudnnConvolutionRewriterTest, + BackwardInputConvolveUnevenPaddingOnGradients) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -409,14 +410,14 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { auto module = CreateNewModule(); HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - EXPECT_EQ(HloOpcode::kFusion, - entry_computation->root_instruction()->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - entry_computation->root_instruction()->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); + const HloInstruction* custom_call = + entry_computation->root_instruction()->operand(0); for (int i = 0; i < 2; ++i) { - const WindowDimension& window_dim = - entry_computation->root_instruction()->window().dimensions(i); + const WindowDimension& window_dim = custom_call->window().dimensions(i); EXPECT_EQ(0, window_dim.padding_low()); EXPECT_EQ(0, window_dim.padding_high()); EXPECT_EQ(2, window_dim.stride()); @@ -425,7 +426,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveUnevenPaddingOnGradients) { // Similar to BackwardInputConvolveUnevenPadding, but the low padding of the // gradients exceeds kernel_size - 1. Therefore, this pattern cannot be fused. -TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveLowPaddingTooLarge) { auto builder = HloComputation::Builder(TestName()); HloInstruction* output = builder.AddInstruction(HloInstruction::CreateParameter( @@ -454,8 +455,12 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } // Extracted from //learning/brain/google/xla/benchmarks/resnet.py @@ -472,7 +477,7 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveLowPaddingTooLarge) { // // We should fuse BC even though padding on activations is uneven, because // PadInsertion will canonicalize the fusion HLO. -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveUnevenPaddingOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -505,13 +510,12 @@ TEST_F(ConvolutionFoldingTest, auto module = CreateNewModule(); const HloComputation* entry_computation = module->AddEntryComputation(builder.Build()); - EXPECT_TRUE(FoldConvolution(module.get())); - const HloInstruction* backward_conv = entry_computation->root_instruction(); - EXPECT_EQ(HloOpcode::kFusion, backward_conv->opcode()); - EXPECT_TRUE(HloInstruction::FusionKind::kConvBackwardInput == - backward_conv->fusion_kind()); + EXPECT_TRUE(RunPass(module.get())); + ASSERT_THAT(entry_computation->root_instruction(), + op::GetTupleElement( + op::CustomCall(kCudnnConvBackwardInputCallTarget), 0)); const WindowDimension& backward_conv_col_dim = - backward_conv->window().dimensions(1); + entry_computation->root_instruction()->operand(0)->window().dimensions(1); EXPECT_EQ(0, backward_conv_col_dim.padding_low()); EXPECT_EQ(1, backward_conv_col_dim.padding_high()); } @@ -527,7 +531,7 @@ TEST_F(ConvolutionFoldingTest, // // We currently don't fuse BC because PadInsertion doesn't support negative // padding on the gradients of backward convolution (b/32744257). -TEST_F(ConvolutionFoldingTest, +TEST_F(CudnnConvolutionRewriterTest, BackwardInputConvolveNegativePaddingHighOnActivations) { auto builder = HloComputation::Builder(TestName()); // The gradients are in NCHW layout. @@ -556,9 +560,14 @@ TEST_F(ConvolutionFoldingTest, .ValueOrDie())); auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(FoldConvolution(module.get())); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(RunPass(module.get())); + EXPECT_THAT( + entry_computation->root_instruction(), + op::GetTupleElement(op::CustomCall(kCudnnConvForwardCallTarget), 0)); } +} // anonymous namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..81695a6c326b922904330f33bc88260729ff67ee --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.cc @@ -0,0 +1,224 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace gpu { +namespace { + +namespace se = ::perftools::gputools; + +using se::DeviceMemory; +using se::DeviceMemoryBase; +using se::Stream; +using se::dnn::AlgorithmConfig; +using se::dnn::BatchDescriptor; +using se::dnn::ConvolutionDescriptor; +using se::dnn::DataLayout; +using se::dnn::DimIndex; +using se::dnn::FilterDescriptor; +using se::dnn::FilterLayout; +using se::dnn::ProfileResult; + +// A StreamExecutor ScratchAllocator that wraps a single XLA allocation, +// returning it (in its entirety) the first time Allocate() is called. +class ScratchBufAllocator : public se::ScratchAllocator { + public: + explicit ScratchBufAllocator(se::DeviceMemoryBase scratch) + : scratch_(scratch) {} + + ~ScratchBufAllocator() override = default; + + int64 GetMemoryLimitInBytes(se::Stream* /*stream*/) override { + return scratch_.size(); + } + + se::port::StatusOr> AllocateBytes( + se::Stream* stream, int64 byte_size) override { + if (allocated_) { + return se::port::InternalError( + "Can't allocate twice from a ScratchBufAllocator."); + } + if (byte_size > scratch_.size()) { + return se::port::InternalError(tensorflow::strings::StrCat( + "Can't allocate ", byte_size, + " bytes from a ScratchBufAllocator of size ", scratch_.size())); + } + + allocated_ = true; + return se::DeviceMemory(scratch_); + } + + private: + se::DeviceMemoryBase scratch_; + bool allocated_ = false; +}; + +} // anonymous namespace + +string CudnnConvKindToString(CudnnConvKind kind) { + switch (kind) { + case CudnnConvKind::kForward: + return "forward"; + case CudnnConvKind::kBackwardFilter: + return "backward_filter"; + case CudnnConvKind::kBackwardInput: + return "backward_input"; + } +} + +Status RunCudnnConvolution(CudnnConvKind kind, const Shape& input_shape, + const Shape& filter_shape, const Shape& output_shape, + DeviceMemory input_buf, + DeviceMemory filter_buf, + DeviceMemory output_buf, + DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + AlgorithmConfig algorithm, Stream* stream, + ProfileResult* profile_result /*= nullptr*/) { + ScratchBufAllocator scratch_allocator(scratch_buf); + return RunCudnnConvolution(kind, input_shape, filter_shape, output_shape, + input_buf, filter_buf, output_buf, + &scratch_allocator, window, dnums, algorithm, + stream, profile_result); +} + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, DeviceMemory input_buf, + DeviceMemory filter_buf, DeviceMemory output_buf, + se::ScratchAllocator* scratch_allocator, const Window& window, + const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm, + Stream* stream, ProfileResult* profile_result /*= nullptr*/) { + VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id(); + VLOG(3) << "tensor_ops_enabled: " + << algorithm.algorithm().tensor_ops_enabled(); + VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind); + VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }"; + VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }"; + VLOG(3) << "Output shape: { " << ShapeUtil::HumanString(output_shape) << " }"; + VLOG(3) << "Window: { " << window.ShortDebugString() << " }"; + VLOG(3) << "Dim nums: { " << dnums.ShortDebugString() << " }"; + + const int num_dimensions = window.dimensions_size(); + CHECK_LE(num_dimensions, 3); + // cuDNN does not support 1D convolutions. We therefore express 1D + // convolutions as 2D convolutions where the first spatial dimension is 1. + // This matches the behavior of TF (see definition of conv1d in + // tensorflow/python/ops/nn_ops.py). + const int effective_num_dimensions = std::max(2, num_dimensions); + + CHECK_EQ(F32, output_shape.element_type()) + << ShapeUtil::HumanString(output_shape); + CHECK_EQ(num_dimensions, dnums.input_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.kernel_spatial_dimensions_size()); + CHECK_EQ(num_dimensions, dnums.output_spatial_dimensions_size()); + for (const WindowDimension& dim : window.dimensions()) { + CHECK_EQ(dim.padding_low(), dim.padding_high()); + } + + // cuDNN's convolution APIs support the BDYX layout for activations/output and + // the OIYX layout for weights. + BatchDescriptor input_descriptor(effective_num_dimensions); + input_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + input_shape.dimensions(dnums.input_feature_dimension())) + .set_count(input_shape.dimensions(dnums.input_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + // Note that the dimensions are reversed. The same holds below. + input_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + input_shape.dimensions(dnums.input_spatial_dimensions(dim))); + } + + FilterDescriptor filter_descriptor(effective_num_dimensions); + filter_descriptor.set_layout(FilterLayout::kOutputInputYX) + .set_input_feature_map_count( + filter_shape.dimensions(dnums.kernel_input_feature_dimension())) + .set_output_feature_map_count( + filter_shape.dimensions(dnums.kernel_output_feature_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + filter_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + filter_shape.dimensions(dnums.kernel_spatial_dimensions(dim))); + } + + ConvolutionDescriptor convolution_descriptor(effective_num_dimensions); + for (int dim = 0; dim < num_dimensions; ++dim) { + convolution_descriptor + .set_zero_padding( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).padding_low()) + .set_filter_stride( + static_cast(effective_num_dimensions - dim - 1), + window.dimensions(dim).stride()); + } + + BatchDescriptor output_descriptor(effective_num_dimensions); + output_descriptor.set_layout(DataLayout::kBatchDepthYX) + .set_feature_map_count( + output_shape.dimensions(dnums.output_feature_dimension())) + .set_count(output_shape.dimensions(dnums.output_batch_dimension())); + for (int dim = 0; dim < num_dimensions; ++dim) { + output_descriptor.set_spatial_dim( + static_cast(effective_num_dimensions - dim - 1), + output_shape.dimensions(dnums.output_spatial_dimensions(dim))); + } + + // Add a singleton dimension in the 1D convolution case. + if (num_dimensions == 1) { + input_descriptor.set_spatial_dim(static_cast(0), 1); + output_descriptor.set_spatial_dim(static_cast(0), 1); + filter_descriptor.set_spatial_dim(static_cast(0), 1); + convolution_descriptor.set_zero_padding(static_cast(0), 0) + .set_filter_stride(static_cast(0), 1); + } + + switch (kind) { + case CudnnConvKind::kForward: + stream->ThenConvolveWithAlgorithm( + input_descriptor, input_buf, filter_descriptor, filter_buf, + convolution_descriptor, output_descriptor, &output_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardInput: + stream->ThenConvolveBackwardDataWithAlgorithm( + filter_descriptor, filter_buf, output_descriptor, output_buf, + convolution_descriptor, input_descriptor, &input_buf, + scratch_allocator, algorithm, profile_result); + break; + case CudnnConvKind::kBackwardFilter: + stream->ThenConvolveBackwardFilterWithAlgorithm( + input_descriptor, input_buf, output_descriptor, output_buf, + convolution_descriptor, filter_descriptor, &filter_buf, + scratch_allocator, algorithm, profile_result); + break; + } + + if (!stream->ok()) { + return InternalError( + "Unable to launch convolution with type %s and algorithm (%lld, %lld)", + CudnnConvKindToString(kind).c_str(), algorithm.algorithm().algo_id(), + algorithm.algorithm_no_scratch().algo_id()); + } + return Status::OK(); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..b101f76510c129fd22b246e5f0348848192ecbba --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h @@ -0,0 +1,97 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// This file contains low-level routines for running cudnn convolutions. + +// Different types of convolutions supported by cudnn. +// +// A way to think about these is that a convolution is defined by three arrays +// -- the "input", the "filter", and the "output" -- and given any two of these, +// we can compute the third. For example, a backward-input convolution takes as +// input a filter and an "output" and produces an "input" such that if one were +// to do a forward convolution of "input" using filter, the result would be +// something with the same shape as "output". +// +// This way of thinking is not correct if you look at the values produced. For +// example, a backward-input convolution is not actually the mathematical +// inverse of a forward convolution. But it's right as far as the shapes and +// "connectivity" (i.e. which elements of the input affect which elements of +// the output) are concerned. +enum class CudnnConvKind { + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter +}; + +// Converts a CudnnConvKind value to a string. +string CudnnConvKindToString(CudnnConvKind kind); + +// Calls into cudnn to run the specified convolution. +// +// Note that depending on the value of CudnnConvKind, the result of this call +// may be written into input_buf, filter_buf, or output_buf! +// +// At the moment we only support cudnn convolutions over floats. +// +// We provide one overload which takes a scratch buffer, and another which takes +// an allocator which is responsible for allocating the scratch space. In +// theory the second one shouldn't be necessary -- users of this function could +// just ask cudnn how much scratch space it needs for a particular convolution. +// But in practice, StreamExecutor does not expose such an API, and in the name +// of parsimony, perhaps it's better not to add it. Instead, the first time you +// call a convolution, you should call the version that takes a scratch +// allocator and take note of how much memory is used. The next time you call +// the same conv, you can provide an explicitly preallocated scratch buffer of +// that size, if you like. +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, + perftools::gputools::DeviceMemory input_buf, + perftools::gputools::DeviceMemory filter_buf, + perftools::gputools::DeviceMemory output_buf, + perftools::gputools::DeviceMemoryBase scratch_buf, const Window& window, + const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +Status RunCudnnConvolution( + CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape, + const Shape& output_shape, + perftools::gputools::DeviceMemory input_buf, + perftools::gputools::DeviceMemory filter_buf, + perftools::gputools::DeviceMemory output_buf, + perftools::gputools::ScratchAllocator* scratch_allocator, + const Window& window, const ConvolutionDimensionNumbers& dnums, + perftools::gputools::dnn::AlgorithmConfig algorithm, + perftools::gputools::Stream* stream, + perftools::gputools::dnn::ProfileResult* profile_result = nullptr); + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUDNN_CONVOLUTION_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 6bf00cfb8a53723ae9608093480bf2eed10144dd..5af7a77ea858563fbea05af8efd54f96a74aee93 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -72,9 +72,27 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type) const { // The libdevice math functions differentiate between "double" and "float" by - // appending an 'f' to the function's name. + // appending an 'f' to the function's name. libdevice doesn't have f16 math + // functions, so we convert the operands to f32 before calling the function + // and then convert the result back to f16. string munged_callee = callee_name; + bool cast_result_to_fp16 = false; + std::vector converted_operands(operands.begin(), + operands.end()); + std::vector converted_input_types(input_types.begin(), + input_types.end()); switch (output_type) { + case F16: + cast_result_to_fp16 = true; + for (int64 i = 0; i < operands.size(); ++i) { + if (input_types[i] == F16) { + converted_operands[i] = ir_builder_->CreateFPCast( + converted_operands[i], ir_builder_->getFloatTy()); + converted_input_types[i] = F32; + } + } + output_type = F32; + TF_FALLTHROUGH_INTENDED; case F32: StrAppend(&munged_callee, "f"); break; @@ -84,7 +102,13 @@ StatusOr GpuElementalIrEmitter::EmitLibdeviceMathCall( return Unimplemented("Bad type for libdevice math call: %s", PrimitiveType_Name(output_type).c_str()); } - return EmitMathCall(munged_callee, operands, input_types, output_type); + llvm::Value* result = EmitMathCall(munged_callee, converted_operands, + converted_input_types, output_type) + .ValueOrDie(); + if (cast_result_to_fp16) { + result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + } + return result; } StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( @@ -92,10 +116,13 @@ StatusOr GpuElementalIrEmitter::EmitLlvmIntrinsicMathCall( tensorflow::gtl::ArraySlice operands, tensorflow::gtl::ArraySlice input_types, PrimitiveType output_type) const { - // llvm intrinsics differentiate between float/double functions via the ".f32" - // and ".f64" suffixes. + // llvm intrinsics differentiate between half/float/double functions via + // the suffixes ".f16", ".f32" and ".f64". string munged_callee = callee_name; switch (output_type) { + case F16: + StrAppend(&munged_callee, ".f16"); + break; case F32: StrAppend(&munged_callee, ".f32"); break; @@ -135,10 +162,6 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kAtan2: - return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value}, - {lhs_input_type, rhs_input_type}, - output_type); case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, @@ -199,29 +222,44 @@ StatusOr GpuElementalIrEmitter::EmitErfcInv( return EmitLibdeviceMathCall("__nv_erfcinv", {value}, {prim_type}, prim_type); } +StatusOr GpuElementalIrEmitter::EmitLog( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_log", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitSin( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_sin", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitCos( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_cos", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitExp( + PrimitiveType prim_type, llvm::Value* value) const { + return EmitLibdeviceMathCall("__nv_exp", {value}, {prim_type}, prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitPow(PrimitiveType prim_type, + llvm::Value* lhs, + llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_pow", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + +StatusOr GpuElementalIrEmitter::EmitAtan2( + PrimitiveType prim_type, llvm::Value* lhs, llvm::Value* rhs) const { + return EmitLibdeviceMathCall("__nv_atan2", {lhs, rhs}, {prim_type, prim_type}, + prim_type); +} + StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { PrimitiveType input_type = op->operand(0)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { - case HloOpcode::kExp: - return EmitLibdeviceMathCall("__nv_exp", {operand_value}, {input_type}, - output_type); - case HloOpcode::kFloor: - return EmitLibdeviceMathCall("__nv_floor", {operand_value}, {input_type}, - output_type); - case HloOpcode::kCeil: - return EmitLibdeviceMathCall("__nv_ceil", {operand_value}, {input_type}, - output_type); - case HloOpcode::kLog: - return EmitLibdeviceMathCall("__nv_log", {operand_value}, {input_type}, - output_type); - case HloOpcode::kCos: - return EmitLibdeviceMathCall("__nv_cos", {operand_value}, {input_type}, - output_type); - case HloOpcode::kSin: - return EmitLibdeviceMathCall("__nv_sin", {operand_value}, {input_type}, - output_type); case HloOpcode::kTanh: return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); @@ -230,224 +268,6 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } -StatusOr GpuElementalIrEmitter::EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - TF_RET_CHECK(primitive_util::IsComplexType(input_type)); - PrimitiveType component_type = - primitive_util::ComplexComponentType(input_type); - switch (op->opcode()) { - case HloOpcode::kPower: { - // (a+bi)^(c+di) = - // (a*a+b*b)^(0.5c) * exp(-d*atan2(b,a)) * (cos(q) + i*sin(q)), - // where q = c*atan2(b,a)+0.5d*ln(a*a+b*b) - auto a = EmitExtractReal(lhs_value); - auto b = EmitExtractImag(lhs_value); - auto c = EmitExtractReal(rhs_value); - auto d = EmitExtractImag(rhs_value); - auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - auto one_half = llvm::ConstantFP::get(a->getType(), 0.5); - auto half_c = ir_builder_->CreateFMul(one_half, c); - - TF_ASSIGN_OR_RETURN( - auto aa_p_bb_to_half_c, - EmitLibdeviceMathCall("__nv_pow", {aa_p_bb, half_c}, - {component_type, component_type}, - component_type)); - auto neg_d = ir_builder_->CreateFNeg(d); - TF_ASSIGN_OR_RETURN( - auto arg_lhs, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs); - TF_ASSIGN_OR_RETURN( - auto e_to_neg_d_arg_lhs, - EmitLibdeviceMathCall("__nv_exp", {neg_d_arg_lhs}, {component_type}, - component_type)); - auto coeff = - ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs); - TF_ASSIGN_OR_RETURN( - auto ln_aa_p_bb, - EmitLibdeviceMathCall("__nv_log", {aa_p_bb}, {component_type}, - component_type)); - auto half_d = ir_builder_->CreateFMul(one_half, d); - auto q = - ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs), - ir_builder_->CreateFMul(half_d, ln_aa_p_bb)); - TF_ASSIGN_OR_RETURN( - auto cos_q, EmitLibdeviceMathCall("__nv_cos", {q}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_q, EmitLibdeviceMathCall("__nv_sin", {q}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q), - ir_builder_->CreateFMul(coeff, sin_q)); - } - default: - return ElementalIrEmitter::EmitComplexBinaryOp(op, lhs_value, rhs_value); - } -} - -StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const { - PrimitiveType input_type = op->operand(0)->shape().element_type(); - PrimitiveType component_type = - primitive_util::IsComplexType(input_type) - ? primitive_util::ComplexComponentType(input_type) - : input_type; - - switch (op->opcode()) { - case HloOpcode::kLog: { - // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - llvm::Type* llvm_ty = a->getType(); - auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), - ir_builder_->CreateFMul(b, b)); - TF_ASSIGN_OR_RETURN( - auto log_sum_sq, - EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a}, - {component_type, component_type}, - component_type)); - auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); - return EmitComposeComplex( - op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle); - } - case HloOpcode::kExp: { - // e^(a+bi) = e^a*(cos(b)+sin(b)i) - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractReal(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), - ir_builder_->CreateFMul(exp_a, sin_b)); - } - case HloOpcode::kCos: { - // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); - } - - case HloOpcode::kSin: { - // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) - auto a = EmitExtractReal(operand_value); - auto llvm_ty = a->getType(); - TF_ASSIGN_OR_RETURN( - auto exp_b, - EmitLibdeviceMathCall("__nv_exp", {EmitExtractImag(operand_value)}, - {component_type}, component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, - component_type)); - auto half_exp_b = - ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - auto half_exp_neg_b = - ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); - return EmitComposeComplex( - op, - ir_builder_->CreateFMul( - sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), - ir_builder_->CreateFMul( - cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); - } - case HloOpcode::kTanh: { - /* - tanh=(exp(x)-exp(-x)) / (exp(x)+exp(-x)) - e^(a+bi) = e^a*(cos(b)+sin(b)i) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(-b)+sin(-b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(-b)+sin(-b)i)e^-a)) - cos(b)=cos(-b), sin(-b)=-sin(b) - so tanh=(((cos(b)+sin(b)i)e^a - (cos(b)-sin(b)i)e^-a)) / - (((cos(b)+sin(b)i)e^a + (cos(b)-sin(b)i)e^-a)) - =(cos(b)e^a+i*sin(b)e^a + cos(b)(-e^-a)+i*sin(b)e^-a) / - (cos(b)e^a+i*sin(b)e^a + cos(b)e^-a+i*sin(b)(-e^-a)) - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) / - (cos(b)(e^a+e^-a) + i*sin(b)(e^a-e^-a)) - This is a complex division, so we can multiply by denom_conj/denom_conj - =(cos(b)(e^a-e^-a) + i*sin(b)(e^a+e^-a)) * - (cos(b)(e^a+e^-a) - i*sin(b)(e^a-e^-a)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - =(cos(b)^2(e^(2a)-e^(-2a)) + sin(b)^2(e^(2a)-e^(-2a)) + - i*(cos(b)sin(b)(e^a+e^-a)^2 - cos(b)sin(b)(e^a-e^-a)^2)) / - ((cos(b)(e^a+e^-a))^2 + (sin(b)(e^a-e^-a))^2) - */ - auto a = EmitExtractReal(operand_value); - auto b = EmitExtractImag(operand_value); - TF_ASSIGN_OR_RETURN( - auto exp_a, EmitLibdeviceMathCall("__nv_exp", {a}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, - component_type)); - TF_ASSIGN_OR_RETURN( - auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, - component_type)); - auto exp_neg_a = ir_builder_->CreateFDiv( - llvm::ConstantFP::get(exp_a->getType(), 1), exp_a); - auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub( - ir_builder_->CreateFMul(exp_a, exp_a), - ir_builder_->CreateFMul(exp_neg_a, exp_neg_a)); - auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b); - auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b); - auto real_num = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a), - ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a)); - auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b); - auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a); - auto exp_a_plus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a); - auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a); - auto exp_a_minus_exp_neg_a_sq = - ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a); - auto imag_num = ir_builder_->CreateFMul( - cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq, - exp_a_minus_exp_neg_a_sq)); - auto denom = ir_builder_->CreateFAdd( - ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq), - ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq)); - return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom), - ir_builder_->CreateFDiv(imag_num, denom)); - } - default: - return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); - } -} - llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 6a537d015209bc507af36b13eeb5d69ce58d8fea..77d4569b1e8e398005e8f517ff086a77aedd382d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -54,20 +54,31 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitComplexUnaryOp( - const HloInstruction* op, llvm::Value* operand_value) const override; - StatusOr EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; - StatusOr EmitComplexBinaryOp( - const HloInstruction* op, llvm::Value* lhs_value, - llvm::Value* rhs_value) const override; - StatusOr EmitErfcInv(PrimitiveType prim_type, llvm::Value* value) const override; + StatusOr EmitLog(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitSin(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitCos(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitExp(PrimitiveType prim_type, + llvm::Value* value) const override; + + StatusOr EmitPow(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + + StatusOr EmitAtan2(PrimitiveType prim_type, llvm::Value* lhs, + llvm::Value* rhs) const override; + llvm::Value* EmitThreadId() const override; private: diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.cc b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc new file mode 100644 index 0000000000000000000000000000000000000000..66931bdc8b1030b2b2e7731ce6327c1e908d4ee6 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.cc @@ -0,0 +1,234 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" + +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace se = ::perftools::gputools; + +namespace xla { +namespace gpu { + +FftScratchAllocator::FftScratchAllocator( + int device_ordinal, DeviceMemoryAllocator* memory_allocator) + : device_ordinal_(device_ordinal), memory_allocator_(memory_allocator) {} + +FftScratchAllocator::~FftScratchAllocator() { + for (auto& allocated_buffer : allocated_buffers_) { + if (!memory_allocator_->Deallocate(device_ordinal_, &allocated_buffer) + .ok()) { + // The program can still continue with failed deallocation. + LOG(ERROR) << "Failed to deallocate the allocated buffer: " + << allocated_buffer.opaque(); + } + } +} + +int64 FftScratchAllocator::GetMemoryLimitInBytes(se::Stream* stream) { + constexpr int64 kFftScratchSize = 1LL << 32; // 4GB by default. + return kFftScratchSize; +} + +se::port::StatusOr> FftScratchAllocator::AllocateBytes( + se::Stream* stream, int64 byte_size) { + CHECK_GE(byte_size, 0) << "byte_size must be positive."; + if (byte_size > GetMemoryLimitInBytes(stream)) { + return se::port::Status( + se::port::error::RESOURCE_EXHAUSTED, + tensorflow::strings::Printf( + "Allocating %lld bytes exceeds the memory limit of %lld bytes.", + byte_size, GetMemoryLimitInBytes(stream))); + } + + auto status_or_memory = + memory_allocator_->Allocate(device_ordinal_, byte_size, + /*retry_on_failure=*/false); + if (!status_or_memory.ok()) { + return tensorflow::errors::ResourceExhausted( + "Failed to allocate %lld bytes on device %d.", byte_size, + device_ordinal_); + } + se::DeviceMemoryBase allocated_buffer = status_or_memory.ValueOrDie(); + allocated_buffers_.push_back(allocated_buffer); + total_allocated_bytes_ += byte_size; + return se::DeviceMemory(allocated_buffer); +} + +namespace { + +se::fft::Type FftTypeToSeType(FftType type) { + switch (type) { + case FftType::FFT: + return se::fft::Type::kC2CForward; + case FftType::IFFT: + return se::fft::Type::kC2CInverse; + case FftType::IRFFT: + return se::fft::Type::kC2R; + case FftType::RFFT: + return se::fft::Type::kR2C; + default: + LOG(FATAL) << "unsupported fft type"; + } +} + +string FftTypeToString(se::fft::Type type) { + switch (type) { + case se::fft::Type::kC2CForward: + return "FFT"; + case se::fft::Type::kC2CInverse: + return "IFFT"; + case se::fft::Type::kC2R: + return "IRFFT"; + case se::fft::Type::kR2C: + return "RFFT"; + default: + LOG(FATAL) << "unknown fft type"; + } +} + +} // namespace + +FftThunk::FftThunk(FftType fft_type, + tensorflow::gtl::ArraySlice fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape, + const HloInstruction* hlo) + : Thunk(Kind::kFft, hlo), + fft_type_(FftTypeToSeType(fft_type)), + fft_length_(fft_length.begin(), fft_length.end()), + scale_factor_(1.0f), + input_buffer_(input_buffer), + output_buffer_(output_buffer), + input_shape_(input_shape), + output_shape_(output_shape) {} + +tensorflow::Status FftThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, se::Stream* stream) { + VLOG(3) << "FFT type: " << FftTypeToString(fft_type_); + VLOG(3) << "Input shape: " << ShapeUtil::HumanStringWithLayout(input_shape_); + VLOG(3) << "Output shape: " + << ShapeUtil::HumanStringWithLayout(output_shape_); + + FftScratchAllocator scratch_allocator(buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + + if (fft_plan_ == nullptr) { + const int64 fft_rank = fft_length_.size(); + CHECK_LE(fft_rank, 3); + int batch_size = 1; + for (int i = 0; i < input_shape_.dimensions_size() - fft_rank; ++i) { + batch_size *= input_shape_.dimensions(i); + } + uint64 fft_length[3]; + uint64 input_embed[3]; + const uint64 input_stride = 1; + uint64 input_distance = 1; + uint64 output_embed[3]; + const uint64 output_stride = 1; + uint64 output_distance = 1; + + for (int i = 0; i < fft_rank; ++i) { + auto dim_offset = input_shape_.dimensions_size() - fft_rank + i; + fft_length[i] = static_cast(fft_length_[i]); + input_embed[i] = input_shape_.dimensions(dim_offset); + input_distance *= input_shape_.dimensions(dim_offset); + output_embed[i] = output_shape_.dimensions(dim_offset); + output_distance *= output_shape_.dimensions(dim_offset); + } + + constexpr bool kInPlaceFft = false; + fft_plan_ = + stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + stream, fft_rank, fft_length, input_embed, input_stride, + input_distance, output_embed, output_stride, output_distance, + fft_type_, kInPlaceFft, batch_size, &scratch_allocator); + scale_factor_ = 1.0f / output_distance; + } else { + stream->parent()->AsFft()->UpdatePlanWithScratchAllocator( + stream, fft_plan_.get(), &scratch_allocator); + } + + bool launch_ok; + switch (fft_type_) { + case se::fft::Type::kC2CForward: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } + case se::fft::Type::kC2CInverse: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = + stream + ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + complex64(scale_factor_), &output_data, 1) + .ok(); + } + break; + } + case se::fft::Type::kR2C: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + break; + } + case se::fft::Type::kC2R: { + se::DeviceMemory input_data( + buffer_allocations.GetDeviceAddress(input_buffer_)); + se::DeviceMemory output_data( + buffer_allocations.GetDeviceAddress(output_buffer_)); + launch_ok = + stream->ThenFft(fft_plan_.get(), input_data, &output_data).ok(); + if (launch_ok) { + launch_ok = stream + ->ThenBlasScal(ShapeUtil::ElementsIn(output_shape_), + scale_factor_, &output_data, 1) + .ok(); + } + break; + } + default: + LOG(FATAL) << "unsupported fft type"; + } + if (launch_ok) { + return tensorflow::Status::OK(); + } + return InternalError("Unable to launch fft for thunk %p with type %s", this, + FftTypeToString(fft_type_).c_str()); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/fft_thunk.h b/tensorflow/compiler/xla/service/gpu/fft_thunk.h new file mode 100644 index 0000000000000000000000000000000000000000..52fb8c376d7acea0f15aaa865c23fa2382717338 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/fft_thunk.h @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ + +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { +namespace gpu { + +// A one-time scratch allocator for FFT. The scratch buffers allocated are +// released on destruction. +// +// Not thread-safe in that AllocateBytes, destructor are not locked. +class FftScratchAllocator : public perftools::gputools::ScratchAllocator { + public: + FftScratchAllocator(int device_ordinal, + DeviceMemoryAllocator* memory_allocator); + + ~FftScratchAllocator() override; + + int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override; + + int64 TotalAllocatedBytes() { return total_allocated_bytes_; } + + perftools::gputools::port::StatusOr> + AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override; + + private: + const int device_ordinal_; + DeviceMemoryAllocator* memory_allocator_; + std::vector allocated_buffers_; + int64 total_allocated_bytes_ = 0; +}; + +// This class stores everything that StreamExecutor needs to launch an FFT. +// It is generated by IrEmitter. +// +// This is thread-compatible. +class FftThunk : public Thunk { + public: + // Constructs a thunk for launching an FFT on a stream. + // Semantics of null hlo_instruction argument are as in Thunk. + FftThunk(FftType fft_type, tensorflow::gtl::ArraySlice fft_length, + const BufferAllocation::Slice& input_buffer, + const BufferAllocation::Slice& output_buffer, + const Shape& input_shape, const Shape& output_shape, + const HloInstruction* hlo); + + FftThunk(const FftThunk&) = delete; // Cannot share fft_plan_ + FftThunk& operator=(const FftThunk&) = delete; // Cannot share fft_plan_ + + // Does the FFT for the thunk on "stream". + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const perftools::gputools::fft::Type fft_type_; + const std::vector fft_length_; + + float scale_factor_; + + std::unique_ptr fft_plan_; + + const BufferAllocation::Slice input_buffer_; + const BufferAllocation::Slice output_buffer_; + + const Shape input_shape_; + const Shape output_shape_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FFT_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/for_thunk.h b/tensorflow/compiler/xla/service/gpu/for_thunk.h index 525a2af941e77a27c0e01543e00e8a4c3e4b9f62..832494d17e9c4e1d9e92e18ef331df1cf3689024 100644 --- a/tensorflow/compiler/xla/service/gpu/for_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/for_thunk.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ #include @@ -49,4 +49,4 @@ class ForThunk : public Thunk { } // namespace gpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FOR_THUNK_H_ diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.h b/tensorflow/compiler/xla/service/gpu/fusion_merger.h index bd720f8584f6254c43a3e2a1a5399aa919eebbc0..4c523a66de977cd32423b25f0d165c4f4ba51c4a 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.h +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -44,4 +44,4 @@ class FusionMerger : public HloPassInterface { } // namespace gpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_FUSION_MERGER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index e784046450ed1cca088770c65c786e80adda869f..8e3aebbc12b5e6d746700956b9743bc94db50167 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -264,9 +264,9 @@ tensorflow::Status GemmThunk::ExecuteOnStream( auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, bool transpose) -> MatrixDescriptor { - bool is_row_major = shape.layout().minor_to_major(0) != 0; - bool layout_mismatch = shape.layout().minor_to_major(0) != - output_shape_.layout().minor_to_major(0); + bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != + LayoutUtil::Minor(output_shape_.layout(), 0); return MatrixDescriptor(data, transpose ^ layout_mismatch, shape.dimensions(is_row_major), shape.dimensions(!is_row_major)); @@ -320,7 +320,7 @@ tensorflow::Status GemmThunk::ExecuteOnStream( }; bool launch_ok; - if (output_shape_.layout().minor_to_major(0) == 0) { + if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { launch_ok = launch( lhs_descriptor, rhs_descriptor, MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h index 983cb872924f22be0dfad8aa9ad86f233b909c46..8c6a1f51a8a09ef78950dfe7e89994a3fe247f49 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.h @@ -52,6 +52,15 @@ class GemmThunk : public Thunk { const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) override; + // Returns true if we'll perform autotuning if run on the given stream. If + // so, we want the GPU to be quiescent during autotuning, so as not to + // introduce noise in our results. + bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* stream) override { + return autotune_results_.count( + stream->parent()->GetDeviceDescription().name()) != 0; + } + private: const BufferAllocation::Slice lhs_buffer_; const BufferAllocation::Slice rhs_buffer_; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index fcd73fd37a2d9ae3c24b56970e3e992da5944682..28ebd034ee0c89137f4e6eb417d8a37f4a00af7a 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -18,30 +18,37 @@ limitations under the License. #include #include #include +#include // NOLINT(build/c++11): only using std::call_once, not mutex. #include #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Verifier.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/algebraic_simplifier.h" -#include "tensorflow/compiler/xla/service/batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/batchnorm_expander.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" -#include "tensorflow/compiler/xla/service/gpu/convolution_folding.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" #include "tensorflow/compiler/xla/service/gpu/gpu_executable.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" #include "tensorflow/compiler/xla/service/gpu/hlo_schedule.h" #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.h" #include "tensorflow/compiler/xla/service/gpu/pad_insertion.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" @@ -52,6 +59,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_constant_folding.h" #include "tensorflow/compiler/xla/service/hlo_cse.h" #include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h" @@ -64,6 +72,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -74,9 +83,11 @@ limitations under the License. #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" namespace se = ::perftools::gputools; @@ -90,14 +101,6 @@ namespace gpu { namespace { using tensorflow::port::Tracing; -using tensorflow::strings::StrCat; - -// Any address of a variable residing in global memory or returned by one of the -// memory allocation routines from the driver or runtime API is always aligned -// to at least 256 bytes. -// -// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses -constexpr int64 kMemoryAlignment = 256; // Returns the directory containing nvvm libdevice files. config_cuda_data_dir // should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the @@ -125,31 +128,46 @@ string GetLibdeviceDir(const string& config_cuda_data_dir) { } // Runs optimization passes on the given HLO module. -tensorflow::Status OptimizeHloModule( - HloModule* hlo_module, - const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { +tensorflow::Status OptimizeHloModule(HloModule* hlo_module, + se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { { HloPassPipeline pipeline("optimization"); - pipeline.AddInvariantChecker(shape_size_function); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); ReducePrecisionInsertion::AddPasses( &pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::BEFORE_OPTIMIZATION); // TODO(b/64094172): make Call work on GPU instead of inlining. pipeline.AddPass(); + // Convert BF16 operations to F32 operations so that the GPU backend can + // support BF16 operations without directly implementing a BF16 lowering for + // most ops. + pipeline.AddPass(BF16, F32); + pipeline.AddPass(); { auto& pass = pipeline.AddPass>("simplification"); - pass.AddInvariantChecker(shape_size_function); + pass.AddInvariantChecker(); - // TODO(b/62764704): Do not rewrite on GPU, use cuDNN's BatchNorm APIs - // instead. - pass.AddPass( + // If cudnn batchnorms are enabled, rewrite batchnorm HLOs to cudnn calls + // where possible. Not every batchnorm op can be implemented as a call to + // cudnn, so decompose any remaining batchnorm ops into a soup of HLOs. + if (hlo_module->config().debug_options().xla_gpu_use_cudnn_batchnorm()) { + pass.AddPass(); + } + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, /*rewrite_grad_op=*/true, /*use_fusion=*/false); + + // BatchNormExpander can create zero-sized ops, so zero-sized HLO + // elimination has to come after that pass. + pipeline.AddPass(); + pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); @@ -159,7 +177,7 @@ tensorflow::Status OptimizeHloModule( pass.AddPass(); pass.AddPass(); } - pipeline.AddPass(); + pipeline.AddPass( [](const HloInstruction& dot, const TransposeFolding::OperandIndices& candidate_operands) { @@ -171,16 +189,68 @@ tensorflow::Status OptimizeHloModule( pipeline.AddPass(); TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); } + + { + // Convert convolutions into CustomCalls to cudnn, then canonicalize them + // (PadInsertion). + HloPassPipeline pipeline("conv_canonicalization"); + pipeline.AddInvariantChecker(); + pipeline.AddPass(); + pipeline.AddPass(); + + // Choose the fastest algorithm for each conv. + // + // In theory doing this here is way too early: It needs to happen after + // layout assignment, because the layout of the inputs/outputs affects the + // speed of the conv. But currently we only allow only one input/output + // layout when calling cudnn, so there's no ambiguity. + // + // We pick the algorithm at this early stage so we can generate better HLO. + // After CudnnConvolutionRewriter, our convolutions are CustomCalls which + // return a tuple (conv_result, scratch_memory), and the each conv uses 0 + // bytes of scratch: + // + // customcall = (f32[...], f32[0]) + // return gte(customcall, 0) + // + // The algorithm picker then chooses the best algorithm, and potentially + // increases the scratch space. It replaces customcall with new_tuple, + // giving us the following: + // + // new_customcall = (f32[...], f32[N]) + // new_tuple = tuple(gte(new_customcall, 0), constant f32[0]) + // return gte(new_tuple, 0) + // + // The new tuple and gte instructions then be simplified away, because + // nobody is expected to use the scratch value. + // + // However, if we were to run CudnnConvolutionAlgorithmPicker after layout + // assignment, fusion would already have run, and the gte(customcall, 0) + // would probably already be into a fusion node. We can't simplify across + // HloComputation boundaries, so in this case we wouldn't be able to + // simplify away the new_tuple bits. + // + // We'll need to revisit this if we ever allow multiple layouts for the + // inputs/outputs of a cudnn convolution. + pipeline.AddPass(stream_exec, + device_allocator); + // Clean up new_tuple described above. + pipeline.AddPass(); + pipeline.AddPass(); + + TF_RETURN_IF_ERROR(pipeline.Run(hlo_module).status()); + } + { HloPassFix fusion("fusion"); - fusion.AddInvariantChecker(shape_size_function); + fusion.AddInvariantChecker(); fusion.AddPass(/*may_duplicate=*/false); fusion.AddPass(/*may_duplicate=*/true); fusion.AddPass(); TF_RETURN_IF_ERROR(fusion.Run(hlo_module).status()); HloPassPipeline reduce_pipeline("reduce-precision"); - reduce_pipeline.AddInvariantChecker(shape_size_function); + reduce_pipeline.AddInvariantChecker(); ReducePrecisionInsertion::AddPasses( &reduce_pipeline, hlo_module->config().debug_options(), ReducePrecisionInsertion::PassTiming::AFTER_FUSION); @@ -198,19 +268,18 @@ tensorflow::Status OptimizeHloModule( // Modifies the given HLO module so that it will be accepted by IrEmitter. // Unlike optimization passes, the passes are necessary for correctness. -tensorflow::Status PrepareHloModuleForIrEmitting( - HloModule* hlo_module, - const HloCostAnalysis::ShapeSizeFunction& shape_size_function) { +tensorflow::Status PrepareHloModuleForIrEmitting(HloModule* hlo_module) { // In some cases, we have to place the result of an instruction in a temporary // buffer. For instance, the buffer that holds an external parameter is // assumed immutable at this point, and should not be reused for output // (b/27180329). Therefore, in that case, we set the output to be a copy of // the parameter. HloPassPipeline pipeline("GPU-ir-emit-prepare"); - pipeline.AddInvariantChecker(shape_size_function); - pipeline.AddPass(); + pipeline.AddInvariantChecker(); + pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); + // The LayoutAssignment pass may leave behind kCopy instructions which are // duplicate or NOPs, so remove them with algebraic simplification and CSE. pipeline.AddPass>( @@ -229,6 +298,103 @@ tensorflow::Status PrepareHloModuleForIrEmitting( return pipeline.Run(hlo_module).status(); } +// Prints a warning if the ptxas at ptxas_path has known bugs. +// +// Only prints a warning the first time it's called for a particular value of +// ptxas_path. +void WarnIfBadPtxasVersion(const string& ptxas_path) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static std::unordered_set* seen_ptxas_paths GUARDED_BY(mu) = + new std::unordered_set(); + + tensorflow::mutex_lock lock(mu); + if (!seen_ptxas_paths->insert(ptxas_path).second) { + // Already checked this ptx binary, nothing to do. + return; + } + + tensorflow::SubProcess ptxas; + ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"}); + ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + if (!ptxas.Start()) { + LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version"; + return; + } + + string out; + int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, + /*stderr_output=*/nullptr); + if (exit_code != 0) { + LOG(WARNING) << "Running " << ptxas_path << " --version returned " + << exit_code; + return; + } + + int64 vmaj, vmin, vdot; + string vmaj_str, vmin_str, vdot_str; + if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, + &vmin_str, &vdot_str) || + !tensorflow::strings::safe_strto64(vmaj_str, &vmaj) || + !tensorflow::strings::safe_strto64(vmin_str, &vmin) || + !tensorflow::strings::safe_strto64(vdot_str, &vdot)) { + LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path + << " --version:\n" + << out; + return; + } + + // ptxas 9.0 before 9.0.276 and ptxas 9.1 before 9.1.121 miscompile some + // address calculations with large offsets (e.g. "load ptr + large_constant"), + // b/70245379. + if ((vmaj == 9 && vmin == 0 && vdot < 276) || + (vmaj == 9 && vmin == 1 && vdot < 121)) { + LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." + << vmin << "." << vdot + << ", which is in range [9.0.0, 9.0.276) + [9.1.0, 9.1.121). " + "These versions are known to miscompile XLA code, leading " + "to incorrect results or invalid-address errors."; + } +} + +// Prints a warning if the ptx->sass JIT in the driver has known bugs. +// +// Using such a driver only a problem if we fail to use ptxas to compile our ptx +// and have to use the driver instead, so you should only call this function if +// we're going to use the driver JIT. +// +// Only prints a warning the first time it's called. +void WarnIfBadDriverJITVersion() { + static std::once_flag run_once; + std::call_once(run_once, [] { + auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion(); + if (!version_or_status.ok()) { + LOG(WARNING) << "Couldn't read CUDA driver version."; + return; + } + se::cuda::DriverVersion version = version_or_status.ValueOrDie(); + + // The following versions of the driver JIT miscompile some address + // calculations with large offsets (e.g. "load ptr + large_constant"), + // b/70245379: + // + // - 384.x before 384.108 + // - 387.x before 387.40 + // - 390.x before 390.10. + auto vmaj = std::get<0>(version); + auto vmin = std::get<1>(version); + if ((vmaj == 384 && vmin < 108) || // + (vmaj == 387 && vmin < 40) || // + (vmaj == 390 && vmin < 10)) { + LOG(WARNING) + << "*** WARNING *** Invoking the PTX->SASS JIT from driver version " + << se::cuda::DriverVersionToString(version) + << ", which is in range [384.0.0, 384.108.0) + [387.0.0, 387.40.0) + " + "[390.0.0, 390.10.0). These versions are known to miscompile XLA " + "code, leading to incorrect results or invalid-address errors."; + } + }); +} + // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. StatusOr> CompilePtx(const string& ptx, int cc_major, @@ -240,6 +406,8 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, auto env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + WarnIfBadPtxasVersion(ptxas_path); + // Write ptx into a temporary file. string ptx_path; if (!env->LocalTempFilename(&ptx_path)) { @@ -263,8 +431,9 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, tensorflow::Env::Default()->DeleteFile(cubin_path).IgnoreError(); }); tensorflow::SubProcess ptxas_info_dumper; - std::vector ptxas_args = {ptxas_path, ptx_path, "-o", cubin_path, - StrCat("-arch=sm_", cc_major, cc_minor)}; + std::vector ptxas_args = { + ptxas_path, ptx_path, "-o", cubin_path, + tensorflow::strings::StrCat("-arch=sm_", cc_major, cc_minor)}; if (VLOG_IS_ON(2)) { ptxas_args.push_back("-v"); } @@ -294,25 +463,28 @@ StatusOr> CompilePtx(const string& ptx, int cc_major, } // namespace GpuCompiler::GpuCompiler() - : pointer_size_(llvm::DataLayout(kDataLayout).getPointerSize()) {} + : pointer_size_(llvm::DataLayout(kDataLayout) + .getPointerSize(0 /* default address space */)) {} StatusOr> GpuCompiler::RunHloPasses( - std::unique_ptr module, se::StreamExecutor* /*stream_exec*/) { + std::unique_ptr module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunHloPasses"); Tracing::TraceMe annotation("HLO Transforms", module->name(), /*is_expensive=*/true); - TF_RETURN_IF_ERROR(OptimizeHloModule(module.get(), ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR( + OptimizeHloModule(module.get(), stream_exec, device_allocator)); return std::move(module); } StatusOr> GpuCompiler::RunBackend( - std::unique_ptr module, se::StreamExecutor* stream_exec) { + std::unique_ptr module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend"); TF_RET_CHECK(stream_exec != nullptr); - TF_RETURN_IF_ERROR( - PrepareHloModuleForIrEmitting(module.get(), ShapeSizeBytesFunction())); + TF_RETURN_IF_ERROR(PrepareHloModuleForIrEmitting(module.get())); llvm::LLVMContext llvm_context; std::string buffer; @@ -343,19 +515,21 @@ StatusOr> GpuCompiler::RunBackend( TF_ASSIGN_OR_RETURN( std::unique_ptr buffer_assignment, BufferAssigner::Run(module.get(), hlo_schedule->ConsumeHloOrdering(), - BufferSizeBytesFunction(), [](LogicalBuffer::Color) { - return kMemoryAlignment; + BufferSizeBytesFunction(), + /*color_alignment=*/[](LogicalBuffer::Color) { + return kCudaMallocAlignBytes; })); - // BufferAssignment::ToString() includes a header, so no need for us to - // print one ourselves. + // BufferAssignment::Stats::ToString() and BufferAssignment::ToString() + // include headers, so no need for us to print them ourselves. + XLA_VLOG_LINES(1, buffer_assignment->GetStats().ToString()); XLA_VLOG_LINES(2, buffer_assignment->ToString()); XLA_VLOG_LINES(2, module->ToString()); - const string xla_dump_hlo_proto_to = - module->config().debug_options().xla_dump_hlo_proto_to(); - if (!xla_dump_hlo_proto_to.empty()) { + const string xla_dump_optimized_hlo_proto_to = + module->config().debug_options().xla_dump_optimized_hlo_proto_to(); + if (!xla_dump_optimized_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( - proto, xla_dump_hlo_proto_to, module->name())); + proto, xla_dump_optimized_hlo_proto_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), @@ -393,6 +567,20 @@ StatusOr> GpuCompiler::RunBackend( /*optimized=*/false)); } + { + XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - Running LLVM verifier"); + + std::string err; + llvm::raw_string_ostream err_stream(err); + + // verifyModule() returns true if the module is broken. + TF_RET_CHECK(!llvm::verifyModule(llvm_module, &err_stream)) + << "Invalid LLVM IR before optimizations:\n" + << err_stream.str() + << "\nThis probably indicates a bug in the HLO -> LLVM IR lowering. " + "Rerun with --xla_dump_ir_to to get the IR. "; + } + string libdevice_dir; { tensorflow::mutex_lock lock(mutex_); @@ -443,7 +631,7 @@ StatusOr> GpuCompiler::RunBackend( // Write PTX to IR dump directory, if IR dumping was requested. if (!ir_dump_directory.empty()) { const string ptx_outfile = tensorflow::io::JoinPath( - ir_dump_directory, StrCat(module->name(), ".ptx")); + ir_dump_directory, tensorflow::strings::StrCat(module->name(), ".ptx")); auto status = [&] { auto* env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(ir_dump_directory)); @@ -466,13 +654,14 @@ StatusOr> GpuCompiler::RunBackend( XLA_VLOG_LINES(2, thunk_schedule->ToString()); std::unique_ptr profile_index_map; - std::unique_ptr profile_printer; + std::unique_ptr profile_printer; if (module->config().hlo_profiling_enabled()) { HloCostAnalysis cost_analysis(ShapeSizeBytesFunction()); + TF_RETURN_IF_ERROR(module->entry_computation()->Accept(&cost_analysis)); profile_index_map = MakeUnique(*module); profile_printer = - CreateHloProfilePrinter(*profile_index_map, cost_analysis); + CreateHloProfilePrinterData(*profile_index_map, cost_analysis); } auto* gpu_executable = new GpuExecutable( @@ -541,6 +730,10 @@ std::vector GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, "GPU driver compile the ptx. " << maybe_cubin.status(); } + + // We're going to use the driver to JIT our PTX->SASS, so warn if + // the JIT in the driver has known bugs. + WarnIfBadDriverJITVersion(); } } cache_value->compilation_done = true; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index 18e34340205b6f51497e26c45520799d21c55a46..c352d4d8462fadb266c55ad437de998e86a6528e 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -51,11 +51,13 @@ class GpuCompiler : public LLVMCompiler { StatusOr> RunHloPasses( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> module, diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.cc b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa360c7f73de2f0f9cf59c22b552b8e60ddb3a87 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.cc @@ -0,0 +1,25 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" + +namespace xla { +namespace gpu { + +// http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses +const int64 kCudaMallocAlignBytes = 256; + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_constants.h b/tensorflow/compiler/xla/service/gpu/gpu_constants.h new file mode 100644 index 0000000000000000000000000000000000000000..eb1ca4c6c95a23d2a08f5f9c3cbc85e7d47d4f89 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_constants.h @@ -0,0 +1,31 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ + +#include "tensorflow/compiler/xla/types.h" + +namespace xla { +namespace gpu { + +// Minimum alignment of cudaMalloc. We require that buffers created by our +// DeviceMemoryAllocator, and all input/output buffers, have this alignment. +extern const int64 kCudaMallocAlignBytes; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_CONSTANTS_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc index 33d739b79d3664fec3586bbc924b7fa2e10d3256..916b556fd43a453a4da2c96217e74c367f8c7653 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc @@ -36,7 +36,7 @@ namespace gpu { StatusOr GpuCopyInsertion::FindOrInsertCopy( HloInstruction* hlo) { - HloInstruction*& copy = inserted_copies_[hlo]; + HloInstruction*& copy = hlo_to_copy_map_[hlo]; if (copy == nullptr) { TF_ASSIGN_OR_RETURN(copy, hlo->parent()->DeepCopyInstruction(hlo)); } @@ -55,45 +55,71 @@ StatusOr GpuCopyInsertion::Run(HloModule* module) { // in IR. for (HloInstruction* hlo : module->entry_computation()->MakeInstructionPostOrder()) { - if (ImplementedAsLibraryCall(*hlo)) { + // Inserts a copy of hlo->operand(n) if it's a constant. + auto copy_operand_if_constant = [&](int64 n) -> Status { + HloInstruction* operand = hlo->mutable_operand(n); + TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); + const auto& values = dataflow->GetValueSet(operand).values(); + if (std::any_of(values.begin(), values.end(), [](const HloValue* value) { + return value->defining_instruction()->opcode() == + HloOpcode::kConstant; + })) { + TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); + TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(n, copy)); + changed = true; + } + return Status::OK(); + }; + + if (IsCustomCallToDnnBatchNorm(*hlo)) { + // The epsilon and feature_index operands to a CUDNN batchnorm op don't + // need to be materialized in memory -- in fact, they must be constants. + // These are the last two operands of all three batchnorm ops. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } else if (IsCustomCallToDnnConvolution(*hlo)) { + // The last two arguments to a CUDNN convolution are two HLO constants for + // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied. + for (int64 i = 0; i < hlo->operand_count() - 2; ++i) { + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); + } + } else if (ImplementedAsLibraryCall(*hlo)) { + // For all other library calls, materialize all the operands into memory. for (int64 i = 0; i < hlo->operand_count(); ++i) { - HloInstruction* operand = hlo->mutable_operand(i); - TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); - const auto& values = dataflow->GetValueSet(operand).values(); - if (std::any_of(values.begin(), values.end(), - [](const HloValue* value) { - return value->defining_instruction()->opcode() == - HloOpcode::kConstant; - })) { - TF_ASSIGN_OR_RETURN(HloInstruction * copy, FindOrInsertCopy(operand)); - TF_RETURN_IF_ERROR(hlo->ReplaceOperandWith(i, copy)); - changed = true; - } + TF_RETURN_IF_ERROR(copy_operand_if_constant(i)); } } } - // Init values of a while node cannot be constants. Insert copies for any - // constants found at the operand of a while. - tensorflow::gtl::FlatSet copied_constants; + // Init values of while and conditional nodes cannot be constants. Insert + // copies for any constants found at the operands of these nodes. + tensorflow::gtl::FlatSet inserted_copies; for (HloComputation* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kWhile) { + if (instruction->opcode() != HloOpcode::kWhile && + instruction->opcode() != HloOpcode::kConditional) { continue; } - for (auto& pair : - dataflow->GetInstructionValueSet(instruction->operand(0))) { - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (value->defining_instruction()->opcode() == - HloOpcode::kConstant && - !ContainsKey(copied_constants, value->defining_instruction())) { - HloInstruction* constant = value->defining_instruction(); - TF_ASSIGN_OR_RETURN(HloInstruction * copy, - FindOrInsertCopy(constant)); - TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); - copied_constants.insert(constant); - changed = true; + for (auto operand : instruction->operands()) { + // Skip the operands that have already been replaced with a copy in a + // previous iteration (which is possible when a constant is used as an + // operand in multiple places). + if (ContainsKey(inserted_copies, operand)) { + continue; + } + for (auto& pair : dataflow->GetInstructionValueSet(operand)) { + const HloValueSet& value_set = pair.second; + for (const HloValue* value : value_set.values()) { + if (value->defining_instruction()->IsConstant() && + !ContainsKey(hlo_to_copy_map_, value->defining_instruction())) { + HloInstruction* constant = value->defining_instruction(); + TF_ASSIGN_OR_RETURN(HloInstruction * copy, + FindOrInsertCopy(constant)); + TF_RETURN_IF_ERROR(constant->ReplaceAllUsesWith(copy)); + inserted_copies.insert(copy); + changed = true; + } } } } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h index 4d77f337e6eb20f7d79acc0829fde26bbe443f25..0c6f9b511f3aac5f62182273b827adcd068cd633 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h @@ -32,13 +32,13 @@ class GpuCopyInsertion : public HloPassInterface { StatusOr Run(HloModule* module) override; protected: - // Returns a copy of `hlo`. Looks in inserted_copies_ first to avoid making + // Returns a copy of `hlo`. Looks in hlo_to_copy_map_ first to avoid making // duplicate copies. StatusOr FindOrInsertCopy(HloInstruction* hlo); // A map containing all copies inserted to materialize operands of library // calls. The key is the copied instruction and the value is the copy. - tensorflow::gtl::FlatMap inserted_copies_; + tensorflow::gtl::FlatMap hlo_to_copy_map_; }; } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index 0fd85e4fb057f144df93d53485570d67c66af0d4..623d6714de501000e38b7698620925f66425f157 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -66,10 +66,12 @@ class HloExecutionProfiler { // If profiling is enabled, sets the total cycle count on the profile from the // execution timer. - ~HloExecutionProfiler() { + void FinishExecution() { + CHECK(!finished_execution_) << "Call FinishExecution only once!"; + finished_execution_ = true; if (do_profile_) { stream_->ThenStopTimer(execution_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->set_total_cycles_executed( *computation_, execution_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -87,7 +89,7 @@ class HloExecutionProfiler { void FinishOperation(const HloInstruction* hlo_instruction) { if (do_profile_) { stream_->ThenStopTimer(per_op_timer_.get()); - stream_->BlockHostUntilDone(); + stream_->BlockHostUntilDone().IgnoreError(); profile_->SetCyclesTakenBy( hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); } @@ -101,6 +103,7 @@ class HloExecutionProfiler { const HloComputation* computation_; std::unique_ptr execution_timer_; std::unique_ptr per_op_timer_; + bool finished_execution_ = false; }; } // namespace @@ -113,9 +116,9 @@ GpuExecutable::GpuExecutable( std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map) - : Executable(std::move(hlo_module), std::move(hlo_profile_printer), + : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)), ptx_(ptx), cubin_(cubin), @@ -143,9 +146,12 @@ Status GpuExecutable::ExecuteThunks( if (do_profile) { LOG(WARNING) << "PROFILING: profiling is enabled"; } + HloExecutionProfiler profiler(do_profile, hlo_execution_profile, main_stream, hlo_module_->entry_computation()); + uint64 start_micros = tensorflow::Env::Default()->NowMicros(); + // Stream 0 indicates `main_stream` and substreams start from stream 1. std::vector::SmartPtr> sub_streams; while (sub_streams.size() + 1 < thunk_schedule_->StreamCount()) { @@ -155,6 +161,9 @@ Status GpuExecutable::ExecuteThunks( run_options->BorrowStream(main_stream->parent()->device_ordinal())); } + // The next event enqueued on stream N must not run until the thunk at + // last_blocking_thunk_for_stream[N] completes. + std::map last_blocking_thunk_for_stream; std::map> thunk_to_finish_event; for (Thunk* thunk : thunk_schedule_->TotalOrder()) { TF_RETURN_IF_ERROR(thunk->Initialize(*this)); @@ -167,15 +176,41 @@ Status GpuExecutable::ExecuteThunks( stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, dependency).get()); } + if (last_blocking_thunk_for_stream.count(stream_no)) { + stream->ThenWaitFor(FindOrDie(thunk_to_finish_event, + last_blocking_thunk_for_stream[stream_no]) + .get()); + last_blocking_thunk_for_stream.erase(stream_no); + } + + // If this thunk requests it, wait for all currently-executing thunks to + // finish. This is useful e.g. if the thunk is about to perform autotuning. + if (thunk->ShouldHaltAllActivityBeforeRunning(stream)) { + TF_RETURN_IF_ERROR(main_stream->BlockHostUntilDone()); + last_blocking_thunk_for_stream.clear(); + } + profiler.StartOperation(); VLOG(2) << "Executing the thunk for " - << thunk->hlo_instruction()->ToString(); + << thunk->hlo_instruction()->ToString() << " on stream " + << stream_no; TF_RETURN_IF_ERROR(thunk->ExecuteOnStream(buffer_allocations, stream)); - if (thunk_schedule_->Depended(thunk)) { + if (thunk_schedule_->Depended(thunk) || thunk->ShouldBlockFutureThunks()) { auto finish_event = MakeUnique(main_stream->parent()); finish_event->Init(); stream->ThenRecordEvent(finish_event.get()); thunk_to_finish_event[thunk] = std::move(finish_event); + + if (thunk->ShouldBlockFutureThunks()) { + // Set last_blocking_thunk_for_stream on all streams other than this one + // so that all other streams will wait for this thunk to complete before + // executing any events that occur later in the total order. + for (int32 i = 0; i < sub_streams.size() + 1; ++i) { + if (i != stream_no) { + last_blocking_thunk_for_stream[i] = thunk; + } + } + } } profiler.FinishOperation(thunk->hlo_instruction()); } @@ -184,90 +219,32 @@ Status GpuExecutable::ExecuteThunks( // Make sure kernels are completed before deallocating temporary buffers. // TODO(b/30100571): we could potentially postpone deallocating the temp // buffers until a different computation is executed. - if (block_host_until_done && !main_stream->BlockHostUntilDone()) { - return InternalError("Failed to complete all kernels launched on stream %p", - main_stream); + if (block_host_until_done) { + Status block_status = main_stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError( + "Failed to complete all kernels launched on stream %p: %s", + main_stream, block_status.error_message().c_str()); + } } - return Status::OK(); -} + profiler.FinishExecution(); + uint64 end_micros = tensorflow::Env::Default()->NowMicros(); -StatusOr GpuExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - se::Stream* stream = run_options->stream(); - DeviceMemoryAllocator* memory_allocator = run_options->allocator(); + { + tensorflow::mutex_lock lock(mutex_); + const double nanoseconds = (end_micros - start_micros) * 1000.0; + execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); - BufferAllocations::Builder buffer_allocations_builder; - for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); - ++i) { - const BufferAllocation& allocation = assignment_->GetAllocation(i); - if (allocation.is_entry_computation_parameter()) { - buffer_allocations_builder.RegisterBuffer( - i, arguments[allocation.parameter_number()]); + // If hlo profiling was disabled then the cycle count is left empty. + if (do_profile) { + execution_profile_.set_compute_cycle_count( + hlo_execution_profile->total_cycles_executed( + *module().entry_computation())); } } - se::StreamExecutor* executor = stream->parent(); - TF_ASSIGN_OR_RETURN( - auto buffer_allocations, - buffer_allocations_builder.Build(*assignment_, executor->device_ordinal(), - memory_allocator)); - - bool block_host_until_done = - !memory_allocator->AllowsAsynchronousDeallocation(); - TF_RETURN_IF_ERROR(ExecuteThunks(run_options, *buffer_allocations, - block_host_until_done, - hlo_execution_profile)); - HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice output_slice, - assignment_->GetUniqueTopLevelOutputSlice()); - se::DeviceMemoryBase output_buffer_address = - buffer_allocations->GetDeviceAddress(output_slice.index()); - - if (ShapeUtil::IsTuple(root->shape())) { - std::set referred_by_output; - if (GetRootPointsToSet().IsAmbiguous()) { - // The points-to set of the root is ambiguous so we need to examine the - // result data to determine which buffers are contained in the result. - TF_ASSIGN_OR_RETURN( - TransferManager * transfer_manager, - TransferManager::GetForPlatform(executor->platform())); - TF_ASSIGN_OR_RETURN(referred_by_output, - transfer_manager->GatherBufferPointersFromTuple( - executor, output_buffer_address, root->shape())); - } else { - // The points-to set of the root is unambiguous so it's known statically - // which buffers are in the result. Gather these buffers using the root's - // points-to set. - TF_RETURN_IF_ERROR(GetRootPointsToSet().ForEachElementWithStatus( - [&referred_by_output, &buffer_allocations, this]( - const ShapeIndex& /*index*/, - const PointsToSet::BufferList& buffers) { - // The points to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction produced - // the array at this element. - CHECK_EQ(1, buffers.size()); - HloInstruction* hlo = buffers[0]->instruction(); - TF_ASSIGN_OR_RETURN( - const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice(hlo, buffers[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - referred_by_output.insert( - buffer_allocations->GetDeviceAddress(slice.index())); - return Status::OK(); - })); - } - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown(referred_by_output, *assignment_)); - } else { - // If the computation result is not a tuple, we can delete all temporary - // buffers that are not the output. - TF_RETURN_IF_ERROR( - buffer_allocations->TearDown({output_buffer_address}, *assignment_)); - } - return output_buffer_address; + return Status::OK(); } StatusOr> GpuExecutable::ExecuteOnStream( @@ -285,9 +262,16 @@ StatusOr> GpuExecutable::ExecuteOnStream( ++i) { const BufferAllocation& allocation = assignment_->GetAllocation(i); if (allocation.is_entry_computation_parameter()) { - auto param_no = allocation.parameter_number(); - buffer_allocations_builder.RegisterBuffer( - i, arguments[param_no]->buffer(/*index=*/{})); + // The caller must give us a buffer for ShapeIndex {} of every parameter. + // It can optionally give us a buffer for other ShapeIndices, but we + // ignore them: Because we can't rely on these sub-buffers' addresses + // being available, our generated code can't use them. Instead, it must + // chase pointers starting at the tuple root. + if (allocation.param_shape_index().empty()) { + auto param_no = allocation.parameter_number(); + buffer_allocations_builder.RegisterBuffer( + i, arguments[param_no]->root_buffer()); + } } } se::StreamExecutor* executor = run_options->stream()->parent(); @@ -305,50 +289,46 @@ StatusOr> GpuExecutable::ExecuteOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); auto shaped_buffer = MakeUnique( - root->shape(), executor->platform(), device_ordinal); + root->shape(), root->shape(), executor->platform(), device_ordinal); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer. std::set buffers_in_result; - TF_RETURN_IF_ERROR( - shaped_buffer->mutable_shape_index_to_buffer_entry() - ->ForEachMutableElementWithStatus( - [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( - const ShapeIndex& index, size_t* buffer_entry) { - const auto& sources = this->GetRootPointsToSet().element(index); - // The points-to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction - // produced the array at this element. - CHECK_EQ(1, sources.size()); - auto src_hlo = sources[0]->instruction(); - - VLOG(4) << "Looking at: " << sources[0]; - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - perftools::gputools::DeviceMemoryBase src_base = - buffer_allocations->GetDeviceAddress(slice.index()); - CHECK(!src_base.is_null() || src_base.size() == 0); - shaped_buffer->mutable_buffers()->push_back(src_base); - *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; - - buffers_in_result.insert(src_base); - return Status::OK(); - })); + TF_RETURN_IF_ERROR(shaped_buffer->buffers().ForEachMutableElementWithStatus( + [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( + const ShapeIndex& index, se::DeviceMemoryBase* device_memory) { + const auto& sources = this->GetRootPointsToSet().element(index); + // The points-to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction + // produced the array at this element. + CHECK_EQ(1, sources.size()); + auto src_hlo = sources[0]->instruction(); + + VLOG(4) << "Looking at: " << sources[0]; + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN( + const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice(src_hlo, sources[0]->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + perftools::gputools::DeviceMemoryBase src_base = + buffer_allocations->GetDeviceAddress(slice.index()); + CHECK(!src_base.is_null() || src_base.size() == 0); + *device_memory = src_base; + buffers_in_result.insert(src_base); + return Status::OK(); + })); TF_RETURN_IF_ERROR( buffer_allocations->TearDown(buffers_in_result, *assignment_)); return std::move(shaped_buffer); } -StatusOr GpuExecutable::ExecuteAsyncOnStream( +StatusOr> GpuExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { // TODO(b/30671675): Implement asynchronous execution mode. return Unimplemented( "Asynchronous execution on stream is not yet supported on GPU."); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index e7307e07c0b5608e31f15597d31d11c50f81c6d5..b19cfd43debd0a5490495d176fa2f1fcd625da07 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -54,7 +54,7 @@ class GpuExecutable : public Executable { std::unique_ptr thunk_schedule, std::unique_ptr hlo_module, std::unique_ptr assignment, - std::unique_ptr hlo_profile_printer, + std::unique_ptr hlo_profile_printer_data, std::unique_ptr hlo_profile_index_map); // This should be called after set_ir_module_string. @@ -72,24 +72,16 @@ class GpuExecutable : public Executable { // empty, in which case compilation is left up to the GPU driver. const std::vector& cubin() const { return cubin_; } - // Both overloads of ExecuteOnStream will fail if the compute capability of - // the stream doesn't match the compute capability passed to this object's - // constructor. - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - + // ExecuteOnStream will fail if the compute capability of the stream doesn't + // match the compute capability passed to this object's constructor. StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; const Status EqualOrFail(const Executable& executable) { // TODO(b/62952745) Implement equality test on GPU executable. diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc new file mode 100644 index 0000000000000000000000000000000000000000..4944c41f7d8dc7a78a3cd094aee4d7087c74857e --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.cc @@ -0,0 +1,48 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +StatusOr GpuHloSupportChecker::Run(HloModule* module) { + for (auto* computation : module->computations()) { + for (const auto& instruction : computation->instructions()) { + TF_RETURN_IF_ERROR( + ShapeUtil::ValidateShapeWithOptionalLayout(instruction->shape())); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + instruction->shape(), + [&instruction](const Shape& subshape, const ShapeIndex&) { + if (LayoutUtil::IsSparseArray(subshape)) { + return xla::Unimplemented( + "GPU backend does not support HLO instruction %s with shape " + "containing a sparse layout: %s", + instruction->ToString().c_str(), + ShapeUtil::HumanStringWithLayout(instruction->shape()) + .c_str()); + } + return Status::OK(); + })); + } + } + return false; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h new file mode 100644 index 0000000000000000000000000000000000000000..d63e213d2b1efab4bcff75541cc5ab33d7a07976 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ + +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// his pass should run early in the HLO pipeline and checks for HLO constructs +// which are not supported by the GPU backend and cannot be removed via HLO +// transformations (eg, sparse layouts). +class GpuHloSupportChecker : public HloPassInterface { + public: + GpuHloSupportChecker() = default; + ~GpuHloSupportChecker() override = default; + + tensorflow::StringPiece name() const override { + return "gpu_hlo_support_checker"; + } + + // Note: always returns false (no instructions are ever modified by this + // pass). + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_HLO_SUPPORT_CHECKER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..0a4089df4c954cafcbe241189ee79a0995683513 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_hlo_support_checker.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +class GpuHloSupportCheckerTest : public HloTestBase { + protected: + GpuHloSupportChecker& checker() { return checker_; } + + private: + GpuHloSupportChecker checker_; +}; + +TEST_F(GpuHloSupportCheckerTest, Add) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(checker().Run(module.get()).status()); +} + +TEST_F(GpuHloSupportCheckerTest, SparseUnimplemented) { + HloComputation::Builder builder(TestName()); + const Shape sparse_shape = ShapeUtil::MakeShapeWithSparseLayout(F32, {10}, 2); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, sparse_shape, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, sparse_shape, "param1")); + builder.AddInstruction(HloInstruction::CreateBinary( + sparse_shape, HloOpcode::kAdd, param0, param1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Status status = checker().Run(module.get()).status(); + ASSERT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED); + EXPECT_THAT(status.error_message(), + HasSubstr("GPU backend does not support")); + EXPECT_THAT(status.error_message(), + HasSubstr(ShapeUtil::HumanStringWithLayout(sparse_shape))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc new file mode 100644 index 0000000000000000000000000000000000000000..89f1e625884568bf7370b3801d851ef4846c2a98 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -0,0 +1,249 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" + +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace gpu { + +// cuDNN convolutions are called with specific layouts on the input, output, +// and filter: +// +// input: DataLayout::kBatchDepthYX +// output: DataLayout::kBatchDepthYX +// filter: FilterLayout::kOutputInputYX +// +// The order dimensions in the constant name is major-to-minor (eg, the +// most-major dimension of the input is batch, most-minor is X). The +// specific dimension numbers these named dimensions correspond to is +// determined by the ConvolutionDimensionNumbers argument. Y is spatial +// dimension 0, and X is spatial dimension 1. +// +// TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. +static Status AddBackendConstraintsToDnnConvCustomCall( + HloInstruction* instr, LayoutConstraints* constraints) { + CHECK(IsCustomCallToDnnConvolution(*instr)) << instr->ToString(); + Shape input_shape; + Shape filter_shape; + Shape output_shape; + const auto& target = instr->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->shape().tuple_shapes(0); + } else if (target == kCudnnConvBackwardInputCallTarget) { + input_shape = instr->shape().tuple_shapes(0); + filter_shape = instr->operand(1)->shape(); + output_shape = instr->operand(0)->shape(); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + input_shape = instr->operand(0)->shape(); + filter_shape = instr->shape().tuple_shapes(0); + output_shape = instr->operand(1)->shape(); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + + // Construct minor-to-major dimension orders for operands and result. + // cuDNN's convolution APIs support the BDYX layout for activations/output + // and the OIYX layout for weights. + // TODO(b/29399649): Be more flexible about handling layouts of cuDNN + // calls after we switch to cuDNN v5. + const ConvolutionDimensionNumbers& dimension_numbers = + instr->convolution_dimension_numbers(); + std::vector input_layout; + for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; i >= 0; + --i) { + input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); + } + input_layout.push_back(dimension_numbers.input_feature_dimension()); + input_layout.push_back(dimension_numbers.input_batch_dimension()); + *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); + + std::vector filter_layout; + for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; i >= 0; + --i) { + filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); + } + filter_layout.push_back(dimension_numbers.kernel_input_feature_dimension()); + filter_layout.push_back(dimension_numbers.kernel_output_feature_dimension()); + *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); + + std::vector output_layout; + for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; i >= 0; + --i) { + output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); + } + output_layout.push_back(dimension_numbers.output_feature_dimension()); + output_layout.push_back(dimension_numbers.output_batch_dimension()); + *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); + + // The custom call returns a tuple of (actual_result, scratch_buffer); + // call_result_buf is the logical buffer for actual_result, the thing that + // contains the result of the conv call. + TF_ASSIGN_OR_RETURN(const LogicalBuffer* call_result_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instr, /*index=*/{0})); + + // Set layouts of the instructions' shapes. + if (target == kCudnnConvForwardCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(output_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardInputCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(filter_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(input_shape.layout(), *call_result_buf)); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(input_shape, instr, 0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(output_shape, instr, 1)); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(filter_shape.layout(), *call_result_buf)); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << instr->custom_call_target(); + } + return Status::OK(); +} + +Status GpuLayoutAssignment::AddBackendConstraints( + LayoutConstraints* constraints) { + for (auto* instruction : constraints->computation()->instructions()) { + if (IsCustomCallToDnnConvolution(*instruction)) { + TF_RETURN_IF_ERROR( + AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); + } + } + return Status::OK(); +} + +bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( + const HloInstruction* instruction) { + // - Inputs to cudnn batchnorm custom calls don't need the major-first layout + // (i.e. {n, n-1, ...0}) -- we can handle any layout. + // - Inputs to cudnn convolution require custom layouts handled in + // AddBackendConstraints. + return !IsCustomCallToDnnBatchNorm(*instruction) && + !IsCustomCallToDnnConvolution(*instruction); +} + +Status GpuLayoutAssignment::PropagateOperandConstraint( + const OperandLayoutConstraint& layout_constraint, + LayoutConstraints* constraints) { + const HloInstruction* instruction = layout_constraint.instruction(); + + // cudnn batchnorm forward inference's result must have the same layout as its + // operand 0. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardInferenceCallTarget && + layout_constraint.operand_no() == 0) { + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + layout_constraint.shape_layout().shape(), instruction)); + } + + // cudnn batchnorm forward training returns a tuple {output, mean, + // inverse-stddev}. mean and inverse-stddev are rank 1 and so have only one + // possible layout, but output is not (necessarily) rank 1, and, like in + // batchnorm forward inference, must have the same layout as operand 0. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardTrainingCallTarget && + layout_constraint.operand_no() == 0) { + TF_ASSIGN_OR_RETURN(const LogicalBuffer* out_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instruction, /*index=*/{0})); + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + layout_constraint.shape_layout().layout(), *out_buf)); + } + + // Like forward training, cudnn batchnorm backward returns a tuple {output, + // mean, inverse-stddev}, and its operand 0 and 'output' must have the same + // layout. In addition, its operand 0 and operand 4 -- the 'operand' and + // 'grad_output' parameters -- must have the same layout. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == kCudnnBatchNormBackwardCallTarget && + (layout_constraint.operand_no() == 0 || + layout_constraint.operand_no() == 4)) { + TF_ASSIGN_OR_RETURN(const LogicalBuffer* out_buf, + constraints->points_to_analysis().GetBufferDefinedAt( + instruction, /*index=*/{0})); + TF_RETURN_IF_ERROR(constraints->SetBufferLayout( + layout_constraint.shape_layout().layout(), *out_buf)); + + int64 operand_to_set = layout_constraint.operand_no() == 0 ? 4 : 0; + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + layout_constraint.shape_layout().shape(), instruction, operand_to_set)); + } + + return LayoutAssignment::PropagateOperandConstraint(layout_constraint, + constraints); +} + +Status GpuLayoutAssignment::PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) { + const LogicalBuffer& buf = buffer_constraint.buffer(); + const HloInstruction* instruction = buf.instruction(); + + Shape shape_with_layout = buf.shape(); + *shape_with_layout.mutable_layout() = buffer_constraint.layout(); + + // Propagate output constraints to the operands of cudnn batchnorm ops. This + // is the same as PropagateOperandConstraint, just in the other direction. We + // need to both to fulfill our contract to LayoutAssignment. + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardInferenceCallTarget) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/0)); + } + + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == + kCudnnBatchNormForwardTrainingCallTarget && + buf.index() == ShapeIndex({0})) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/0)); + } + if (instruction->opcode() == HloOpcode::kCustomCall && + instruction->custom_call_target() == kCudnnBatchNormBackwardCallTarget && + buf.index() == ShapeIndex({0})) { + // batchnorm backward has two operands, "operand" and "grad_output" whose + // layouts must both match that of the result at tuple-index 0. + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/0)); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + shape_with_layout, instruction, /*operand_no=*/4)); + } + + return LayoutAssignment::PropagateBufferConstraint(buffer_constraint, + constraints); +} + +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h similarity index 70% rename from tensorflow/compiler/xla/service/gpu/layout_assignment.h rename to tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 169041eb85c633cb4f1f679bcea127714828308f..86a3a7111fd79494e469beecf3234f6cec9adb9c 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" @@ -33,9 +33,17 @@ class GpuLayoutAssignment : public LayoutAssignment { protected: Status AddBackendConstraints(LayoutConstraints* constraints) override; + Status PropagateOperandConstraint( + const OperandLayoutConstraint& layout_constraint, + LayoutConstraints* constraints) override; + Status PropagateBufferConstraint( + const BufferLayoutConstraint& buffer_constraint, + LayoutConstraints* constraints) override; + bool CustomCallRequiresMajorFirstLayout( + const HloInstruction* instruction) override; }; } // namespace gpu } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_LAYOUT_ASSIGNMENT_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_GPU_LAYOUT_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4c45d2e94aebce5496da94841f6a1ae9015615c1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -0,0 +1,328 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h" + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_layout.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace gpu { +namespace { + +using LayoutAssignmentTest = HloTestBase; + +TEST_F(LayoutAssignmentTest, Elementwise) { + Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); + Shape ashape_in_row_major(ashape); + Shape ashape_in_col_major(ashape); + *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + + // Enumerate all possible combinations of layouts. + for (const Shape& lhs_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + for (const Shape& rhs_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + for (const Shape& result_shape_with_layout : + {ashape_in_row_major, ashape_in_col_major}) { + // GpuLayoutAssignment should assign the same layout to "add" and its + // two operands. + auto builder = HloComputation::Builder(TestName()); + auto x = builder.AddInstruction( + HloInstruction::CreateParameter(0, ashape, "x")); + auto y = builder.AddInstruction( + HloInstruction::CreateParameter(1, ashape, "y")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(add)); + + ComputationLayout computation_layout( + computation->ComputeProgramShape()); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(lhs_shape_with_layout); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(rhs_shape_with_layout); + *computation_layout.mutable_result_layout() = + ShapeLayout(result_shape_with_layout); + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + for (const HloInstruction* operand : add->operands()) { + EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(), + operand->shape().layout())); + } + } + } + } +} + +// Returns a list shapes with all the possible layouts of this shape, including +// a shape with no layout. +std::vector AllLayoutsOf(const Shape& s) { + std::vector layout_vec(s.dimensions_size()); + std::iota(layout_vec.begin(), layout_vec.end(), 0); + + std::vector shapes; + shapes.push_back(s); + shapes.back().clear_layout(); + + do { + shapes.push_back(s); + *shapes.back().mutable_layout() = LayoutUtil::MakeLayout(layout_vec); + } while (std::next_permutation(layout_vec.begin(), layout_vec.end())); + + return shapes; +} + +TEST_F(LayoutAssignmentTest, BatchNormInference) { + const int64 kFeatureIndex = 1; + + // The shape of the data operand to BatchNormInference and of the output of + // the BatchNormInference call. + Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100}); + + // The shape of the scale, offset, mean, and variance inputs to + // BatchNormTraining. These are rank 1, with as many elements are in the + // kFeatureIndex dim of shape. + Shape aux_shape = + ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)}); + + for (const Shape& input_shape : AllLayoutsOf(shape)) { + for (const Shape& result_shape : AllLayoutsOf(shape)) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), + ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); + + auto builder = HloComputation::Builder(TestName()); + auto* operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")); + auto* scale = builder.AddInstruction( + HloInstruction::CreateParameter(1, aux_shape, "scale")); + auto* offset = builder.AddInstruction( + HloInstruction::CreateParameter(2, aux_shape, "offset")); + auto* mean = builder.AddInstruction( + HloInstruction::CreateParameter(3, aux_shape, "mean")); + auto* variance = builder.AddInstruction( + HloInstruction::CreateParameter(4, aux_shape, "variance")); + + auto* epsilon = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + auto* feature_index = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(kFeatureIndex))); + + auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( + shape, + {operand, scale, offset, mean, variance, epsilon, feature_index}, + kCudnnBatchNormForwardInferenceCallTarget)); + + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(batchnorm)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + + if (input_shape.has_layout()) { + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(input_shape); + } + + if (result_shape.has_layout()) { + *computation_layout.mutable_result_layout() = ShapeLayout(result_shape); + } + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + // The first operand to batchnorm should have the same layout as the + // result. + EXPECT_TRUE(LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(), + batchnorm->shape().layout())) + << batchnorm->ToString(); + } + } +} + +TEST_F(LayoutAssignmentTest, BatchNormTraining) { + const int64 kFeatureIndex = 1; + + // The shape of the data operand to BatchNormTraining. + Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100}); + + // The shape of the offset and scale inputs to BatchNormTraining. These are + // rank 1, with as many elements are in the kFeatureIndex dim of shape. + Shape offset_scale_shape = + ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)}); + + // Shape of the output of our BatchNormTraining op. + Shape batchnorm_shape = ShapeUtil::MakeTupleShape( + {shape, offset_scale_shape, offset_scale_shape}); + + // Enumerate all combinations of shapes. + for (const Shape& input_shape : AllLayoutsOf(shape)) { + for (const Shape& result_shape : AllLayoutsOf(shape)) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), + ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); + + auto builder = HloComputation::Builder(TestName()); + auto* operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")); + auto* scale = builder.AddInstruction( + HloInstruction::CreateParameter(1, offset_scale_shape, "scale")); + auto* offset = builder.AddInstruction( + HloInstruction::CreateParameter(2, offset_scale_shape, "offset")); + + auto* epsilon = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + auto* feature_index = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(kFeatureIndex))); + + auto* batchnorm = builder.AddInstruction(HloInstruction::CreateCustomCall( + batchnorm_shape, {operand, scale, offset, epsilon, feature_index}, + kCudnnBatchNormForwardTrainingCallTarget)); + + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(batchnorm)); + + ComputationLayout computation_layout(computation->ComputeProgramShape()); + + if (input_shape.has_layout()) { + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(input_shape); + } + + if (result_shape.has_layout()) { + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {result_shape, offset_scale_shape, offset_scale_shape})); + } + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + // The first operand to batchnorm should have the same layout as the + // first element of the result tuple. + EXPECT_TRUE( + LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(), + batchnorm->shape().tuple_shapes(0).layout())) + << batchnorm->ToString(); + } + } +} + +TEST_F(LayoutAssignmentTest, BatchNormGrad) { + const int64 kFeatureIndex = 1; + + // The shape of the data operand to BatchNormTraining. + Shape shape = ShapeUtil::MakeShape(F32, {42, 12, 1, 100}); + + // The shape of the scale, mean, and variance inputs to BatchNormGrad. These + // are rank 1, with as many elements are in the kFeatureIndex dim of shape. + Shape scale_shape = + ShapeUtil::MakeShape(F32, {shape.dimensions(kFeatureIndex)}); + + // Shape of the output of our BatchNormGrad op. + Shape batchnorm_shape = + ShapeUtil::MakeTupleShape({shape, scale_shape, scale_shape}); + + // Enumerate all combinations of shapes plus whether we're constraining param + // 0 or param 4. + for (const Shape& input_shape : AllLayoutsOf(shape)) { + for (const Shape& result_shape : AllLayoutsOf(shape)) { + for (int constrained_param_no : {0, 4}) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "input_shape=", ShapeUtil::HumanStringWithLayout(input_shape), + ", result_shape=", ShapeUtil::HumanStringWithLayout(result_shape))); + + auto builder = HloComputation::Builder(TestName()); + auto* operand = builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "operand")); + auto* scale = builder.AddInstruction( + HloInstruction::CreateParameter(1, scale_shape, "scale")); + auto* mean = builder.AddInstruction( + HloInstruction::CreateParameter(2, scale_shape, "mean")); + auto* var = builder.AddInstruction( + HloInstruction::CreateParameter(3, scale_shape, "var")); + auto* grad_offset = builder.AddInstruction( + HloInstruction::CreateParameter(4, shape, "var")); + + auto* epsilon = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))); + auto* feature_index = + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(kFeatureIndex))); + + auto* batchnorm = + builder.AddInstruction(HloInstruction::CreateCustomCall( + batchnorm_shape, + {operand, scale, mean, var, grad_offset, epsilon, + feature_index}, + kCudnnBatchNormBackwardCallTarget)); + + auto module = CreateNewModule(); + HloComputation* computation = + module->AddEntryComputation(builder.Build(batchnorm)); + + ComputationLayout computation_layout( + computation->ComputeProgramShape()); + + if (input_shape.has_layout()) { + *computation_layout.mutable_parameter_layout(constrained_param_no) = + ShapeLayout(input_shape); + } + + if (result_shape.has_layout()) { + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {result_shape, scale_shape, scale_shape})); + } + + GpuLayoutAssignment layout_assignment(&computation_layout); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + // The first and fourth operands to the batchnorm call should have the + // same layout as the first element of the result tuple. + EXPECT_TRUE( + LayoutUtil::Equal(batchnorm->operand(0)->shape().layout(), + batchnorm->shape().tuple_shapes(0).layout())) + << batchnorm->ToString(); + EXPECT_TRUE( + LayoutUtil::Equal(batchnorm->operand(4)->shape().layout(), + batchnorm->shape().tuple_shapes(0).layout())) + << batchnorm->ToString(); + } + } + } +} + +} // namespace +} // namespace gpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index f0f036f7f381db15b84db85d3efeec5d8141884e..af9897769fda371e47af06c19abce9a06015e094 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -44,7 +44,7 @@ GpuTransferManager::GpuTransferManager() : GenericTransferManager( se::cuda::kCudaPlatformId, /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) - .getPointerSize()) {} + .getPointerSize(0 /* default address space */)) {} Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { @@ -54,7 +54,7 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, if (!ShapeUtil::IsTuple(shape)) { int64 size = GetByteSizeRequirement(shape); - return TransferBufferToInfeed(executor, size, literal.InternalData()); + return TransferBufferToInfeed(executor, size, literal.untyped_data()); } if (ShapeUtil::IsNestedTuple(shape)) { @@ -67,20 +67,21 @@ Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, // enqueue the resulting destination device addresses with the // infeed manager. std::vector buffers; - buffers.reserve(literal.tuple_literals_size()); + buffers.reserve(ShapeUtil::TupleElementCount(shape)); auto cleanup = tensorflow::gtl::MakeCleanup([buffers]() { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } }); - for (const auto& tuple_element : literal.tuple_literals()) { - const Shape& tuple_element_shape = tuple_element.shape(); + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& tuple_element_shape = + ShapeUtil::GetTupleElementShape(shape, i); int64 tuple_element_size = GetByteSizeRequirement(tuple_element_shape); TF_ASSIGN_OR_RETURN( gpu::InfeedBuffer * buffer, TransferBufferToInfeedInternal(executor, tuple_element_size, - tuple_element.InternalData())); + literal.untyped_data({i}))); buffers.push_back(buffer); } @@ -105,12 +106,13 @@ Status GpuTransferManager::EnqueueBuffersToInfeed( // infeed requests, blocking on the stream might be // heavy-handed. Figure out if finer-grained acknowledgement is // possible. - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { for (gpu::InfeedBuffer* b : buffers) { b->Done(); } - return InternalError("Failed to complete data transfer on stream %p", - stream); + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->EnqueueBuffers(buffers); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index c2115c49993ef71c4b6dd584e7e0498807666613..061210352cf12e6802d066d311fd2cb481673f15 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -22,12 +22,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace gpu { +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + void HloToIrBindings::EmitBasePointersForHlos( tensorflow::gtl::ArraySlice io_hlos, tensorflow::gtl::ArraySlice non_io_hlos) { @@ -191,7 +196,11 @@ static bool BuffersInvariantWithinConsumer( llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo, const HloInstruction& consumer, const ShapeIndex& shape_index) { - llvm_ir::IrArray ir_array(GetBasePointer(hlo, shape_index), + llvm::Value* base_ptr = GetBasePointer(hlo, shape_index); + CHECK_NE(base_ptr, nullptr) + << "Buffer not assigned for shape_index " << shape_index.ToString() + << " of " << hlo.ToString(); + llvm_ir::IrArray ir_array(base_ptr, ShapeUtil::GetSubshape(hlo.shape(), shape_index)); alias_analysis_.AddAliasingInformationToIrArray(hlo, &ir_array); @@ -223,5 +232,54 @@ void HloToIrBindings::UnbindAllLocalIrValues() { } } +string HloToIrBindings::ToString() const { + string s = StrCat("** HloToIrBindings **\n"); + StrAppend(&s, " is_nested_=", is_nested_, "\n"); + StrAppend(&s, + " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_), + "\n"); + + if (base_ptrs_.empty()) { + return s; + } + + // Iterate over all computations in the module in topological order, and print + // out the base pointers we have in each computation in topological order. + for (const HloComputation* computation : + base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) { + bool is_first = true; + for (const HloInstruction* instr : + computation->MakeInstructionPostOrder()) { + auto it = base_ptrs_.find(instr); + if (it == base_ptrs_.end()) { + continue; + } + if (is_first) { + StrAppend(&s, " Base pointers for computation ", computation->name(), + ":\n"); + is_first = false; + } + StrAppend(&s, " ", instr->ToString()); + + const ShapeTree& shape_tree = it->second; + if (!ShapeUtil::IsTuple(instr->shape())) { + const llvm::Value* val = shape_tree.begin()->second; + StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n"); + continue; + } + + StrAppend(&s, "\n"); + for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end(); + ++shape_it) { + llvm::Value* val = shape_it->second; + StrAppend(&s, " ", shape_it->first.ToString(), " -> ", + (val != nullptr ? llvm_ir::DumpToString(*val) : "null"), + "\n"); + } + } + } + return s; +} + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index 62ae1769a1f2fb3b9acaf35bdf18a793232500b0..3d34311b4368d17cb074aaf33c71fc865e96387e 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -66,13 +66,14 @@ class HloToIrBindings { } llvm::Value* GetTempBufferBase() const { return temp_buffer_base_; } + void SetTempBufferBase(llvm::Value* v) { temp_buffer_base_ = v; } // A helper method that returns the base pointer of the IrArray containing the // output of "inst".at the given ShapeIndex. llvm::Value* GetBasePointer(const HloInstruction& hlo, const ShapeIndex& shape_index = {}) const { auto it = base_ptrs_.find(&hlo); - CHECK(it != base_ptrs_.end()); + CHECK(it != base_ptrs_.end()) << hlo.ToString(); return it->second.element(shape_index); } @@ -87,6 +88,8 @@ class HloToIrBindings { const HloInstruction& consumer, const ShapeIndex& shape_index = {}); + string ToString() const; + private: // Emits IR to resolve (possibly) recursive GetTupleElement instructions. llvm::Value* EmitGetTupleElement(const HloInstruction* gte, @@ -111,7 +114,7 @@ class HloToIrBindings { std::unordered_map> base_ptrs_; // The address of the memory block that contains all temporary buffers. - llvm::Value* temp_buffer_base_; + llvm::Value* temp_buffer_base_ = nullptr; llvm_ir::AliasAnalysis alias_analysis_; }; diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc index e33e904692ca5ad41e17d2e165dbb40b6bd4aa33..2ac95ceb692447c7ac6dbbcd8b9a38876f7a77b6 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.cc @@ -30,9 +30,8 @@ InfeedThunk::InfeedThunk( tuple_element_buffers.end()), destination_buffer_(destination_buffer) {} -tensorflow::Status InfeedThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { +Status InfeedThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { VLOG(2) << "Infeeding to GPU "; perftools::gputools::DeviceMemoryBase destination_address = @@ -66,15 +65,16 @@ tensorflow::Status InfeedThunk::ExecuteOnStream( buffer->length()); } - if (!stream->BlockHostUntilDone()) { - return InternalError("Failed to complete data transfer on stream %p", - stream); + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("Failed to complete data transfer on stream %p: %s", + stream, block_status.error_message().c_str()); } infeed_manager->ReleaseBuffers(infeed_buffers); VLOG(2) << "Infeeding to GPU complete"; - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h index 371d71f9dbdd21cb5f36cc3108c8f398a4a91c29..86918705fa0305217f11753e383200c7bd71474b 100644 --- a/tensorflow/compiler/xla/service/gpu/infeed_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/infeed_thunk.h @@ -43,9 +43,8 @@ class InfeedThunk : public Thunk { InfeedThunk(const InfeedThunk&) = delete; InfeedThunk& operator=(const InfeedThunk&) = delete; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const std::vector tuple_element_buffers_; diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 1d47ffde4331868cbc8a8afb2d01b11e77a7fab0..2d6dad27a59978da6e4719afc50ebee5e641dde0 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -137,49 +137,6 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) { .ValueOrDie()); } -TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { - HloComputation::Builder builder(TestName()); - auto input = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), "input")); - auto filter = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 1, 1, 2}), "filter")); - - Window conv_window; - WindowDimension* conv_window_row = conv_window.add_dimensions(); - conv_window_row->set_size(1); - WindowDimension* conv_window_col = conv_window.add_dimensions(); - conv_window_col->set_size(2); - conv_window_col->set_padding_high(1); - - ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_input_batch_dimension(0); - conv_dnums.set_output_batch_dimension(0); - conv_dnums.set_input_feature_dimension(1); - conv_dnums.set_output_feature_dimension(1); - conv_dnums.add_input_spatial_dimensions(2); - conv_dnums.add_output_spatial_dimensions(2); - conv_dnums.add_input_spatial_dimensions(3); - conv_dnums.add_output_spatial_dimensions(3); - conv_dnums.set_kernel_output_feature_dimension(0); - conv_dnums.set_kernel_input_feature_dimension(1); - conv_dnums.add_kernel_spatial_dimensions(2); - conv_dnums.add_kernel_spatial_dimensions(3); - - auto conv = builder.AddInstruction( - HloInstruction::CreateConvolve(ShapeUtil::MakeShape(F32, {1, 1, 1, 3}), - input, filter, conv_window, conv_dnums)); - auto transpose = builder.AddInstruction(HloInstruction::CreateTranspose( - ShapeUtil::MakeShape(F32, {3, 1, 1, 1}), conv, {3, 2, 1, 0})); - builder.AddInstruction( - HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {3}), transpose)); - - auto module = CreateNewModule(); - module->AddEntryComputation(builder.Build()); - EXPECT_FALSE(GpuInstructionFusion(/*may_duplicate=*/true) - .Run(module.get()) - .ValueOrDie()); -} - TEST_F(InstructionFusionTest, GetTupleElementFused) { HloComputation::Builder builder(TestName()); Shape data_shape = ShapeUtil::MakeShape(F32, {8}); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 658fd05cd4b63c923d21b4a1de16468c0aeec65d..2f65edffea81db7dba1f8545f92b27ea622044e7 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -90,41 +90,93 @@ bool ImplementedAsGemm(const HloInstruction& hlo) { return false; } -bool ImplementedAsDnnConvolution(const HloInstruction& hlo) { - // We can only do this if the HLO is unnested. - if (hlo.parent() != hlo.GetModule()->entry_computation()) { +const char* const kCudnnBatchNormForwardInferenceCallTarget = + "__cudnn$batchNormalizationForwardInference"; +const char* const kCudnnBatchNormForwardTrainingCallTarget = + "__cudnn$batchNormalizationForwardTraining"; +const char* const kCudnnBatchNormBackwardCallTarget = + "__cudnn$batchNormalizationBackward"; + +bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { return false; } + const auto& target = hlo.custom_call_target(); + return target == kCudnnBatchNormForwardInferenceCallTarget || + target == kCudnnBatchNormForwardTrainingCallTarget || + target == kCudnnBatchNormBackwardCallTarget; +} - // Forward convolution. - if (hlo.opcode() == HloOpcode::kConvolution) { - const ConvolutionDimensionNumbers& dnums = - hlo.convolution_dimension_numbers(); - if (dnums.input_spatial_dimensions_size() > 3) { - return false; - } - - // CuDNN does not accept zero-element arguments - if (ShapeUtil::HasZeroElements(hlo.operand(0)->shape()) || - ShapeUtil::HasZeroElements(hlo.operand(1)->shape())) { - return false; - } +const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward"; +const char* const kCudnnConvBackwardInputCallTarget = + "__cudnn$convBackwardInput"; +const char* const kCudnnConvBackwardFilterCallTarget = + "__cudnn$convBackwardFilter"; - return true; +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) { + if (hlo.opcode() != HloOpcode::kCustomCall) { + return false; } + const auto& target = hlo.custom_call_target(); + return target == kCudnnConvForwardCallTarget || + target == kCudnnConvBackwardInputCallTarget || + target == kCudnnConvBackwardFilterCallTarget; +} - // Backward convolution. - if (hlo.opcode() == HloOpcode::kFusion && - (hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardFilter || - hlo.fusion_kind() == HloInstruction::FusionKind::kConvBackwardInput)) { - return true; - } +bool ImplementedAsLibraryCall(const HloInstruction& hlo) { + return ImplementedAsGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) || + IsCustomCallToDnnConvolution(hlo); +} - return false; +static HloInstruction* CreateCudnnConv( + const char* call_target, const Shape& shape, HloInstruction* lhs, + HloInstruction* rhs, const Window& window, + const ConvolutionDimensionNumbers& dnums) { + HloComputation* computation = lhs->parent(); + + // This call returns a tuple of (conv_result, scratch_memory), where + // conv_result is the actual result of the convolution, and scratch_memory is + // temporary memory used by cudnn. + // + // At the moment, we don't know how much scratch memory this conv is going to + // use, so we put u8[0] in this place. Later on another pass will choose + // which conv algorithm to use, and at that point we'll modify the shape of + // this second tuple element. + Shape call_shape = + ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U8, {0})}); + + // Our CustomCall takes three arguments: The conv lhs and rhs, and the cudnn + // algorithm to use. It's up to a later pass to choose the algorithm, so to + // indicate that we haven't yet made a choice, we speicfy -1 for that arg. + HloInstruction* negative_one = computation->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-1))); + HloInstruction* custom_call = + computation->AddInstruction(HloInstruction::CreateCustomCall( + call_shape, {lhs, rhs, negative_one}, call_target)); + custom_call->set_window(window); + custom_call->set_convolution_dimension_numbers(dnums); + return custom_call; } -bool ImplementedAsLibraryCall(const HloInstruction& hlo) { - return ImplementedAsGemm(hlo) || ImplementedAsDnnConvolution(hlo); +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvForwardCallTarget, shape, input, kernel, + window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardInputCallTarget, shape, output, + reverse_filter, window, dnums); +} + +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums) { + return CreateCudnnConv(kCudnnConvBackwardFilterCallTarget, shape, input, + output, window, dnums); } bool IsReductionToVector(const HloInstruction& reduce) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 06c3205296e4546e39525ec093cc17e2fc375d0d..59455f389e733fee2d6cace7486f919a0c5e834e 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -22,6 +22,9 @@ limitations under the License. #include "llvm/IR/Value.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +// TODO(jlebar): Move functions related to cublas/cudnn to a separate file; they +// don't belong in "ir_emission_utils". + namespace xla { namespace gpu { @@ -30,8 +33,85 @@ constexpr int64 kWarpSize = 32; // Returns true if `hlo` will be implemented as a call to BLAS gemm. bool ImplementedAsGemm(const HloInstruction& hlo); -// Returns true if `hlo` will be implemented as a call to cuDNN convolution. -bool ImplementedAsDnnConvolution(const HloInstruction& hlo); +// A call to cuDNN for batch normalization is represented as CustomCall HLO with +// a call target equal to one of these strings. +// +// The operands to and outputs of these calls are the same as those of the +// corresponding HLOs, except: +// +// - epsilon and feature_index are proper operands, at the end of the operands +// list. They must be HLO constants. +// - The cuDNN forward training call returns inv_stddev = +// 1/sqrt(variance + epsilon) in place of plain variance. +// - Similarly, BatchNormGrad accepts inv_stddev in place of the variance +// operand. +extern const char* const kCudnnBatchNormForwardInferenceCallTarget; +extern const char* const kCudnnBatchNormForwardTrainingCallTarget; +extern const char* const kCudnnBatchNormBackwardCallTarget; + +// Returns true if `hlo` will be implemented as a call to a cuDNN batch +// normalization routine. +// +// This returns true if `hlo` is a CustomCall HLO with a call target equal to +// one of the kCudnnBatchNormFoo constants above, but returns *false* for HLOs +// with one of the kBatchNorm opcodes, because these are lowered either to a +// sequence of generic HLOs or to a cuDNN CustomCall. +bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); + +// A call to cuDNN for convolution (forward, backward filter, or backward input) +// is represented as a CustomCall HLO with a call target equal to one of these +// strings. +// +// These CustomCalls have window() and convolution_dimension_numbers() set like +// regular convolution ops. They have the same LHS and RHS operands, plus two +// additional constant operands: an int64 operand for the cudnn algorithm and +// a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn +// algorithm means that the implementation is free to choose the best algorithm +// it can. +// +// These calls output a tuple (conv_result, scratch_memory), where conv_result +// is the actual result of the convolution, and scratch_memory is temporary +// memory used by cudnn. Callers shouldn't inspect scratch_memory, as its value +// is not well-defined. +// +// CudnnConvolutionRewriter lowers kConvolution HLOs to these custom calls. +// When it does so, it chooses algorithm -1 and 0 bytes of scratch space. Later +// on in the pipeline, CudnnConvolutionAlgorithmChooser chooses an explicit +// algorithm for each conv and sets the amount of scratch space needed. +// +// (Representing the scratch memory as an output may seem strange at first, but +// it's quite sensible, from a certain point of view. The scratch buffer is a +// location in memory that the conv can write into, but which it can't legally +// read from, at least until it's written something first. But that's exactly +// the definition of an output buffer.) +extern const char* const kCudnnConvForwardCallTarget; +extern const char* const kCudnnConvBackwardInputCallTarget; +extern const char* const kCudnnConvBackwardFilterCallTarget; + +// Returns true if `hlo` will be implemented as a call to a cuDNN convolution +// routine. +// +// This returns true if `hlo` is a CustomCall HLO with a call target equal to +// one of the kCudnnConvFoo constants above, but returns *false* for HLOs with a +// kConvolution opcode. +bool IsCustomCallToDnnConvolution(const HloInstruction& hlo); + +// Creates a CustomCall for a cudnn forward/backward-input/backward-filter conv. +// Note that these CustomCalls return a tuple (conv_result, scratch_memory). If +// you want just the conv result, you'll need to get-tuple-element the value +// returned by this function. +// +// The created cudnn call will use the default cudnn algorithm and no scratch +// space. +HloInstruction* CreateCudnnConvForward( + const Shape& shape, HloInstruction* input, HloInstruction* kernel, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardInput( + const Shape& shape, HloInstruction* output, HloInstruction* reverse_filter, + const Window& window, const ConvolutionDimensionNumbers& dnums); +HloInstruction* CreateCudnnConvBackwardFilter( + const Shape& shape, HloInstruction* input, HloInstruction* output, + const Window& window, const ConvolutionDimensionNumbers& dnums); // Returns true if `hlo` will be implemented as a library call, e.g. cuBLAS gemm // or cuDNN convolution. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 6e2bd4e11d3c4ff576edb0df3b724abebfc0e424..a3df67a87344d6ece2ea9047321ad9542c13f8cf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" @@ -173,7 +175,7 @@ Status IrEmitter::EmitCallToNestedComputation( return Status::OK(); } -bool IrEmitter::MaybeEmitSpecialAtomicOperation( +bool IrEmitter::MaybeEmitDirectAtomicOperation( const HloComputation& computation, llvm::Value* output_address, llvm::Value* source_address) { CHECK_EQ(2, computation.num_parameters()); @@ -233,102 +235,189 @@ bool IrEmitter::MaybeEmitSpecialAtomicOperation( return false; } -Status IrEmitter::EmitAtomicOperationForNestedComputation( - const HloComputation& computation, llvm::Value* output_address, - llvm::Value* source_address) { - if (computation.num_parameters() != 2) { - // TODO(b/30258929): We only accept binary computations so far. - return Unimplemented( - "We only support atomic functions with exactly two parameters, but " - "computation %s has %lld.", - computation.name().c_str(), computation.num_parameters()); - } - - if (MaybeEmitSpecialAtomicOperation(computation, output_address, - source_address)) { - return Status::OK(); - } +// Implements atomic binary operations using atomic compare-and-swap +// (atomicCAS) as follows: +// 1. Reads the value from the memory pointed to by output_address and +// records it as old_output. +// 2. Uses old_output as one of the source operand to perform the binary +// operation and stores the result in new_output. +// 3. Calls atomicCAS which implements compare-and-swap as an atomic +// operation. In particular, atomicCAS reads the value from the memory +// pointed to by output_address, and compares the value with old_output. If +// the two values equal, new_output is written to the same memory location +// and true is returned to indicate that the atomic operation succeeds. +// Otherwise, the new value read from the memory is returned. In this case, +// the new value is copied to old_output, and steps 2. and 3. are repeated +// until atomicCAS succeeds. +// +// On Nvidia GPUs, atomicCAS can only operate on 32 bit and 64 bit integers. If +// the element type of the binary operation is 32 bits or 64 bits, the integer +// type of the same size is used for the atomicCAS operation. On the other hand, +// if the element type is smaller than 32 bits, int32 is used for the atomicCAS +// operation. In this case, atomicCAS reads and writes 32 bit values from +// the memory, which is larger than the memory size required by the original +// atomic binary operation. We mask off the last two bits of the output_address +// and use the result as an address to read the 32 bit values from the memory. +// This can avoid out of bound memory accesses if tensor buffers are 4 byte +// aligned and have a size of 4N, an assumption that the runtime can guarantee. +// +// The pseudo code is shown below. Variables *_address are pointers to a memory +// region with a size equal to the size of the atomicCAS operation, with the +// exception that new_output_address is a pointer to a memory region with a size +// equal to the element size of the binary operation. +// +// element_size = sizeof(element_type); +// atomic_size = max(32, element_size); +// cas_new_output_address = alloca(atomic_size); +// cas_old_output_address = alloca(atomic_size); +// if (atomic_size != element_size) { +// atomic_address = output_address & ((int64)(-4)); +// new_output_address = cas_new_output_address + (output_address & 3); +// } else { +// atomic_address = output_address; +// new_output_address = cas_new_output_address; +// } +// +// *cas_old_output_address = *atomic_address; +// do { +// *cas_new_output_address = *cas_old_output_address; +// *new_output_address = operation(*new_output_address, *source_address); +// (*cas_old_output_address, success) = +// atomicCAS(atomic_address, *cas_old_output_address, +// *cas_new_output_address); +// } while (!success); +// +Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address) { + llvm::PointerType* output_address_type = + llvm::dyn_cast(output_address->getType()); + CHECK_NE(output_address_type, nullptr); + + // element_type is the data type for the binary operation. + llvm::Type* element_type = output_address_type->getPointerElementType(); + int element_size = llvm_ir::GetSizeInBits(element_type); + llvm::Type* element_address_type = element_type->getPointerTo(); + + int atomic_size = (element_size < 32) ? 32 : element_size; + llvm::Type* atomic_type = ir_builder_.getIntNTy(atomic_size); + llvm::Type* atomic_address_type = + atomic_type->getPointerTo(output_address_type->getPointerAddressSpace()); + + // cas_old_output_address and cas_new_output_address point to the scratch + // memory where we store the old and new values for the repeated atomicCAS + // operations. + llvm::Value* cas_old_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_old_output_address"); + llvm::Value* cas_new_output_address = ir_builder_.CreateAlloca( + atomic_type, /*ArraySize=*/nullptr, "cas_new_output_address"); - // Other binary computations can be made atomic as following (labels are basic - // block names used in the IR emitting code later). - // - // atomic_op_loop_preheader: - // ... - // source = *source_address; - // old_output = *output_address; - // do { - // atomic_op_loop_body_entry: - // new_output = computation(old_output, source); - // (old_output, success) = - // atomicCAS(output_address, old_output, new_output); - // } while (!success); - // - // atomic_op_loop_exit: - // ... - // - // TODO(jingyue): Consider encapsulate the logic of emitting control flow to - // something similar to llvm_ir::ForLoop. - // // Emit preparation code to the preheader. llvm::BasicBlock* loop_preheader_bb = ir_builder_.GetInsertBlock(); - llvm::Type* element_ir_type = - output_address->getType()->getPointerElementType(); - // old_output = *output_address; - llvm::Value* old_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "old_output_location"); - ir_builder_.CreateStore(ir_builder_.CreateLoad(output_address, "old_output"), - old_output_location); + + llvm::Value* atomic_memory_address; + // binop_output_address points to the scratch memory that stores the + // result of the binary operation. + llvm::Value* binop_output_address; + if (element_size < 32) { + // Assume the element size is an integer number of bytes. + CHECK_EQ((element_size % sizeof(char)), 0); + llvm::Type* address_int_type = + module_->getDataLayout().getIntPtrType(output_address_type); + atomic_memory_address = + ir_builder_.CreatePtrToInt(output_address, address_int_type); + llvm::Value* mask = llvm::ConstantInt::get(address_int_type, 3); + llvm::Value* offset = ir_builder_.CreateAnd(atomic_memory_address, mask); + mask = llvm::ConstantInt::get(address_int_type, -4); + atomic_memory_address = ir_builder_.CreateAnd(atomic_memory_address, mask); + atomic_memory_address = + ir_builder_.CreateIntToPtr(atomic_memory_address, atomic_address_type); + binop_output_address = ir_builder_.CreateAdd( + ir_builder_.CreatePtrToInt(cas_new_output_address, address_int_type), + offset); + binop_output_address = + ir_builder_.CreateIntToPtr(binop_output_address, element_address_type); + } else { + atomic_memory_address = + ir_builder_.CreateBitCast(output_address, atomic_address_type); + binop_output_address = + ir_builder_.CreateBitCast(cas_new_output_address, element_address_type); + } + + // Use the value from the memory that atomicCAS operates on to initialize + // cas_old_output. + llvm::Value* cas_old_output = + ir_builder_.CreateLoad(atomic_memory_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_old_output_address); + llvm::BasicBlock* loop_exit_bb = loop_preheader_bb->splitBasicBlock( ir_builder_.GetInsertPoint(), "atomic_op_loop_exit"); - - // Emit the body of the loop that repeatedly invokes atomicCAS. llvm::BasicBlock* loop_body_bb = llvm::BasicBlock::Create(ir_builder_.getContext(), "atomic_op_loop_body", ir_builder_.GetInsertBlock()->getParent()); ir_builder_.SetInsertPoint(loop_body_bb); // Change preheader's successor from loop_exit_bb to loop_body_bb. loop_preheader_bb->getTerminator()->setSuccessor(0, loop_body_bb); - // new_output = computation(old_output, source); - llvm::Value* new_output_location = ir_builder_.CreateAlloca( - element_ir_type, /*ArraySize=*/nullptr, "new_output_location"); + + // Emit the body of the loop that repeatedly invokes atomicCAS. + // + // Use cas_old_output to initialize cas_new_output. + cas_old_output = + ir_builder_.CreateLoad(cas_old_output_address, "cas_old_output"); + ir_builder_.CreateStore(cas_old_output, cas_new_output_address); + // Emits code to calculate new_output = operation(old_output, source); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( - computation, {old_output_location, source_address}, new_output_location)); - - // (old_output, success) = atomicCAS(output_address, old_output, new_output); - int num_bits = llvm_ir::GetSizeInBits(element_ir_type); - llvm::Type* element_int_ir_type = ir_builder_.getIntNTy(num_bits); - // cmpxchg accepts integer only, and bitcast refuses to operate on aggregate - // types, so we bitcast load and store addresses to intN* of the same bit - // width. - llvm::Value* old_output = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(old_output_location, - element_int_ir_type->getPointerTo()), - "old_output"); - llvm::Value* new_output = ir_builder_.CreateLoad( - ir_builder_.CreateBitCast(new_output_location, - element_int_ir_type->getPointerTo()), - "new_output"); + computation, {binop_output_address, source_address}, + binop_output_address)); + + llvm::Value* cas_new_output = + ir_builder_.CreateLoad(cas_new_output_address, "cas_new_output"); + + // Emit code to perform the atomicCAS operation + // (cas_old_output, success) = atomicCAS(memory_address, cas_old_output, + // cas_new_output); llvm::Value* ret_value = ir_builder_.CreateAtomicCmpXchg( - ir_builder_.CreateBitCast(output_address, - element_int_ir_type->getPointerTo()), - old_output, new_output, llvm::AtomicOrdering::SequentiallyConsistent, + atomic_memory_address, cas_old_output, cas_new_output, + llvm::AtomicOrdering::SequentiallyConsistent, llvm::AtomicOrdering::SequentiallyConsistent); - // cmpxchg returns a pair. The first element is the original value at - // output_address and the second element is whether the swap is successful. + + // Extract the memory value returned from atomicCAS and store it as + // cas_old_output. ir_builder_.CreateStore( - ir_builder_.CreateExtractValue(ret_value, 0, "old_output"), - ir_builder_.CreateBitCast(old_output_location, - element_int_ir_type->getPointerTo())); + ir_builder_.CreateExtractValue(ret_value, 0, "cas_old_output"), + cas_old_output_address); + // Extract the success bit returned from atomicCAS and generate a + // conditional branch on the success bit. ir_builder_.CreateCondBr( ir_builder_.CreateExtractValue(ret_value, 1, "success"), loop_exit_bb, loop_body_bb); - // Restore the insertion point to the exit basic block so that the caller of + // Set the insertion point to the exit basic block so that the caller of // this method can continue emitting code to the right place. SetToFirstInsertPoint(loop_exit_bb, &ir_builder_); return Status::OK(); } +Status IrEmitter::EmitAtomicOperationForNestedComputation( + const HloComputation& computation, llvm::Value* output_address, + llvm::Value* source_address) { + if (computation.num_parameters() != 2) { + // TODO(b/30258929): We only accept binary computations so far. + return Unimplemented( + "We only support atomic functions with exactly two parameters, but " + "computation %s has %lld.", + computation.name().c_str(), computation.num_parameters()); + } + + if (MaybeEmitDirectAtomicOperation(computation, output_address, + source_address)) { + return Status::OK(); + } + + return EmitAtomicOperationUsingCAS(computation, output_address, + source_address); +} + Status IrEmitter::HandleSelect(HloInstruction* select) { auto pred = select->operand(0); auto on_true = select->operand(1); @@ -518,10 +607,17 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution) { "Hit a case for convolution that is not implemented on GPU."); } +Status IrEmitter::HandleFft(HloInstruction* fft) { + if (ShapeUtil::HasZeroElements(fft->shape())) { + // Emit no code for an empty output. + return Status::OK(); + } + return Unimplemented("Hit a case for fft that is not implemented on GPU."); +} + Status IrEmitter::HandleCrossReplicaSum(HloInstruction* crs) { // TODO(b/33011107): Support cross replica sum on GPU. - return Unimplemented( - "Cross replica sum not implemented on GPU. See b/33011107."); + return Unimplemented("CrossReplicaSum is not implemented on GPU."); } Status IrEmitter::HandleParameter(HloInstruction* parameter) { @@ -615,11 +711,13 @@ Status IrEmitter::HandleCustomCall(HloInstruction*) { } Status IrEmitter::HandleInfeed(HloInstruction*) { - return Unimplemented("Infeed is not supported on GPU (b/30467474)."); + // TODO(b/30467474): Implement infeed on GPU. + return Unimplemented("Infeed is not supported on GPU."); } Status IrEmitter::HandleOutfeed(HloInstruction*) { - return Unimplemented("Outfeed is not supported on GPU (b/34359662)."); + // TODO(b/34359662): Implement outfeed on GPU. + return Unimplemented("Outfeed is not supported on GPU."); } Status IrEmitter::HandleRng(HloInstruction* random) { @@ -640,6 +738,29 @@ Status IrEmitter::HandleRng(HloInstruction* random) { .EmitLoop(IrName(random)); } +Status IrEmitter::HandleBatchNormInference(HloInstruction*) { + return Unimplemented( + "The GPU backend does not implement BatchNormInference directly. It " + "should be lowered before IR emission to HLO-soup using " + "BatchNormRewriter or to a cudnn CustomCall using " + "CudnnBatchNormRewriter."); +} + +Status IrEmitter::HandleBatchNormTraining(HloInstruction*) { + return Unimplemented( + "The GPU backend does not implement BatchNormTraining directly. It " + "should be lowered before IR emission to HLO-soup using " + "BatchNormRewriter or to a cudnn CustomCall using " + "CudnnBatchNormRewriter."); +} + +Status IrEmitter::HandleBatchNormGrad(HloInstruction*) { + return Unimplemented( + "The GPU backend does not implement BatchNormGrad directly. It should " + "be lowered before IR emission to HLO-soup (using BatchNormRewriter) or " + "to a cudnn CustomCall using CudnnBatchNormRewriter."); +} + llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( const llvm_ir::IrArray& operand_array, int64 reduction_dimension, tensorflow::StringPiece name_suffix, llvm_ir::ForLoopNest* loop_nest) { @@ -648,8 +769,8 @@ llvm_ir::IrArray::Index IrEmitter::EmitOperandArrayLoopNest( // reduction dimension. std::vector dimensions; const Shape& shape = operand_array.GetShape(); - for (int i = shape.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape.layout().minor_to_major(i); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape.layout(), i); if (dimension != reduction_dimension) { dimensions.push_back(dimension); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 9c01f5b7c72f429822300af28bfd5261150d33d1..b0accc08d479258d65a18202122e4c9e90ff78d0 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -13,19 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An XLA HLO graph may contain multiple computations. These computations -// fall into two types, nested and unnested. We translate each nested -// computation (e.g. the computation operand of a Map operator) to a device -// function. For each unnested computation composed of top-level -// HloInstructions, we generate a CUDA kernel for each HloInstruction. -// -// This file declares classes that translate an XLA HLO graph to LLVM IR for -// GPUs. IrEmitterNested emits LLVM IR for nested computations, and -// IrEmitterUnnested for unnested computations. The logic of emitting LLVM IR -// for each individual HloInstruction is largely the same between these two -// classes. Therefore, we implement the common logic in the Handle* functions in -// the superclass IrEmitter. - #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_H_ @@ -60,25 +47,35 @@ limitations under the License. namespace xla { namespace gpu { -// This class is the top-level API for the XLA HLO --> LLVM IR compiler. -// It implements the DfsHloVisitor interface and emits an LLVM IR program that -// implements the input HLO graph. +// Abstract base class for translating HLO graphs to LLVM IR for a GPU. +// +// There are two concrete subclasses of IrEmitter: IrEmitterNested and +// IrEmitterUnnested. In the unnested variety, each HLO gets its own kernel +// function, whereas in the nested version the whole computation is emitted as +// one *non-kernel* function. +// +// In XLA, kernel functions never call other kernel functions. This means that +// if we have a kernel -- e.g. implementing a kReduce HLO -- that wants to use +// an HLO computation as a "subroutine" -- e.g. the HLO computation that +// specifies how to reduce two elements -- then the subroutine computation must +// be emitted using IrEmitterNested. // -// Note: if `T` is a subclass of `IrEmitter` and a handler is not overridden in -// either `IrEmitter` or `T`, the handler in `DfsHloVisitorWithDefault` -// calls `T::DefaultAction`. +// Fusion nodes are a special case. A fusion node is emitted using +// IrEmitterUnnested, but the code is generated using FusedIrEmitter, which is +// not a subclass of gpu::IrEmitter, and in fact is better understood as an IR +// generator generator. See comments on that class. class IrEmitter : public DfsHloVisitorWithDefault { public: IrEmitter(const IrEmitter&) = delete; IrEmitter& operator=(const IrEmitter&) = delete; - // The following methods implement the DfsHloVisitorWithDefault interface. Status DefaultAction(HloInstruction* hlo) override; Status HandleConstant(HloInstruction* constant) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleDot(HloInstruction* dot) override; Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; @@ -95,6 +92,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleRng(HloInstruction* random) override; + Status HandleBatchNormInference(HloInstruction* batch_norm) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -185,9 +185,16 @@ class IrEmitter : public DfsHloVisitorWithDefault { // be simply implemented using an LLVM atomic instruction. If "computation" is // one of this kind, emits code to do that and returns true; otherwise, // returns false. - bool MaybeEmitSpecialAtomicOperation(const HloComputation& computation, - llvm::Value* output_address, - llvm::Value* source_address); + bool MaybeEmitDirectAtomicOperation(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); + + // A helper method for EmitAtomicOperationForNestedComputation. It implements + // binary atomic operations using atomicCAS with special handling to support + // small data types. + Status EmitAtomicOperationUsingCAS(const HloComputation& computation, + llvm::Value* output_address, + llvm::Value* source_address); StatusOr ComputeNestedElement( const HloComputation& computation, @@ -206,185 +213,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { std::map computation_to_ir_function_; }; -// Emits LLVM IR for unnested computations. Each HloInstruction is translated to -// a separate CUDA kernel. These kernels are inserted into the resultant module -// sorted in reverse postorder of the XLA HLO graph. -class IrEmitterUnnested : public IrEmitter { - public: - IrEmitterUnnested(const HloModuleConfig& hlo_module_config, - const HloComputation* hlo_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterUnnested(const IrEmitterUnnested&) = delete; - IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; - - // Transfers the ownship of thunk_sequence_ out. - std::unique_ptr ConsumeThunkSequence() { - return std::move(thunk_sequence_); - } - - Status DefaultAction(HloInstruction* hlo) override; - - // IrEmitterUnnested handles the following instructions differently from - // IrEmitter. - Status HandleCopy(HloInstruction* copy) override; - Status HandleConvolution(HloInstruction* convolution) override; - Status HandleDot(HloInstruction* dot) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleReduce(HloInstruction* reduce) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleTuple(HloInstruction* tuple) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status HandleInfeed(HloInstruction* xla_infeed) override; - Status HandleRng(HloInstruction* random) override; - Status HandleSelect(HloInstruction* select) override; - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - // Same as `EmitTargetElementLoop`, but in given `thunk` rather than - // `LastThunk()`. - Status EmitTargetElementLoopInThunk( - const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, - KernelThunk* thunk); - - private: - // Builds the appropriate thunk for the instruction hlo and returns the owning - // pointer to it. The caller needs to make sure `inst` outlives the lifetime - // of the returned Thunk object. - std::unique_ptr BuildThunk(const HloInstruction* hlo); - - // Builds the prototype of the IR kernel for `inst` and adds it to the module. - llvm::Function* BuildKernelPrototype( - const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos); - - // Emits the base pointers for `hlo` and its operands. `io_hlos` will store - // all input/output HLOs among `hlo` and its operands. - llvm::Function* EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos); - - // EmitColumnReduction and EmitRowReduction emit code for column and row - // reduction of a matrix and/or 3D tensor. Row and column reduction have - // different memory access pattern, so for performance their implementations - // are significantly different. - // - // Emits code that reduces a matrix of shape [height x width] to a vector of - // [width]. Other parameters have the same meaning as those of - // `EmitReductionToVector`. Note that input shape might not be - // [height x width], but can be bitcast to [height x weight] with "height" - // being the major dimension. - Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, - const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Emits code that reduces a 3D tensor of shape [depth x height x width] to a - // vector of shape [height]. Other parameters have the same meaning as those - // of `EmitReductionToVector`. Note that input shape might not be - // [depth x height x width], but can be bitcast to [depth x height x weight] - // with "depth" being the most major dimension. - Status EmitRowReduction(int64 depth, int64 height, int64 width, - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - HloComputation* reducer); - - // Figures out whether `reduce` is a row or column reduction, and which - // dimensions to reduce, and calls either `EmitRowReduction` or - // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the - // input array, which is the operand of the Reduce instruction if unfused or - // of the Fusion instruction if fused. `input_gen` and `init_value_gen` - // generate elements of the input and the initial value. Other parameters mean - // the same as for `HandleReduce`. - // - // Prerequisite: `IsReductionToVector(*reduce)` - Status EmitReductionToVector( - HloInstruction* reduce, const Shape& input_shape, - const llvm_ir::ElementGenerator& input_gen, - const llvm_ir::ElementGenerator& init_value_gen, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer); - - // Emits code to initialize buffer of `inst` in given `thunk`. - Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); - - // Returns a KernelThunk that invokes the kernel emitted for `inst`. The - // caller needs to make sure `inst` outlives the lifetime of the returned - // Thunk object. - std::unique_ptr BuildKernelThunk(const HloInstruction* inst); - - // Returns a ConvolutionThunk that calls DNN to implement `inst`. - std::unique_ptr BuildConvolutionThunk(const HloInstruction* inst); - - // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs - // to make sure `inst` outlives the lifetime of the returned Thunk object. - std::unique_ptr BuildGemmThunk(const HloInstruction* inst); - - // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); - - // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildDeviceToDeviceCopyThunk( - const HloInstruction* inst); - - // Returns an InfeedThunk that performs device-to-device memcpy to implement - // `inst`. - std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); - - // Returns a WhileThunk that invokes thunk sequences for 'condition' and - // 'body' sub-computations of while instruction 'hlo'. - std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); - - // Returns a ForThunk which executes 'loop_limit' invocations of a thunk - // sequence from the 'body' sub-computation of the while instruction 'hlo'. - std::unique_ptr BuildForThunk(const HloInstruction* hlo, - const int64 loop_limit); - - Status Postprocess(HloInstruction* hlo) override; - - // Returns the last generated thunk. - Thunk* LastThunk() const { return thunk_sequence_->back().get(); } - - // The thunk sequence this IrEmitter generates for the input computation. - std::unique_ptr thunk_sequence_; - - // The HloComputation that this IrEmitter emits code for. - const HloComputation* hlo_computation_; -}; - -// Emits LLVM IR for a nested computation to the resultant function. -class IrEmitterNested : public IrEmitter { - public: - // Constructs an LLVM IR emitter for a nested HLO computation. `function` is - // the containing IR function this emitter produces IR to. See - // IrEmitter::IrEmitter for the meanings of other arguments. - IrEmitterNested(const HloModuleConfig& hlo_module_config, - const HloComputation& nested_computation, - IrEmitterContext* ir_emitter_context); - IrEmitterNested(const IrEmitterNested&) = delete; - IrEmitterNested& operator=(const IrEmitterNested&) = delete; - - // Overrides the default empty implementation. Binds the given instruction - // "parameter" with the parameter of the IR function. - Status HandleParameter(HloInstruction* parameter) override; - - llvm::Function* GetEmittedFunction() const { return emitted_function_; } - - Status EmitTargetElementLoop( - const HloInstruction& hlo, - const llvm_ir::ElementGenerator& body_emitter) override; - - private: - llvm::Function* EmitBasePointersForNestedComputation( - const HloComputation& nested_computation, - std::vector* io_hlos); - - llvm::Function* emitted_function_; -}; - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 5225ff36ff3a8a1b049479c34aa301de8724f73e..71aada080ae8df70bffce3e1854b5fbd833efd23 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -16,12 +16,13 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h" + #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h new file mode 100644 index 0000000000000000000000000000000000000000..ca11cf2c182b0600b931b19d2d7fb3983e36441a --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.h @@ -0,0 +1,72 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ + +#include "llvm/IR/Function.h" +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for a "nested computation" into a non-kernel device function. +// +// This is used to emit code for HloComputations that don't require a separate +// kernel call. For example, IrEmitterNested is used to emit code for a kReduce +// HLO's elementwise reduction computation. Notably, IrEmitterNested is *not* +// used to emit code for fusion nodes -- fusion nodes use FusedIrEmitter, which +// is a different beast altogether. +// +// IrEmitterNested generates a non-kernel function with the following +// parameters: +// +// - N pointers to the buffers of each of the N parameters to the computation, +// - a pointer to the output buffer of the computation, and +// - a pointer to the top-level temp buffer. +// +class IrEmitterNested : public IrEmitter { + public: + // Constructs an LLVM IR emitter for a nested HLO computation. `function` is + // the containing IR function this emitter produces IR to. See + // IrEmitter::IrEmitter for the meanings of other arguments. + IrEmitterNested(const HloModuleConfig& hlo_module_config, + const HloComputation& nested_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterNested(const IrEmitterNested&) = delete; + IrEmitterNested& operator=(const IrEmitterNested&) = delete; + + // Overrides the default empty implementation. Binds the given instruction + // "parameter" with the parameter of the IR function. + Status HandleParameter(HloInstruction* parameter) override; + + llvm::Function* GetEmittedFunction() const { return emitted_function_; } + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + private: + llvm::Function* EmitBasePointersForNestedComputation( + const HloComputation& nested_computation, + std::vector* io_hlos); + + llvm::Function* emitted_function_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_NESTED_H_ diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1b863c9e3c51d6e757751154abd653cd1fdcb8a7..aa2a0a9800bab142481e1def785c9052526fcd8c 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -17,6 +17,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h" + #include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" @@ -28,14 +30,18 @@ limitations under the License. #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" +#include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h" #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h" #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_runner.h" +#include "tensorflow/compiler/xla/service/gpu/fft_thunk.h" #include "tensorflow/compiler/xla/service/gpu/for_thunk.h" #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h" #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" @@ -69,6 +75,10 @@ namespace gpu { namespace { using llvm_ir::IrName; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; +using tensorflow::strings::StrCat; // If a dimensions is smaller than this, untiled transposition may be more // efficient. @@ -123,12 +133,46 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get( llvm::IntegerType::get(llvm_context, /*NumBits=*/32), launch_dims.threads_per_block()); + // Our launch bounds are exact, so we can specify them as reqntidx rather than + // maxntidx. nvvm_annotations_node->addOperand(llvm::MDNode::get( llvm_context, {llvm::ConstantAsMetadata::get(ir_kernel), - llvm::MDString::get(llvm_context, "maxntidx"), + llvm::MDString::get(llvm_context, "reqntidx"), llvm::ConstantAsMetadata::get(threads_per_block_ir_value)})); } + +// Tries to get a Slice for the given instruction at the given index, but +// returns nullopt if we might not know the slice's address at runtime without +// dereferencing a containing tuple. +// +// In particular, when XLA accepts a parameter of tuple type, the caller has the +// option of telling XLA what are the values inside of the tuple, or just giving +// XLA a pointer to the top-level tuple and letting us chase the pointers on the +// GPU. We therefore cannot rely having these pointers to parameter sub-buffers +// being present when we run the program. +optional GetKnownAtRuntimeSlice( + const HloInstruction* instr, const ShapeIndex& index, + const BufferAssignment& buffer_assn) { + auto maybe_slice = buffer_assn.GetUniqueSlice(instr, index); + if (!maybe_slice.ok()) { + return nullopt; + } + // BufferAllocation gives a slice and alloc to every buffer accessed by XLA, + // but we don't necessarily know the runtime address of sub-buffers of input + // parameters. + const BufferAllocation::Slice& slice = maybe_slice.ValueOrDie(); + const BufferAllocation* alloc = slice.allocation(); + if (alloc->IsInputOrOutput() && !alloc->maybe_live_out() && + !alloc->param_shape_index().empty()) { + return nullopt; + } + + // Otherwise, we will know the address of this slice at runtime without having + // to dereference a tuple. + return slice; +} + } // namespace IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, @@ -146,16 +190,20 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } namespace { -bool ImplementedAsHostToDeviceMemcpy(const HloInstruction& hlo) { - // `hlo` needs to satisfy three conditions to be implemented as a +bool ImplementedAsHostToDeviceMemcpy(const BufferAssignment& buffer_assignment, + const HloInstruction& hlo) { + // `hlo` needs to satisfy the following conditions to be implemented as a // host-to-device cuMemcpy. // // 1. `hlo` is a kCopy instruction. // 2. `hlo`'s only operand is a kConstant instruction. // 3. `hlo` and its operand have the same shape (thus the same layout too). + // 4. The address of `hlo`'s buffer is known at runtime (without dereferencing + // pointers in a tuple). return hlo.opcode() == HloOpcode::kCopy && hlo.operand(0)->opcode() == HloOpcode::kConstant && - ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()); + ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && + GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value(); } bool ImplementedAsDeviceToDeviceMemcpy( @@ -169,52 +217,50 @@ bool ImplementedAsDeviceToDeviceMemcpy( // instance) which means the source buffer also resides on the device. return hlo.opcode() == HloOpcode::kCopy && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && - buffer_assignment.HasTopLevelAllocation(hlo.operand(0)); + GetKnownAtRuntimeSlice(&hlo, {}, buffer_assignment).has_value() && + GetKnownAtRuntimeSlice(hlo.operand(0), {}, buffer_assignment) + .has_value(); } } // namespace llvm::Function* IrEmitterUnnested::BuildKernelPrototype( const HloInstruction& inst, - tensorflow::gtl::ArraySlice escaped_hlos) { + tensorflow::gtl::ArraySlice args) { // Compute the kernel name. The opcode string may contain "-" which cannot be // in a PTX function name, so sanitize the name before uniquifying it. string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName( llvm_ir::SanitizeFunctionName(inst.name())); - // Create the kernel and adds it to the module. + // Create the kernel and add it to the module. llvm::Module* module = ir_emitter_context_->llvm_module(); llvm::LLVMContext& context = module->getContext(); - int num_escaped_hlos = escaped_hlos.size(); llvm::FunctionType* kernel_type = llvm::FunctionType::get( - llvm::Type::getVoidTy(context), // The type of function result. - std::vector(num_escaped_hlos + 1, - ir_builder_.getInt8PtrTy()), - false); // Not a variadic argument function. + /*Result=*/llvm::Type::getVoidTy(context), + std::vector(args.size(), ir_builder_.getInt8PtrTy()), + /*isVarArg=*/false); llvm::Function* kernel = llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage, kernel_name.c_str(), module); - // Add dereferenceable information to each of the escaped HLO parameters. - for (size_t arg_no = 0; arg_no < escaped_hlos.size(); ++arg_no) { - const HloInstruction* escaped_hlo = escaped_hlos[arg_no]; - const Shape& escaped_hlo_shape = escaped_hlo->shape(); - int64 escaped_hlo_size = llvm_ir::ByteSizeOf( - escaped_hlo_shape, ir_emitter_context_->llvm_module()->getDataLayout()); - kernel->addDereferenceableAttr(arg_no + 1, escaped_hlo_size); - } - - // The last argument is a pointer to the temporary buffer memory block. - // We know that it doesn't alias any of the escaped arguments (the inputs + - // the result). We also know how many bytes can be dereferenced in it. - const llvm::Argument& temp_buffer = *std::prev(kernel->arg_end()); - int64 temp_buffer_arg_no = temp_buffer.getArgNo(); - int64 temp_allocation_total_size = - ir_emitter_context_->buffer_assignment().temp_allocation_total_size(); - if (temp_allocation_total_size != 0) { - kernel->addDereferenceableAttr(temp_buffer_arg_no + 1, - temp_allocation_total_size); + // Add dereferenceable and alignment information to each of the kernel's + // parameters. + auto arg_it = kernel->arg_begin(); + for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) { + const BufferAllocation* alloc = args[arg_no]; + llvm::Argument* fn_arg = &*arg_it; + ++arg_it; + + kernel->addDereferenceableAttr(arg_no + 1, alloc->size()); + kernel->addParamAttr( + arg_no, llvm::Attribute::get(context, llvm::Attribute::Alignment, + kCudaMallocAlignBytes)); + + if (alloc->IsPreallocatedTempBuffer()) { + fn_arg->setName("temp_buf"); + } else { + fn_arg->setName(llvm_ir::AsStringRef(StrCat("alloc", alloc->index()))); + } } - kernel->addAttribute(temp_buffer_arg_no + 1, llvm::Attribute::NoAlias); // TODO(b/65380986): Investigate if adding fast math flags for generated // kernels makes sense. @@ -230,10 +276,9 @@ llvm::Function* IrEmitterUnnested::BuildKernelPrototype( // Update the insert point to the entry basic block. llvm::BasicBlock* entry_bb = - llvm::BasicBlock::Create(context, - "entry", // The name of the basic block. - kernel); // The parent/owner of "entry_bb". - // Emit a "return void" at entry_bb's end, and sets the insert point before + llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel); + + // Emit a "return void" at entry_bb's end, and set the insert point before // that return instruction. ir_builder_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb)); @@ -246,6 +291,11 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + if (dnums.lhs_batch_dimensions_size() > 0 || + dnums.rhs_batch_dimensions_size() > 0) { + return Unimplemented("Dot with batch dimensions not implemented."); + } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); @@ -254,15 +304,191 @@ Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { return IrEmitter::HandleDot(dot); } +Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) { + thunk_sequence_->emplace_back(BuildConditionalThunk(conditional)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { - if (ImplementedAsDnnConvolution(*convolution)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); return IrEmitter::HandleConvolution(convolution); } +Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) { + // A CustomCall on the GPU backend can either be a custom-call to a + // user-supplied kernel, or a call into a library like cudnn. + + // Lower custom-calls to cudnn batchnorm ops to specialized thunks. It's part + // of the contract of these cudnn batchnorm calls that the epsilon and + // feature_index operands be constants. + if (custom_call->custom_call_target() == + kCudnnBatchNormForwardInferenceCallTarget) { + const HloInstruction* epsilon = custom_call->operand(5); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(6); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + thunk_sequence_->emplace_back( + MakeUnique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*offset=*/GetAllocationSlice(*custom_call->operand(2)), + /*mean=*/GetAllocationSlice(*custom_call->operand(3)), + /*variance=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (custom_call->custom_call_target() == + kCudnnBatchNormForwardTrainingCallTarget) { + const HloInstruction* epsilon = custom_call->operand(3); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(4); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + // BatchNormTraining returns a tuple of three elements: data, calculated + // mean, and calculated 1/sqrt(variance + epsilon). + const auto& assn = ir_emitter_context_->buffer_assignment(); + auto output_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto output_mean = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + auto output_inv_stddev = assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); + thunk_sequence_->emplace_back( + MakeUnique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*offset=*/GetAllocationSlice(*custom_call->operand(2)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_data=*/output_data, + /*output_mean=*/output_mean, + /*output_inv_stddev=*/output_inv_stddev, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (custom_call->custom_call_target() == kCudnnBatchNormBackwardCallTarget) { + const HloInstruction* epsilon = custom_call->operand(5); + CHECK(epsilon->IsConstant()); + float epsilon_value = epsilon->literal().Get({}); + + const HloInstruction* feature_index = custom_call->operand(6); + CHECK(feature_index->IsConstant()); + int64 feature_index_value = feature_index->literal().Get({}); + + // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale, + // grad_offset. + const auto& assn = ir_emitter_context_->buffer_assignment(); + auto output_grad_data = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto output_grad_scale = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + auto output_grad_offset = + assn.GetUniqueSlice(custom_call, {2}).ValueOrDie(); + thunk_sequence_->emplace_back(MakeUnique( + /*operand=*/GetAllocationSlice(*custom_call->operand(0)), + /*scale=*/GetAllocationSlice(*custom_call->operand(1)), + /*mean=*/GetAllocationSlice(*custom_call->operand(2)), + /*inv_stddev=*/GetAllocationSlice(*custom_call->operand(3)), + /*grad_output=*/GetAllocationSlice(*custom_call->operand(4)), + /*epsilon=*/epsilon_value, + /*feature_index=*/feature_index_value, + /*output_grad_data=*/output_grad_data, + /*output_grad_scale=*/output_grad_scale, + /*output_grad_offset=*/output_grad_offset, + /*output_tuple=*/GetAllocationSlice(*custom_call), + /*hlo=*/custom_call)); + return Status::OK(); + } + + if (IsCustomCallToDnnConvolution(*custom_call)) { + const auto& assn = ir_emitter_context_->buffer_assignment(); + const auto& lhs_shape = custom_call->operand(0)->shape(); + const auto& rhs_shape = custom_call->operand(1)->shape(); + const auto& conv_result_shape = custom_call->shape().tuple_shapes(0); + auto lhs_slice = GetAllocationSlice(*custom_call->operand(0)); + auto rhs_slice = GetAllocationSlice(*custom_call->operand(1)); + auto tuple_result_slice = GetAllocationSlice(*custom_call); + auto conv_result_slice = assn.GetUniqueSlice(custom_call, {0}).ValueOrDie(); + auto scratch_slice = assn.GetUniqueSlice(custom_call, {1}).ValueOrDie(); + + const HloInstruction* algorithm_inst = custom_call->operand(2); + CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString(); + int64 algorithm = algorithm_inst->literal().Get({}); + + const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3); + CHECK(tensor_ops_enabled_inst->IsConstant()) + << tensor_ops_enabled_inst->ToString(); + bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get({}); + + const auto& target = custom_call->custom_call_target(); + std::unique_ptr thunk; + if (target == kCudnnConvForwardCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kForward, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/conv_result_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/conv_result_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, tensor_ops_enabled, custom_call); + } else if (target == kCudnnConvBackwardInputCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardInput, + /*input_buffer=*/conv_result_slice, + /*filter_buffer=*/rhs_slice, + /*output_buffer=*/lhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/conv_result_shape, + /*filter_shape=*/rhs_shape, + /*output_shape=*/lhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, tensor_ops_enabled, custom_call); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + thunk = MakeUnique( + CudnnConvKind::kBackwardFilter, + /*input_buffer=*/lhs_slice, + /*filter_buffer=*/conv_result_slice, + /*output_buffer=*/rhs_slice, + /*tuple_result_buffer=*/tuple_result_slice, + /*scratch_buffer=*/scratch_slice, + /*input_shape=*/lhs_shape, + /*filter_shape=*/conv_result_shape, + /*output_shape=*/rhs_shape, // + custom_call->window(), custom_call->convolution_dimension_numbers(), + algorithm, tensor_ops_enabled, custom_call); + } else { + LOG(FATAL) << "Unexpected custom call target: " + << custom_call->custom_call_target(); + } + + thunk_sequence_->emplace_back(std::move(thunk)); + return Status::OK(); + } + + return IrEmitter::HandleCustomCall(custom_call); +} + +Status IrEmitterUnnested::HandleFft(HloInstruction* fft) { + TF_RET_CHECK( + LayoutUtil::IsMonotonicWithDim0Major(fft->operand(0)->shape().layout())); + TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(fft->shape().layout())); + thunk_sequence_->emplace_back(BuildFftThunk(fft)); + return Status::OK(); +} + Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); // HandleFusion specializes reduction from a multi-dimensional array to a 1D @@ -372,10 +598,6 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); return Status::OK(); } - if (ImplementedAsDnnConvolution(*fusion)) { - thunk_sequence_->emplace_back(BuildConvolutionThunk(fusion)); - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); return IrEmitter::HandleFusion(fusion); } @@ -407,8 +629,8 @@ Shape MergeDimensions(tensorflow::gtl::ArraySlice segs, (segs.size() == i ? shape.dimensions().size() : segs[i]), 1, std::multiplies())); } - return ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), - dimensions); + return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(), + dimensions); } // Returns whether the given shapes and permutation are a 0-2-1 transpose, and @@ -421,20 +643,22 @@ std::tuple IsTranspose021(const Shape& a, const Shape& b) { CHECK(ShapeUtil::Compatible(a, b)); std::vector perm(a.dimensions().size()); { - std::vector layout_a(a.layout().minor_to_major().rbegin(), - a.layout().minor_to_major().rend()); - std::vector layout_b(b.layout().minor_to_major().rbegin(), - b.layout().minor_to_major().rend()); + auto layout_a_orig = LayoutUtil::MinorToMajor(a); + std::vector layout_a(layout_a_orig.rbegin(), layout_a_orig.rend()); + auto layout_b_orig = LayoutUtil::MinorToMajor(b); + std::vector layout_b(layout_b_orig.rbegin(), layout_b_orig.rend()); for (size_t i = 0; i < perm.size(); ++i) { perm[i] = PositionInContainer(layout_b, layout_a[i]); } } auto segs = ConsecutiveSegments(perm); - Shape norm_a = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a); - Shape norm_b = ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b); + Shape norm_a = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a); + Shape norm_b = + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(b); if (3 == segs.size() && 0 == perm[0]) { Shape reduced_a = MergeDimensions(segs, norm_a); - Shape reduced_b = ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + Shape reduced_b = ShapeUtil::MakeShapeWithDescendingLayout( b.element_type(), Permute({0, 2, 1}, AsInt64Slice(reduced_a.dimensions()))); return std::make_tuple(true, reduced_a, reduced_b); @@ -448,10 +672,11 @@ std::tuple IsTranspose021(const Shape& a, const Shape& b) { bool AreShapesForTranspose021(const Shape& a, const Shape& b) { return 3 == b.dimensions().size() && ShapeUtil::Compatible( - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(a), + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a), ShapeUtil::PermuteDimensions( {0, 2, 1}, - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(b))); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + b))); } // Emits a tiled 0-2-1 transpose, assuming both input and output lain out from @@ -483,9 +708,11 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, CHECK(AreShapesForTranspose021(input.GetShape(), output.GetShape())); Shape input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input.GetShape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input.GetShape()); Shape output_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(output.GetShape()); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + output.GetShape()); input = input.CastToShape(input_shape, builder); output = output.CastToShape(output_shape, builder); @@ -603,7 +830,7 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, builder))), builder->getInt64Ty(), /*isSigned=*/true, "block.id.x"), - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( + ShapeUtil::MakeShapeWithDescendingLayout( PRED /*arbitrary*/, AsInt64Slice(input_dims_in_tiles)), builder); const llvm_ir::IrArray::Index input_tile_origin = ({ @@ -672,7 +899,8 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsHostToDeviceMemcpy(*copy)) { + if (ImplementedAsHostToDeviceMemcpy(ir_emitter_context_->buffer_assignment(), + *copy)) { thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); return Status::OK(); } @@ -706,6 +934,194 @@ Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { return IrEmitter::HandleCopy(copy); } +Status IrEmitterUnnested::EmitReductionToScalar( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, HloComputation* reducer) { + // Number of elements processed by a single thread. + constexpr int64 kTileSize = 16; + int64 num_elems = ShapeUtil::ElementsIn(input_shape); + + // Round up the number of tiles to a multiple of the warp size. This is + // necessary for correctness. We launch one thread per tile, and if the + // number of threads isn't a multiple of the number of the warp size, our + // shuffles will read from inactive threads, producing undefined values. + int64 num_tiles = + RoundUpToNearest(CeilOfRatio(num_elems, kTileSize), kWarpSize); + + // Check whether every thread will process a full tile's worth of elements + // without reading outside the bounds of the input. If this is true, we can + // skip some bounds checks in the final algorithm. + bool all_threads_in_bounds = num_tiles * kTileSize == num_elems; + + // __global__ void full_reduce_kernel() { + // x_in_tiles = threadIdx.x + blockIdx.x * blockDim.x; + // x = x_in_tiles * kTileSize; + // + // partial_result = init_value; + // if (all_threads_in_bounds || x + kTileSize <= num_elems) { + // for (i = 0; i < kTileSize; ++i) { + // partial_result = Reducer(partial_result, input[x + i]); + // } + // } else { + // for (i = 0; i < kTileSize; ++i) { + // if (x + i < num_elems) { + // partial_result = Reducer(partial_result, input[x + i]); + // } + // } + // } + // for (i = warpSize / 2; i > 0; i /= 2) { + // partial_result = Reducer(partial_result, + // __shfl_down(partial_result, i)); + // } + // if (lane_id == 0) { + // AtomicReducer(&output[y], partial_result); + // } + // } + // + // // Choose num_blocks and threads_per_block such that: + // // + // // num_blocks * threads_per_block = + // // RoundUpToNextMultipleOf(Ceil(num_elems / kTileSize), warpSize), + // // + // // and threads_per_block is a multiple of warpSize. + // reduce_kernel<<>>(); + // + auto loop_body_emitter = + [=](const llvm_ir::IrArray::Index& tile_index) -> Status { + llvm::Type* element_ir_type = + llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); + llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( + element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); + { + TF_ASSIGN_OR_RETURN(llvm::Value * init_ir_value, + init_value_gen(llvm_ir::IrArray::Index({}))); + ir_builder_.CreateStore(init_ir_value, partial_reduction_result_address); + } + + llvm::Value* x_in_tiles = tile_index[0]; + + // Emit an inner for-loop that reduces the elements in the tile. + auto emit_tile_element_loop = [=](bool tile_in_bounds) -> Status { + std::unique_ptr tile_element_loop = + llvm_ir::ForLoop::EmitForLoop("element_id_in_tile", + ir_builder_.getInt64(0), + ir_builder_.getInt64(kTileSize), + ir_builder_.getInt64(1), &ir_builder_); + + // Emit the body of the partial reduction loop. + llvm_ir::SetToFirstInsertPoint(tile_element_loop->GetBodyBasicBlock(), + &ir_builder_); + llvm::Value* x = ir_builder_.CreateNSWAdd( + ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize)), + tile_element_loop->GetIndVarValue()); + // Unless we know the tile is entirely in bounds, we have to emit a + // x-in-bounds check before reading from the input. + if (!tile_in_bounds) { + llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpULT(x, ir_builder_.getInt64(num_elems)), + "x_in_bounds", &ir_builder_); + + // Emit code that reads the input element and accumulates it to + // the partial reduction result. + llvm_ir::SetToFirstInsertPoint(if_data.true_block, &ir_builder_); + } + llvm_ir::IrArray::Index input_index( + /*linear=*/x, input_shape, &ir_builder_); + llvm::Value* input_address = ir_builder_.CreateAlloca(element_ir_type); + TF_ASSIGN_OR_RETURN(llvm::Value * input_ir_value, input_gen(input_index)); + ir_builder_.CreateStore(input_ir_value, input_address); + return (EmitCallToNestedComputation( + *reducer, {partial_reduction_result_address, input_address}, + partial_reduction_result_address)); + }; + + // x_end = kTileSize + x_in_tiles * kTileSize, i.e., the location that's + // immediately beyond the tile. + llvm::Value* x_end = ir_builder_.CreateNSWAdd( + ir_builder_.getInt64(kTileSize), + ir_builder_.CreateNSWMul(x_in_tiles, ir_builder_.getInt64(kTileSize))); + // The tile is entirely in bound if all_threads_in_bounds or + // x_end <= num_elems. + llvm::Value* tile_in_bounds = ir_builder_.CreateOr( + ir_builder_.CreateICmpULE(x_end, ir_builder_.getInt64(num_elems)), + ir_builder_.getInt1(all_threads_in_bounds)); + llvm_ir::LlvmIfData if_tile_in_bounds_data = + llvm_ir::EmitIfThenElse(tile_in_bounds, "tile_in_bounds", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.true_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/true)); + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.false_block, + &ir_builder_); + TF_RETURN_IF_ERROR(emit_tile_element_loop(/*tile_in_bounds=*/false)); + + // After the if-then-else statement on tile_in_bounds, emit calls to + // shfl_down that accumulate the partial reduction results of all threads + // from the warp. + llvm_ir::SetToFirstInsertPoint(if_tile_in_bounds_data.after_block, + &ir_builder_); + int bit_width = llvm_ir::GetSizeInBits(element_ir_type); + // bitcast cannot be applied to aggregate types (even packed ones), so we + // instead bitcast addresses of load/store to intN* of the same bit-width. + llvm::Type* shuffle_ir_type = element_ir_type->isStructTy() + ? ir_builder_.getIntNTy(bit_width) + : element_ir_type; + for (int shuffle_distance = kWarpSize / 2; shuffle_distance >= 1; + shuffle_distance /= 2) { + llvm::Value* partial_reduction_result = ir_builder_.CreateLoad( + ir_builder_.CreateBitCast(partial_reduction_result_address, + shuffle_ir_type->getPointerTo()), + "partial_reduction_result"); + llvm::Value* result_from_other_lane = ir_builder_.CreateAlloca( + element_ir_type, nullptr, "result_from_other_lane"); + ir_builder_.CreateStore( + EmitShuffleDown(partial_reduction_result, + ir_builder_.getInt32(shuffle_distance), &ir_builder_), + ir_builder_.CreateBitCast(result_from_other_lane, + shuffle_ir_type->getPointerTo())); + TF_RETURN_IF_ERROR(EmitCallToNestedComputation( + *reducer, {partial_reduction_result_address, result_from_other_lane}, + partial_reduction_result_address)); + } + + const HloInstruction* output = + reduce->IsFused() ? reduce->parent()->FusionInstruction() : reduce; + + // Emit an atomic operation that accumulates the partial reduction result of + // lane 0 (which holds the partially accumulated result for its warp) to the + // output element. + llvm::Value* lane_id = ir_builder_.CreateURem( + x_in_tiles, ir_builder_.getInt64(kWarpSize), "lane_id"); + llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse( + ir_builder_.CreateICmpEQ(lane_id, ir_builder_.getInt64(0)), + "lane_id_is_zero", &ir_builder_); + llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, + &ir_builder_); + llvm::Value* output_address = + GetIrArray(*output, *output) + .EmitArrayElementAddress( + llvm_ir::IrArray::Index(/*linear=*/ir_builder_.getInt64(0), + output->shape(), &ir_builder_), + &ir_builder_, "output_element_address"); + return EmitAtomicOperationForNestedComputation( + *reducer, output_address, partial_reduction_result_address); + }; + + // Emit a parallel loop that iterates through all input tiles, one per thread. + Shape tiled_input_shape = ShapeUtil::MakeShapeWithLayout( + reduce->shape().element_type(), {num_tiles}, {0}); + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + tiled_input_shape, ir_emitter_context_->device_description()); + CHECK(LastThunk()->kind() == Thunk::Kind::kSequential); + UpdateLaunchDimensions( + launch_dimensions, + static_cast(LastThunk())->thunks().back().get(), + ir_emitter_context_->llvm_module()); + return ParallelLoopEmitter(loop_body_emitter, tiled_input_shape, + launch_dimensions, &ir_builder_) + .EmitLoop(IrName(reduce)); +} + Status IrEmitterUnnested::EmitColumnReduction( int64 height, int64 width, HloInstruction* reduce, const Shape& input_shape, const llvm_ir::ElementGenerator& input_gen, @@ -799,14 +1215,15 @@ Status IrEmitterUnnested::EmitColumnReduction( // input_shape to normalized_input_shape and a reshape from // normalized_input_shape to input_matrix_shape. const Shape normalized_input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( - input_shape.layout().minor_to_major().rbegin(), - input_shape.layout().minor_to_major().rend()); + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_matrix_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {height, width}); + ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), + {height, width}); const llvm_ir::IrArray::Index input_matrix_index( {y, x}, input_matrix_shape, &ir_builder_); const llvm_ir::IrArray::Index input_index = @@ -901,7 +1318,7 @@ Status IrEmitterUnnested::EmitRowReduction( // // Three optimizations are performed. // - // 1. To coalesc global memory accesses, dilate the tile with a factor of 32 + // 1. To coalesce global memory accesses, dilate the tile with a factor of 32 // (i.e. the warp size). For example, suppose the width is 8x32=256. Instead // of making each tile consecutive, we let make tile 0 column // [0,32,64,...,224], tile 1 column [1,33,65,...,225], and so on. This ensures @@ -1042,13 +1459,14 @@ Status IrEmitterUnnested::EmitRowReduction( // from input_shape to normalized_input_shape and a reshape from // normalized_input_shape to input_3d_tensor_shape. const Shape normalized_input_shape = - ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout(input_shape); + ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + input_shape); + auto input_shape_min2maj = LayoutUtil::MinorToMajor(input_shape); const std::vector transpose_dimension_mapping( - input_shape.layout().minor_to_major().rbegin(), - input_shape.layout().minor_to_major().rend()); + input_shape_min2maj.rbegin(), input_shape_min2maj.rend()); const Shape input_3d_tensor_shape = - ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( - input_shape.element_type(), {depth, height, width}); + ShapeUtil::MakeShapeWithDescendingLayout(input_shape.element_type(), + {depth, height, width}); const llvm_ir::IrArray::Index input_3d_tensor_index( {z, y, x}, input_3d_tensor_shape, &ir_builder_); const llvm_ir::IrArray::Index input_index = @@ -1177,9 +1595,9 @@ Status IrEmitterUnnested::EmitReductionToVector( // whether another dimension is major or minor of them. std::sort(input_dims_to_keep.begin(), input_dims_to_keep.end(), [&input_shape](int64 dim_a, int64 dim_b) { - return PositionInContainer(input_shape.layout().minor_to_major(), + return PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_a) < - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), dim_b); }); // Now, if output rank is at least 1, `input_dims_to_keep.front()` is @@ -1189,14 +1607,11 @@ Status IrEmitterUnnested::EmitReductionToVector( // the dimensions to keep are contiguous, by prerequisite of // `EmitReductionToVector`, we only need to check whether the minormost // dimension of the input is to keep. - // - // If the output is scalar, we could emit either a row or a column reduction. - // Some tests have shown scalar reduction is no more efficient as row - // reduction, and is simpler to emit as column reduction, so we emit a column - // reduction in this case. - if (input_dims_to_keep.empty() || - input_dims_to_keep.front() == - LayoutUtil::Minor(input_shape.layout(), 0)) { + if (input_dims_to_keep.empty()) { + return EmitReductionToScalar(reduce, input_shape, input_gen, init_value_gen, + reducer); + } else if (input_dims_to_keep.front() == + LayoutUtil::Minor(input_shape.layout(), 0)) { // Column reduction. Treat the result of "input" as a matrix whose width // is the most minor dimension and height the product of other dimensions, // and treat "reduce" as a column reduction of the input matrix. @@ -1224,14 +1639,14 @@ Status IrEmitterUnnested::EmitReductionToVector( int64 width = 1; for (int64 input_dim = 0; input_dim < ShapeUtil::Rank(input_shape); ++input_dim) { - if (PositionInContainer(input_shape.layout().minor_to_major(), + if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dim) > - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dims_to_keep.back())) { depth *= input_shape.dimensions(input_dim); - } else if (PositionInContainer(input_shape.layout().minor_to_major(), + } else if (PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dim) < - PositionInContainer(input_shape.layout().minor_to_major(), + PositionInContainer(LayoutUtil::MinorToMajor(input_shape), input_dims_to_keep.front())) { width *= input_shape.dimensions(input_dim); } @@ -1279,24 +1694,24 @@ Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { } Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { - tensorflow::gtl::ArraySlice operands(tuple->operands()); - bool all_tuple_elements_have_buffer = std::all_of( - operands.begin(), operands.end(), [this](HloInstruction* tuple_element) { + bool all_tuple_elements_have_buffer = + c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( tuple_element); }); - // Tuples (especially output tuples) can take too many tuple elements, - // causing the kernel emitted exceeds the parameter space limit - // (b/31336476). As an optimization, if all tuple elements have a buffer, we - // collect their buffer addresses in a host array, and then copy that array - // to the tuple's buffer. + // Tuples (especially tuples that are the final result of a computation) can + // be so huge that if we were to emit a kernel that took each tuple element as + // a parameter, we would exceed the max allowable number of parameters to a + // GPU kernel, b/31336476. As an optimization, if all tuple elements have a + // buffer, we collect their buffer addresses in a host array, and then copy + // that array to the tuple's buffer. // // Some tuple elements (e.g. const or bitcast of const) might not have a - // buffer -- their contents are stored in code. In that case, we fall back - // to emitting kernels which have access to their buffer addresses in code. + // buffer -- their contents are stored in code. In that case, we fall back to + // emitting kernels which have access to their buffer addresses in code. if (all_tuple_elements_have_buffer) { std::vector tuple_element_buffers; - for (const HloInstruction* tuple_element : operands) { + for (const HloInstruction* tuple_element : tuple->operands()) { tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element)); } thunk_sequence_->emplace_back(MakeUnique( @@ -1338,8 +1753,7 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // TODO(b/31410564): Implement dilation rate for select-and-scatter. if (window_util::HasDilation(window)) { return Unimplemented( - "Dilation for select-and-scatter not implemented on GPU. " - "See b/31410564."); + "Dilation for SelectAndScatter not implemented on GPU."); } // kSelectAndScatter is implemented as two kernel launches: the first launch @@ -1548,62 +1962,202 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { return Status::OK(); } -llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( - const HloInstruction& hlo, std::vector* io_hlos) { - const BufferAssignment& buffer_assignment = - ir_emitter_context_->buffer_assignment(); - // GetTupleElement instructions are implemented by emitting IR that indexes - // and loads the target tuple element pointer from its operand (possibly - // recursively). For this reason, GetTupleElement instructions are associated - // with their operand buffer in 'io_hlos' and 'non_io_hlos' below. - std::vector non_io_hlos; - for (const HloInstruction* operand : hlo.operands()) { - const HloInstruction* to_lookup = operand->LatestNonGteAncestor(); - if (buffer_assignment.HasTopLevelAllocation(to_lookup) && - buffer_assignment.GetUniqueTopLevelSlice(to_lookup) - .ConsumeValueOrDie() - .allocation() - ->IsInputOrOutput()) { - io_hlos->push_back(operand); - } else { - non_io_hlos.push_back(operand); +// Figures out how to access the buffers for all subshapes of hlo's operands and +// for hlo itself (i.e. all the buffers produced by HLO). +// +// Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for +// this key is a pair {Slice, ShapeIndex}, where the slice tells you the root +// buffer to look in, and the ShapeIndex describes how to dereference starting +// at that buffer to get to the buffer in question. +// +// For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for +// hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo) +// is found at slice[3][4]. That is, slice is a void***, which we dereference +// twice -- first at index 3, and then at index 4 -- to get the address of our +// buffer. +// +// This function conservatively assumes that we'll touch all sub-buffers of +// every operand and of the output. +static std::map, + std::pair> +GetHloBufferSlices(const HloInstruction* hlo, + const BufferAssignment& buffer_assn) { + std::map, + std::pair> + slices; + + // Tries to find a slice plus an array of indices i1, ..., iN such that the + // sub-buffer for instr at index can be found at slice[i1]...[iN]. + auto find_slice_for = [&](const HloInstruction* instr, + const ShapeIndex& index) + -> optional> { + // Simple, common case: Is the buffer for instr known at runtime? If so, + // we're done. + auto slice = GetKnownAtRuntimeSlice(instr, index, buffer_assn); + if (slice.has_value()) { + return {{*slice, ShapeIndex()}}; } - } - CHECK_NE(HloOpcode::kGetTupleElement, hlo.opcode()); - if (buffer_assignment.HasTopLevelAllocation(&hlo) && - buffer_assignment.GetUniqueTopLevelSlice(&hlo) - .ConsumeValueOrDie() - .allocation() - ->IsInputOrOutput()) { - io_hlos->push_back(&hlo); - } else { - non_io_hlos.push_back(&hlo); + // If we don't know the buffer for instr at index, see if we know the buffer + // for instr at index without its last element. If so, we can dynamically + // find the buffer for instr by dereferencing a pointer in that buffer. + // Continue looking this way until we run out of elements in 'index'. + ShapeIndex new_index = index; + ShapeIndex gte_indices; + while (!new_index.empty()) { + gte_indices.push_front(new_index.back()); + new_index.pop_back(); + auto slice = GetKnownAtRuntimeSlice(instr, new_index, buffer_assn); + if (slice.has_value()) { + return {{*slice, gte_indices}}; + } + } + + // If *that* didn't work, check whether instr is a GTE instruction. If it + // is, see if we can get a buffer for its parent, and continue walking up + // parents until we find a defined buffer or we hit something that's not a + // GTE. + const HloInstruction* parent = instr; + while (parent->opcode() == HloOpcode::kGetTupleElement) { + gte_indices.push_front(parent->tuple_index()); + parent = parent->operand(0); + + auto slice = GetKnownAtRuntimeSlice(parent, {}, buffer_assn); + if (slice.has_value()) { + return {{*slice, gte_indices}}; + } + } + + return nullopt; + }; + + // Adds entries for all subshapes of instr to `slices`. + auto add_slices_for = [&](const HloInstruction* instr) { + // GPU constants don't have buffers; don't bother looking for one. + if (instr->IsConstant()) { + return; + } + + ShapeUtil::ForEachSubshape( + instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) { + if (slices.count({instr, index})) { + // HLOs can have duplicate operands; don't bother redoing work. + return; + } + auto maybe_slice = find_slice_for(instr, index); + if (maybe_slice.has_value()) { + slices[{instr, index}] = *maybe_slice; + } else { + VLOG(1) << "Couldn't find buffer for " << instr->ToString() + << " at index " << index.ToString(); + } + }); + }; + + add_slices_for(hlo); + for (const HloInstruction* operand : hlo->operands()) { + // Conservatively assume we'll need the buffers for all subshapes of the + // operand. + add_slices_for(operand); } - llvm::Function* kernel = BuildKernelPrototype(hlo, *io_hlos); - // bindings_ is reused because the bindings of kConstant to their underlying - // llvm::Constant can be shared for all HLOs in this computation. - bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); - return kernel; + return slices; } std::unique_ptr IrEmitterUnnested::BuildKernelThunk( const HloInstruction* inst) { - std::vector io_hlos; - llvm::Function* kernel = - EmitBasePointersForHloAndItsOperands(*inst, &io_hlos); + const BufferAssignment& buffer_assn = + ir_emitter_context_->buffer_assignment(); + + std::map, + std::pair> + hlo_slices = GetHloBufferSlices(inst, buffer_assn); + + // Figure out which buffer allocations need to be passed as arguments to our + // kernel. This is simply all of the allocations referenced in hlo_slices, + // plus the XLA temp buffer (if we have it). We always include the temp + // buffer because even if the kernel itself doesn't use it, a nested + // subcomputation within the kernel (e.g. a kMap's computation) might. + std::unordered_set buffers_needed; + for (const auto& kv : hlo_slices) { + buffers_needed.insert(kv.second.first.allocation()); + } + tensorflow::gtl::optional temp_buffer; + for (const BufferAllocation& alloc : buffer_assn.Allocations()) { + if (alloc.IsPreallocatedTempBuffer()) { + if (!temp_buffer.has_value()) { + temp_buffer = &alloc; + } else { + LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!"; + } + } + } + if (temp_buffer.has_value()) { + buffers_needed.insert(*temp_buffer); + } + + // We'll pass a pointer to each of the elements of `buffers` to our kernel, in + // this order. + std::vector buffers(buffers_needed.begin(), + buffers_needed.end()); + std::sort(buffers.begin(), buffers.end(), + [](const BufferAllocation* a, const BufferAllocation* b) { + return a->index() < b->index(); + }); + + llvm::Function* kernel = BuildKernelPrototype(*inst, buffers); + + // Build a map from a BufferAllocation to the corresponding argument in our + // kernel. + std::unordered_map kernel_args; + { + auto arg_it = kernel->arg_begin(); + auto buffers_it = buffers.begin(); + for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) { + kernel_args[*buffers_it] = arg_it; + } + } + + // For each buffer our kernel might want to touch, bind it to a value derived + // from our kernel args. + for (const auto& kv : hlo_slices) { + const HloInstruction* instr = kv.first.first; + const ShapeIndex& index = kv.first.second; + const BufferAllocation::Slice& slice = kv.second.first; + const ShapeIndex& gte_index = kv.second.second; + + VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString() + << " is found in slice " << slice.ToString() << " at GTE index " + << gte_index.ToString(); + + llvm::Value* loc = + ir_builder_.CreateInBoundsGEP(kernel_args.at(slice.allocation()), + {ir_builder_.getInt64(slice.offset())}); + + // If gte_index is nonempty, we have to dereference `loc` to get to the + // value we're ultimately interested in. + llvm::Type* int8_double_pointer = + llvm::PointerType::get(ir_builder_.getInt8PtrTy(), /*AddressSpace=*/0); + for (int64 idx : gte_index) { + loc = ir_builder_.CreateBitCast(loc, int8_double_pointer); + loc = ir_builder_.CreateLoad( + ir_builder_.CreateInBoundsGEP(loc, {ir_builder_.getInt64(idx)})); + } + + bindings_.BindHloToIrValue(*instr, loc, index); + } - // Compute the input buffer indices. - std::vector io_buffers; - io_buffers.reserve(io_hlos.size()); - for (const HloInstruction* io_hlo : io_hlos) { - io_buffers.push_back(GetAllocationSlice(*io_hlo->LatestNonGteAncestor())); + // Bind the temp buffer so that nested subcomputations can find it if they + // need. + if (temp_buffer.has_value()) { + bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer)); + } else { + bindings_.SetTempBufferBase( + llvm::ConstantPointerNull::get(ir_builder_.getInt8PtrTy())); } - // Create a KernelThunk that launches the kernel that implements "inst". - return MakeUnique(io_buffers, - llvm_ir::AsString(kernel->getName()), inst); + return MakeUnique(buffers, llvm_ir::AsString(kernel->getName()), + inst); } std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( @@ -1611,7 +2165,7 @@ std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); return MakeUnique( - /*source_address=*/operand->literal().InternalData(), + /*source_address=*/operand->literal().untyped_data(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ llvm_ir::ByteSizeOf(operand->shape(), @@ -1692,50 +2246,14 @@ std::unique_ptr IrEmitterUnnested::BuildGemmThunk( LOG(FATAL) << "Cannot build a GemmThunk for " << inst->ToString(); } -std::unique_ptr IrEmitterUnnested::BuildConvolutionThunk( +std::unique_ptr IrEmitterUnnested::BuildFftThunk( const HloInstruction* inst) { - const HloInstruction* lhs = inst->operand(0); - const HloInstruction* rhs = inst->operand(1); - if (inst->opcode() == HloOpcode::kConvolution) { - // Forward covolution. - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kForward, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*inst), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/inst->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - } - - // Backward filter convolution, which takes the input (activations) and the - // gradients, and computes the filter. - CHECK_EQ(HloOpcode::kFusion, inst->opcode()); - switch (inst->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardFilter, - /*input_buffer=*/GetAllocationSlice(*lhs), - /*filter_buffer=*/GetAllocationSlice(*inst), - /*output_buffer=*/GetAllocationSlice(*rhs), - /*input_shape=*/lhs->shape(), - /*filter_shape=*/inst->shape(), - /*output_shape=*/rhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - case HloInstruction::FusionKind::kConvBackwardInput: - return MakeUnique( - ConvolutionThunk::ConvolutionKind::kBackwardInput, - /*input_buffer=*/GetAllocationSlice(*inst), - /*filter_buffer=*/GetAllocationSlice(*rhs), - /*output_buffer=*/GetAllocationSlice(*lhs), - /*input_shape=*/inst->shape(), - /*filter_shape=*/rhs->shape(), - /*output_shape=*/lhs->shape(), inst->window(), - inst->convolution_dimension_numbers(), inst); - default: - LOG(FATAL) << "Not a convolution-fusion"; - } + const HloInstruction* operand = inst->operand(0); + return MakeUnique(inst->fft_type(), inst->fft_length(), + /*input_buffer=*/GetAllocationSlice(*operand), + /*output_buffer=*/GetAllocationSlice(*inst), + /*input_shape=*/operand->shape(), + /*output_shape=*/inst->shape(), inst); } Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, @@ -1773,6 +2291,24 @@ Status IrEmitterUnnested::EmitInitializer(const HloInstruction* hlo, namespace { +// Checks that the buffers corresponding to the given two HLOs share the same +// allocation. +Status CheckHloBuffersShareAllocation( + const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index, + const BufferAssignment& buffer_assignment) { + const BufferAllocation::Slice slice_a = + buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie(); + const BufferAllocation::Slice slice_b = + buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie(); + if (slice_a != slice_b) { + return InternalError( + "instruction %s %s does not share allocation with instruction %s %s", + a->ToString().c_str(), slice_a.ToString().c_str(), + b->ToString().c_str(), slice_b.ToString().c_str()); + } + return Status::OK(); +} + // Checks that all buffers used during while loop iteration share the same // buffer allocation. This includes buffers for while result, while init // operand, condition parameter, body parameter and body result. @@ -1782,37 +2318,65 @@ Status CheckWhileBuffersShareAllocation( const BufferAssignment& buffer_assignment) { return ShapeUtil::ForEachSubshapeWithStatus( xla_while->shape(), - [&buffer_assignment, &xla_while](const Shape& /*subshape*/, - const ShapeIndex& index) -> Status { - auto check = [&buffer_assignment](const HloInstruction* a, - const HloInstruction* b, - const ShapeIndex& index) -> Status { - const BufferAllocation::Slice slice_a = - buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie(); - const BufferAllocation::Slice slice_b = - buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie(); - if (slice_a != slice_b) { - return InternalError( - "instruction %s %s does not share allocation with " - "instruction %s %s", - a->ToString().c_str(), slice_a.ToString().c_str(), - b->ToString().c_str(), slice_b.ToString().c_str()); - } - return Status::OK(); - }; + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { const HloInstruction* condition_parameter = xla_while->while_condition()->parameter_instruction(0); const HloComputation* body = xla_while->while_body(); const HloInstruction* body_parameter = body->parameter_instruction(0); const HloInstruction* body_result = body->root_instruction(); - TF_RETURN_IF_ERROR(check(xla_while, xla_while->operand(0), index)); - TF_RETURN_IF_ERROR(check(xla_while, condition_parameter, index)); - TF_RETURN_IF_ERROR(check(xla_while, body_parameter, index)); - TF_RETURN_IF_ERROR(check(xla_while, body_result, index)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, xla_while->operand(0), index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, condition_parameter, index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, body_parameter, index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + xla_while, body_result, index, buffer_assignment)); return Status::OK(); }); } +// Checks that the buffers used in a conditional instruction are shared with the +// operands and result as follows: +// * The result buffer of the conditional should share the allocation with the +// result buffers of the true and false computations. +// * The buffer of operand 1 should share the allocation with the buffer of +// the parameter 0 instruction of the true computation. +// * The buffer of operand 2 should share the allocation with the buffer of +// the parameter 0 instruction of the false computation. +Status CheckConditionalBuffersShareAllocation( + const HloInstruction* conditional, + const BufferAssignment& buffer_assignment) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + conditional->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + conditional, conditional->true_computation()->root_instruction(), + index, buffer_assignment)); + TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation( + conditional, conditional->false_computation()->root_instruction(), + index, buffer_assignment)); + return Status::OK(); + })); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + conditional->operand(1)->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { + return CheckHloBuffersShareAllocation( + conditional->operand(1), + conditional->true_computation()->parameter_instruction(0), index, + buffer_assignment); + })); + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + conditional->operand(2)->shape(), + [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status { + return CheckHloBuffersShareAllocation( + conditional->operand(2), + conditional->false_computation()->parameter_instruction(0), index, + buffer_assignment); + })); + return Status::OK(); +} + } // namespace std::unique_ptr IrEmitterUnnested::BuildWhileThunk( @@ -1855,9 +2419,36 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( ir_emitter_body.ConsumeThunkSequence(), hlo); } +std::unique_ptr IrEmitterUnnested::BuildConditionalThunk( + const HloInstruction* hlo) { + // Check that the buffers used in conditional are shared with the operands and + // result appropriately. + TF_CHECK_OK(CheckConditionalBuffersShareAllocation( + hlo, ir_emitter_context_->buffer_assignment())); + + HloComputation* true_computation = hlo->true_computation(); + IrEmitterUnnested ir_emitter_true(hlo_module_config_, true_computation, + ir_emitter_context_); + TF_CHECK_OK(true_computation->root_instruction()->Accept(&ir_emitter_true)); + + HloComputation* false_computation = hlo->false_computation(); + IrEmitterUnnested ir_emitter_false(hlo_module_config_, false_computation, + ir_emitter_context_); + TF_CHECK_OK(false_computation->root_instruction()->Accept(&ir_emitter_false)); + + return MakeUnique( + GetAllocationSlice(*hlo->operand(0)), + GetAllocationSlice(*hlo->operand(1)), + GetAllocationSlice(*hlo->operand(2)), + std::move(*ir_emitter_true.ConsumeThunkSequence()), + std::move(*ir_emitter_false.ConsumeThunkSequence()), hlo); +} + Status IrEmitterUnnested::EmitTargetElementLoopInThunk( const HloInstruction& hlo, const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) { + VLOG(3) << bindings_.ToString(); + const Shape& element_shape = hlo.IsMultiOutputFusion() ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape(); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h new file mode 100644 index 0000000000000000000000000000000000000000..688760efbd2c725a4bf48e45eb6f2734b63d25e1 --- /dev/null +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -0,0 +1,205 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ + +#include "tensorflow/compiler/xla/service/gpu/ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/thunk.h" + +namespace xla { +namespace gpu { + +// Emits LLVM IR for an "unnested computation". +// +// An unnested computation is an HloComputation which you run by executing one +// or more kernels for each HloInstruction it contains. Examples of unnested +// computations: +// +// - An HloModule's root computation, +// - The body of an HLO while loop, +// - The true/false computation of an HLO conditional. +// +// Note the opportunity for confusion -- the while loop's computation is nested +// within the root computation, but it's emitted using IrEmitterUnnested! Don't +// think about it too hard. +// +// Examples of things that are not unnested computations: +// +// - The reducer of a kReduce HLO. This is emited using IrEmitterNested. +// - The body of a fusion node. IrEmitterUnenested emits the relevant code +// within a kernel function using FusedIrEmitter. (FusedIrEmitter is not +// really an IrEmitter, but is more an "IR generator generator".) +// +class IrEmitterUnnested : public IrEmitter { + public: + IrEmitterUnnested(const HloModuleConfig& hlo_module_config, + const HloComputation* hlo_computation, + IrEmitterContext* ir_emitter_context); + IrEmitterUnnested(const IrEmitterUnnested&) = delete; + IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; + + // Transfers the ownship of thunk_sequence_ out. + std::unique_ptr ConsumeThunkSequence() { + return std::move(thunk_sequence_); + } + + Status DefaultAction(HloInstruction* hlo) override; + + // IrEmitterUnnested handles the following instructions differently from + // IrEmitter. + Status HandleCopy(HloInstruction* copy) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleFusion(HloInstruction* fusion) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleInfeed(HloInstruction* xla_infeed) override; + Status HandleRng(HloInstruction* random) override; + Status HandleSelect(HloInstruction* select) override; + + Status EmitTargetElementLoop( + const HloInstruction& hlo, + const llvm_ir::ElementGenerator& body_emitter) override; + + // Same as `EmitTargetElementLoop`, but in given `thunk` rather than + // `LastThunk()`. + Status EmitTargetElementLoopInThunk( + const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter, + KernelThunk* thunk); + + private: + // Builds the appropriate thunk for the instruction hlo and returns the owning + // pointer to it. The caller needs to make sure `inst` outlives the lifetime + // of the returned Thunk object. + std::unique_ptr BuildThunk(const HloInstruction* hlo); + + // Builds the prototype of the IR kernel for `inst` and adds it to the module. + // This kernel takes as arguments pointers to the given buffer allocations. + llvm::Function* BuildKernelPrototype( + const HloInstruction& inst, + tensorflow::gtl::ArraySlice args); + + // EmitColumnReduction and EmitRowReduction emit code for column and row + // reduction of a matrix and/or 3D tensor. Row and column reduction have + // different memory access pattern, so for performance their implementations + // are significantly different. + // + // Emits code that reduces a matrix of shape [height x width] to a vector of + // [width]. Other parameters have the same meaning as those of + // `EmitReductionToVector`. Note that input shape might not be + // [height x width], but can be bitcast to [height x weight] with "height" + // being the major dimension. + Status EmitColumnReduction(int64 height, int64 width, HloInstruction* reduce, + const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a 3D tensor of shape [depth x height x width] to a + // vector of shape [height]. Other parameters have the same meaning as those + // of `EmitReductionToVector`. Note that input shape might not be + // [depth x height x width], but can be bitcast to [depth x height x weight] + // with "depth" being the most major dimension. + Status EmitRowReduction(int64 depth, int64 height, int64 width, + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Emits code that reduces a tensor of arbitrary rank to a scalar. + Status EmitReductionToScalar(HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + HloComputation* reducer); + + // Figures out whether `reduce` is a row or column reduction, and which + // dimensions to reduce, and calls either `EmitRowReduction` or + // `EmitColumnReduction` as appropriate. `input_shape` is the shape of the + // input array, which is the operand of the Reduce instruction if unfused or + // of the Fusion instruction if fused. `input_gen` and `init_value_gen` + // generate elements of the input and the initial value. Other parameters mean + // the same as for `HandleReduce`. + // + // Prerequisite: `IsReductionToVector(*reduce)` + Status EmitReductionToVector( + HloInstruction* reduce, const Shape& input_shape, + const llvm_ir::ElementGenerator& input_gen, + const llvm_ir::ElementGenerator& init_value_gen, + tensorflow::gtl::ArraySlice dimensions_to_reduce, + HloComputation* reducer); + + // Emits code to initialize buffer of `inst` in given `thunk`. + Status EmitInitializer(const HloInstruction* inst, KernelThunk* thunk); + + // Returns a KernelThunk that invokes the kernel emitted for `inst`. The + // caller needs to make sure `inst` outlives the lifetime of the returned + // Thunk object. + std::unique_ptr BuildKernelThunk(const HloInstruction* inst); + + // Returns a FftThunk that calls cuFFT to implement `inst`. + std::unique_ptr BuildFftThunk(const HloInstruction* inst); + + // Returns a GemmThunk that calls gemm to implement `inst`. The caller needs + // to make sure `inst` outlives the lifetime of the returned Thunk object. + std::unique_ptr BuildGemmThunk(const HloInstruction* inst); + + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); + + // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst); + + // Returns an InfeedThunk that performs device-to-device memcpy to implement + // `inst`. + std::unique_ptr BuildInfeedThunk(const HloInstruction* inst); + + // Returns a WhileThunk that invokes thunk sequences for 'condition' and + // 'body' sub-computations of while instruction 'hlo'. + std::unique_ptr BuildWhileThunk(const HloInstruction* hlo); + + // Returns a ForThunk which executes 'loop_limit' invocations of a thunk + // sequence from the 'body' sub-computation of the while instruction 'hlo'. + std::unique_ptr BuildForThunk(const HloInstruction* hlo, + const int64 loop_limit); + + // Returns a ConditionalThunk that executes the thunk sequence for + // 'true_computation' or 'false_computation' depending on the value of the + // predicate in the given conditional instruction. + std::unique_ptr BuildConditionalThunk(const HloInstruction* hlo); + + Status Postprocess(HloInstruction* hlo) override; + + // Returns the last generated thunk. + Thunk* LastThunk() const { return thunk_sequence_->back().get(); } + + // The thunk sequence this IrEmitter generates for the input computation. + std::unique_ptr thunk_sequence_; + + // The HloComputation that this IrEmitter emits code for. + const HloComputation* hlo_computation_; +}; + +} // namespace gpu +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_IR_EMITTER_UNNESTED_H_ diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc index 96606993696354f36e143b3b994bbe6afb902df3..c20a781a33fe89af4740ed31dd5bfb1a64473057 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.cc @@ -29,10 +29,10 @@ namespace xla { namespace gpu { KernelThunk::KernelThunk( - tensorflow::gtl::ArraySlice io_buffers, + tensorflow::gtl::ArraySlice args, const string& kernel_name, const HloInstruction* hlo_instruction) : Thunk(Kind::kKernel, hlo_instruction), - io_buffers_(io_buffers.begin(), io_buffers.end()), + args_(args.begin(), args.end()), kernel_name_(kernel_name) {} tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { @@ -42,7 +42,7 @@ tensorflow::Status KernelThunk::Initialize(const GpuExecutable& executable) { return tensorflow::Status::OK(); } - loader_spec_.reset(new se::MultiKernelLoaderSpec(io_buffers_.size() + 1)); + loader_spec_.reset(new se::MultiKernelLoaderSpec(args_.size())); tensorflow::StringPiece ptx = executable.ptx(); // Convert tensorflow::StringPiece to se::port::StringPiece because // StreamExecutor uses the latter. @@ -81,15 +81,16 @@ tensorflow::Status KernelThunk::ExecuteOnStream( kernel = &it->second; } + VLOG(3) << "Launching " << kernel->name(); // Launch the kernel with potentially multiple blocks and threads. static constexpr int kKernelArgsLimit = 1024; auto kernel_args = MakeUnique>(); - for (const BufferAllocation::Slice io_buffer : io_buffers_) { - kernel_args->add_device_memory_argument( - buffer_allocations.GetDeviceAddress(io_buffer)); + for (const BufferAllocation* arg : args_) { + const auto& buf = buffer_allocations.GetDeviceAddress(arg->index()); + kernel_args->add_device_memory_argument(buf); + VLOG(3) << " Arg: alloc #" << arg->index() << ": " << buf.opaque() << " (" + << buf.size() << "B)"; } - kernel_args->add_device_memory_argument( - buffer_allocations.GetTempBufferBase()); if (!stream->parent()->Launch( stream, se::ThreadDim(launch_dimensions.threads_per_block()), se::BlockDim(launch_dimensions.block_count()), *kernel, diff --git a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h index 350b5aaf360b0dad7f7b04d73f4c32bad55d3ce9..9ae455e2fcc253a7a08ff95764721048a16b0bf7 100644 --- a/tensorflow/compiler/xla/service/gpu/kernel_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/kernel_thunk.h @@ -46,7 +46,7 @@ class KernelThunk : public Thunk { // Constructs a thunk for the given kernel. // // `hlo_instruction` is as in Thunk. Other arguments are as the class members. - KernelThunk(tensorflow::gtl::ArraySlice io_buffers, + KernelThunk(tensorflow::gtl::ArraySlice args, const string& kernel_name, const HloInstruction* hlo_instruction); KernelThunk(const KernelThunk&) = delete; KernelThunk& operator=(const KernelThunk&) = delete; @@ -63,8 +63,8 @@ class KernelThunk : public Thunk { perftools::gputools::Stream* stream) override; private: - // The indices of the input/output buffers. - const std::vector io_buffers_; + // Buffers passed to the kernel as arguments. + const std::vector args_; // Entry kernel name for the computation. const string kernel_name_; diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc deleted file mode 100644 index d475c4171b56ceedf5fdbda8b4d6221af844261c..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ /dev/null @@ -1,153 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" - -#include - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/core/errors.h" - -namespace xla { -namespace gpu { - -Status GpuLayoutAssignment::AddBackendConstraints( - LayoutConstraints* constraints) { - for (auto* instruction : constraints->computation()->instructions()) { - // cuDNN is called with specific layouts on the input, output, and filter: - // - // input: DataLayout::kBatchDepthYX - // output: DataLayout::kBatchDepthYX - // filter: FilterLayout::kOutputInputYX - // - // The order dimensions in the constant name is major-to-minor (eg, the - // most-major dimension of the input is batch, most-minor is X). The - // specific dimension numbers these named dimensions correspond to is - // determined by the ConvolutionDimensionNumbers argument. Y is spatial - // dimension 0, and X is spatial dimension 1. - // - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN calls. - if (ImplementedAsDnnConvolution(*instruction)) { - HloInstruction* input = nullptr; - HloInstruction* filter = nullptr; - HloInstruction* output = nullptr; - if (instruction->opcode() == HloOpcode::kConvolution) { - input = instruction->mutable_operand(0); - filter = instruction->mutable_operand(1); - output = instruction; - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - input = instruction->mutable_operand(0); - filter = instruction; - output = instruction->mutable_operand(1); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - input = instruction; - filter = instruction->mutable_operand(1); - output = instruction->mutable_operand(0); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } - - // Construct minor-to-major dimension orders for operands and result. - // cuDNN's convolution APIs support the BDYX layout for activations/output - // and the OIYX layout for weights. - // TODO(b/29399649): Be more flexible about handling layouts of cuDNN - // calls after we switch to cuDNN v5. - const ConvolutionDimensionNumbers& dimension_numbers = - instruction->convolution_dimension_numbers(); - std::vector input_layout; - for (int i = dimension_numbers.input_spatial_dimensions_size() - 1; - i >= 0; --i) { - input_layout.push_back(dimension_numbers.input_spatial_dimensions(i)); - } - input_layout.push_back(dimension_numbers.input_feature_dimension()); - input_layout.push_back(dimension_numbers.input_batch_dimension()); - Shape input_shape(input->shape()); - *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); - - std::vector filter_layout; - for (int i = dimension_numbers.kernel_spatial_dimensions_size() - 1; - i >= 0; --i) { - filter_layout.push_back(dimension_numbers.kernel_spatial_dimensions(i)); - } - filter_layout.push_back( - dimension_numbers.kernel_input_feature_dimension()); - filter_layout.push_back( - dimension_numbers.kernel_output_feature_dimension()); - Shape filter_shape(filter->shape()); - *filter_shape.mutable_layout() = LayoutUtil::MakeLayout(filter_layout); - - std::vector output_layout; - for (int i = dimension_numbers.output_spatial_dimensions_size() - 1; - i >= 0; --i) { - output_layout.push_back(dimension_numbers.output_spatial_dimensions(i)); - } - output_layout.push_back(dimension_numbers.output_feature_dimension()); - output_layout.push_back(dimension_numbers.output_batch_dimension()); - Shape output_shape(output->shape()); - *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); - - // Set layouts of the instructions' shapes. - if (instruction->opcode() == HloOpcode::kConvolution) { - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, output, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, output, 1)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(output_shape, output)); - } else { - CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - // filter = BackwardFilterConvolve(input, output) - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(input_shape, filter, 0)); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(filter_shape, filter)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, filter, 1)); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - // input = BackwardInputConvolve(output, filter) - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(input_shape, input)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(output_shape, input, 0)); - TF_RETURN_IF_ERROR( - constraints->SetOperandLayout(filter_shape, input, 1)); - break; - default: - LOG(FATAL) << "Not a convolution-fusion"; - } - } - } - } - return Status::OK(); -} - -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc deleted file mode 100644 index ac206b89d329d7e4ac91ee51162c9694f6899d78..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment_test.cc +++ /dev/null @@ -1,85 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/gpu/layout_assignment.h" - -#include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_module.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/shape_layout.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -namespace gpu { -namespace { - -using LayoutAssignmentTest = HloTestBase; - -TEST_F(LayoutAssignmentTest, Elementwise) { - Shape ashape = ShapeUtil::MakeShape(F32, {42, 12}); - Shape ashape_in_row_major(ashape); - Shape ashape_in_col_major(ashape); - *ashape_in_row_major.mutable_layout() = LayoutUtil::MakeLayout({1, 0}); - *ashape_in_col_major.mutable_layout() = LayoutUtil::MakeLayout({0, 1}); - - // Enumerate all possible combinations of layouts. - for (const Shape& lhs_shape_with_layout : - {ashape_in_row_major, ashape_in_col_major}) { - for (const Shape& rhs_shape_with_layout : - {ashape_in_row_major, ashape_in_col_major}) { - for (const Shape& result_shape_with_layout : - {ashape_in_row_major, ashape_in_col_major}) { - // GpuLayoutAssignment should assign the same layout to "add" and its - // two operands. - auto builder = HloComputation::Builder(TestName()); - auto x = builder.AddInstruction( - HloInstruction::CreateParameter(0, ashape, "x")); - auto y = builder.AddInstruction( - HloInstruction::CreateParameter(1, ashape, "y")); - auto add = builder.AddInstruction( - HloInstruction::CreateBinary(ashape, HloOpcode::kAdd, x, y)); - auto module = CreateNewModule(); - HloComputation* computation = - module->AddEntryComputation(builder.Build(add)); - - ComputationLayout computation_layout( - computation->ComputeProgramShape()); - *computation_layout.mutable_parameter_layout(0) = - ShapeLayout(lhs_shape_with_layout); - *computation_layout.mutable_parameter_layout(1) = - ShapeLayout(rhs_shape_with_layout); - *computation_layout.mutable_result_layout() = - ShapeLayout(result_shape_with_layout); - - GpuLayoutAssignment layout_assignment(&computation_layout); - EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); - - for (const HloInstruction* operand : add->operands()) { - EXPECT_TRUE(LayoutUtil::Equal(add->shape().layout(), - operand->shape().layout())); - } - } - } - } -} - -} // namespace -} // namespace gpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 059943d48cd34b0ac487b91c3f3079ee3f761229..cfabae791d26d0eb49826085ad7ad166a19109a1 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -440,7 +440,7 @@ StatusOr CompileModuleToPtx(llvm::Module* module, // One-time module initializer. // Must be called only once -- DO NOT CALL DIRECTLY. -void GPUBackendInit() { +void GPUBackendInit(const HloModuleConfig& hlo_module_config) { // Feed all customized flags here, so we can override them with llvm_cl_opts // without redeploy the compiler for development purpose. @@ -466,6 +466,8 @@ void GPUBackendInit() { // between those loads. FeedLLVMWithFlags({"-memdep-block-scan-limit=500"}); + llvm_ir::InitializeLLVMCommandLineOptions(hlo_module_config); + // Initialize the NVPTX target; it's the only target we link with, so call its // specific initialization functions instead of the catch-all InitializeAll*. LLVMInitializeNVPTXTarget(); @@ -485,7 +487,7 @@ StatusOr CompileToPtx(llvm::Module* module, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path) { static std::once_flag backend_init_flag; - std::call_once(backend_init_flag, GPUBackendInit); + std::call_once(backend_init_flag, GPUBackendInit, hlo_module_config); string ptx; { diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc index 11290eda4ffcd579c03acd531b493bb7b1d34ed4..25846dc6cd4633c7becb6e62d6bc9585348a6eac 100644 --- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc +++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc @@ -27,8 +27,8 @@ namespace gpu { namespace { bool IsForwardConvolutionCanonical(const HloInstruction& conv) { - CHECK_EQ(HloOpcode::kConvolution, conv.opcode()); - return window_util::HasEvenPadding(conv.window()) && + CHECK_EQ(conv.custom_call_target(), kCudnnConvForwardCallTarget); + return window_util::HasSymmetricPadding(conv.window()) && !window_util::HasNegativePadding(conv.window()) && !window_util::HasDilation(conv.window()); } @@ -43,10 +43,16 @@ HloInstruction* MaybePaddedAndSlicedInput( const Window& conv_window, const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* input) { HloComputation* computation = input->parent(); - if (!window_util::HasEvenPadding(conv_window) || + if (!window_util::HasSymmetricPadding(conv_window) || window_util::HasBaseDilation(conv_window)) { // If padding is uneven or has dilation, we insert a kPad instruction that // applies positive padding and dilation. + // + // TODO(phawkins): If conv_window has asymmetric padding, perhaps instead of + // moving all the padding into an explicit pad op, we should keep as much + // padding inside of cudnn as possible, on the assumption that padding + // within cudnn is basically free, whereas a kPad's cost increases as the + // amount of padding increases. PaddingConfig padding_config = MakeNoPaddingConfig(input->shape().dimensions_size()); for (size_t i = 0; i < conv_dnums.input_spatial_dimensions().size(); ++i) { @@ -167,14 +173,17 @@ bool PadInsertion::CanonicalizeForwardConvolution(HloInstruction* conv) { dim->set_window_dilation(1); } + // The conv CustomCall returns a tuple (conv_result, scratch_buffer). Extract + // out the shape of conv_result. + Shape old_conv_shape = conv->shape().tuple_shapes(0); + VLOG(1) << "Canonicalizing forward conv"; - auto new_conv = HloInstruction::CreateConvolve( - conv->shape(), new_input, new_kernel, new_conv_window, - conv->convolution_dimension_numbers()); + auto new_conv = CreateCudnnConvForward(old_conv_shape, new_input, new_kernel, + new_conv_window, + conv->convolution_dimension_numbers()); VLOG(1) << "Replacing:\n " << conv->ToString() << "\nwith:\n " << new_conv->ToString(); - TF_CHECK_OK( - conv->parent()->ReplaceWithNewInstruction(conv, std::move(new_conv))); + TF_CHECK_OK(conv->parent()->ReplaceInstruction(conv, new_conv)); return true; } @@ -190,7 +199,9 @@ void IncreasePaddingHighBy(int64 delta, WindowDimension* window_dim) { bool PadInsertion::CanonicalizeBackwardFilterConvolution( HloInstruction* backward_conv) { - if (window_util::HasEvenPadding(backward_conv->window())) { + CHECK_EQ(backward_conv->custom_call_target(), + kCudnnConvBackwardFilterCallTarget); + if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } @@ -202,16 +213,11 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // ABCD0 = Pad(ABCD, padding_high=1) // BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1) // We choose the lesser of padding_low and padding_high as the new padding. - HloInstruction* transpose = backward_conv->fused_expression_root(); - HloInstruction* forward_conv = transpose->mutable_operand(0); HloInstruction* input = backward_conv->mutable_operand(0); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); // input_padding_config is the config of the kPad to be inserted. PaddingConfig input_padding_config = MakeNoPaddingConfig(ShapeUtil::Rank(input->shape())); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { @@ -223,11 +229,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // cuDNN convolution (which doesn't support negative padding) to fail. return false; } - // If the backward convolution has uneven padding on the activations, we - // move some padding on the larger end to "internal" padding, so that the - // backward convolution produces larger weight gradients which get sliced - // later. Therefore, the amount of new padding (low or high) is the minimum - // of the amount of old padding low and old padding high. + // Compute the new, even padding for the backward conv operation. int64 new_conv_padding = std::min(padding_low, padding_high); int64 dim = backward_conv_dnums.input_spatial_dimensions(i); input_padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -238,14 +240,9 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( // Since we move some padding from the backward convolution to the kPad, we // need to accordingly reduce the padding amount of the backward convolution // and its inner forward convolution. - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(-(padding_low - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(-(padding_high - new_conv_padding), - new_forward_conv_window.mutable_dimensions(i)); + auto* new_dim = new_backward_conv_window.mutable_dimensions(i); + new_dim->set_padding_low(new_conv_padding); + new_dim->set_padding_high(new_conv_padding); } // Create a new backward convolution replacing the old one. @@ -261,28 +258,12 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( .ConsumeValueOrDie(), input, padding, input_padding_config)); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - padded_input->shape(), output->shape(), new_forward_conv_window, - forward_conv_dnums) - .ConsumeValueOrDie(), - padded_input, output, new_forward_conv_window, forward_conv_dnums)); - - HloInstruction* new_transpose = - computation->AddInstruction(HloInstruction::CreateTranspose( - ShapeInference::InferTransposeShape(new_forward_conv->shape(), - transpose->dimensions()) - .ConsumeValueOrDie(), - new_forward_conv, transpose->dimensions())); - - // Fuse the new forward convolution and the new transpose to the new backward - // convolution. - HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_transpose, new_forward_conv}, - HloInstruction::FusionKind::kConvBackwardFilter, - new_backward_conv_window, backward_conv_dnums); + // The shape of the backward_conv CustomCall is a tuple (conv_result, + // scratch_buffer). Extract out the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + HloInstruction* new_backward_conv = CreateCudnnConvBackwardFilter( + backward_conv_shape, padded_input, output, new_backward_conv_window, + backward_conv_dnums); VLOG(1) << "Canonicalizing backward filter conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " @@ -295,18 +276,19 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution( bool PadInsertion::CanonicalizeBackwardInputConvolution( HloInstruction* backward_conv) { - if (window_util::HasEvenPadding(backward_conv->window())) { + if (window_util::HasSymmetricPadding(backward_conv->window())) { return false; } - HloInstruction* forward_conv = backward_conv->fused_expression_root(); - HloInstruction* reverse_filter = forward_conv->mutable_operand(1); - Window new_forward_conv_window = forward_conv->window(); Window new_backward_conv_window = backward_conv->window(); - ConvolutionDimensionNumbers forward_conv_dnums = - forward_conv->convolution_dimension_numbers(); ConvolutionDimensionNumbers backward_conv_dnums = backward_conv->convolution_dimension_numbers(); + + // The backward_conv CustomCall returns a tuple (conv_result, scratch_memory). + // Get the shape of conv_result. + Shape backward_conv_shape = backward_conv->shape().tuple_shapes(0); + + Shape new_backward_conv_shape = backward_conv_shape; for (size_t i = 0; i < backward_conv->window().dimensions_size(); ++i) { int64 padding_low = backward_conv->window().dimensions(i).padding_low(); int64 padding_high = backward_conv->window().dimensions(i).padding_high(); @@ -325,41 +307,38 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( // where the amount of padding low is larger, we can canonicalize it to // [B A] = BackwardInputConvolve([a b], [x y z], padding=(low=1,high=1)) // [A] = Slice([B A]) - // For consistency, we need to increase the low padding of the inner - // convolution by 1 as well because the input is larger now. if (padding_low > padding_high) { IncreasePaddingLowBy(padding_high - padding_low, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingLowBy(padding_low - padding_high, - new_forward_conv_window.mutable_dimensions(i)); } else if (padding_low < padding_high) { IncreasePaddingHighBy(padding_low - padding_high, new_backward_conv_window.mutable_dimensions(i)); - IncreasePaddingHighBy(padding_high - padding_low, - new_forward_conv_window.mutable_dimensions(i)); } + // Decreasing the padding by X *increases* the size of our output by X. + int64 dim = backward_conv_dnums.output_spatial_dimensions(i); + new_backward_conv_shape.set_dimensions( + dim, new_backward_conv_shape.dimensions(dim) + + std::abs(padding_low - padding_high)); } // Create a new backward convolution replacing the old one. HloComputation* computation = backward_conv->parent(); HloInstruction* output = backward_conv->mutable_operand(0); HloInstruction* filter = backward_conv->mutable_operand(1); - HloInstruction* new_reverse_filter = - computation->AddInstruction(HloInstruction::CreateReverse( - filter->shape(), filter, reverse_filter->dimensions())); - HloInstruction* new_forward_conv = - computation->AddInstruction(HloInstruction::CreateConvolve( - ShapeInference::InferConvolveShape( - output->shape(), new_reverse_filter->shape(), - new_forward_conv_window, forward_conv_dnums) - .ConsumeValueOrDie(), - output, new_reverse_filter, new_forward_conv_window, - forward_conv_dnums)); + + HloInstruction* new_backward_conv_call = CreateCudnnConvBackwardInput( + new_backward_conv_shape, output, filter, new_backward_conv_window, + backward_conv_dnums); + + // The CustomCall created above returns a tuple (conv_result, scratch_memory). + // Extract out the two elements. HloInstruction* new_backward_conv = - computation->CreateFusionInstructionForBackwardConvolution( - {new_forward_conv, new_reverse_filter}, - HloInstruction::FusionKind::kConvBackwardInput, - new_backward_conv_window, backward_conv_dnums); + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_shape, new_backward_conv_call, 0)); + HloInstruction* new_backward_conv_scratch = + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_backward_conv_call->shape().tuple_shapes(1), + new_backward_conv_call, 1)); // Slice the new backward convolution. // @@ -387,22 +366,25 @@ bool PadInsertion::CanonicalizeBackwardInputConvolution( } // Replace the old backward convolution with the slice. - CHECK(ShapeUtil::Compatible( + Shape slice_shape = ShapeInference::InferSliceShape(new_backward_conv->shape(), start_indices, limit_indices, strides) - .ConsumeValueOrDie(), - backward_conv->shape())); + .ConsumeValueOrDie(); + CHECK(ShapeUtil::Compatible(slice_shape, backward_conv_shape)) + << ShapeUtil::HumanString(slice_shape) << " vs " + << ShapeUtil::HumanString(backward_conv_shape); - auto slice = - HloInstruction::CreateSlice(backward_conv->shape(), new_backward_conv, - start_indices, limit_indices, strides); + HloInstruction* slice = computation->AddInstruction( + HloInstruction::CreateSlice(backward_conv_shape, new_backward_conv, + start_indices, limit_indices, strides)); + HloInstruction* new_tuple = computation->AddInstruction( + HloInstruction::CreateTuple({slice, new_backward_conv_scratch})); VLOG(1) << "Canonicalizing backward input conv"; VLOG(1) << "Replacing:\n " << backward_conv->ToString() << "\nwith:\n " - << slice->ToString(); + << new_tuple->ToString(); - TF_CHECK_OK( - computation->ReplaceWithNewInstruction(backward_conv, std::move(slice))); + TF_CHECK_OK(computation->ReplaceInstruction(backward_conv, new_tuple)); return true; } @@ -410,18 +392,17 @@ StatusOr PadInsertion::Run(HloModule* module) { bool changed = false; for (HloInstruction* instruction : module->entry_computation()->MakeInstructionPostOrder()) { - if (instruction->opcode() == HloOpcode::kConvolution) { - changed |= CanonicalizeForwardConvolution(instruction); - } else if (instruction->opcode() == HloOpcode::kFusion) { - switch (instruction->fusion_kind()) { - case HloInstruction::FusionKind::kConvBackwardFilter: - changed |= CanonicalizeBackwardFilterConvolution(instruction); - break; - case HloInstruction::FusionKind::kConvBackwardInput: - changed |= CanonicalizeBackwardInputConvolution(instruction); - break; - default: - break; + if (IsCustomCallToDnnConvolution(*instruction)) { + const auto& target = instruction->custom_call_target(); + if (target == kCudnnConvForwardCallTarget) { + changed |= CanonicalizeForwardConvolution(instruction); + } else if (target == kCudnnConvBackwardFilterCallTarget) { + changed |= CanonicalizeBackwardFilterConvolution(instruction); + } else if (target == kCudnnConvBackwardInputCallTarget) { + changed |= CanonicalizeBackwardInputConvolution(instruction); + } else { + LOG(FATAL) << "Unknown custom call target for cudnn conv: " + << instruction->ToString(); } } } diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc index 457e6094d90413440658452937bff2ccfe6cbe5c..388dcc008b07a76ff9ed07df04181e49a8734f51 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.cc @@ -88,6 +88,23 @@ llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock( /*HasNUW=*/true, /*HasNSW=*/true), thread_id, "linear_index", /*HasNUW=*/true, /*HasNSW=*/true); + // Add an @llvm.assume(linear_index < threads_per_block * num_blocks). + // + // This might seem obvious from the computation above, but LLVM does not + // currently determine the range of linear_index precisely. InstCombine uses + // known-bits, which, when applied to the task of determining a value's range, + // is imprecise for everything other than powers of 2. And + // CorrelatedValuePropagation is, as a cost-saving measure, disabled for + // conditions in the same basic block as their operands. + llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::assume, + {ir_builder_->CreateICmpULT( + linear_index, + ir_builder_->getInt64(launch_dimensions_.threads_per_block() * + launch_dimensions_.block_count()), + "linear_index_in_range")}, + {}, ir_builder_); + auto if_in_bounds = llvm_ir::EmitIfThenElse( ir_builder_->CreateICmpULT( linear_index, ir_builder_->getInt64(ShapeUtil::ElementsIn(shape_))), diff --git a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h index 934e7e1919f08a16daf09ec634e2f9dc0c7cc723..8ed63a854a74fc06c3c389f40fe1f5970885deac 100644 --- a/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h @@ -42,6 +42,11 @@ class ParallelLoopEmitter : public llvm_ir::LoopEmitter { const LaunchDimensions& launch_dimensions, llvm::IRBuilder<>* ir_builder); + // Constructs a loop emitter for a loop that generates on element of each of N + // arrays on each iteration. + // + // This is used in multi-output fusion. target_element_generator should + // produce a struct with N elements, one for each of target_arrays. ParallelLoopEmitter( const llvm_ir::ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc index d0d2deee24848184278e3e51dcaa3bb673b5fadc..6cf280df05496716a0780d61ded92efd9982734c 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc @@ -44,37 +44,41 @@ std::ostream& operator<<(std::ostream& out, // Calculates the launch dimensions used to invoke `hlo`. LaunchDimensions CalculateLaunchDimensions( - const Shape& shape, const se::DeviceDescription& device_desc, - PartitionStrategy partition_strategy) { - int64 warp_size = device_desc.threads_per_warp(); - + const Shape& shape, const se::DeviceDescription& device_desc) { int64 num_elements = ShapeUtil::ElementsIn(shape); if (num_elements <= 1) { return LaunchDimensions(); } - // Calculate the number of threads per block. - // Initialize threads_per_block as the threads-per-block limit. - int64 threads_per_block = device_desc.threads_per_block_limit(); - VLOG(2) << "Initial # of threads per block = " << threads_per_block; - - if (partition_strategy == PartitionStrategy::kLatency) { - // Limit the thread count to allow maximum number of registers per thread. - // TODO(b/28560520): We don't have to assume the emitted kernel will use up - // all the registers. We could use ptxas to examine the actual number of - // register used, and set the thread count accordingly. - int64 threads_per_block_limit_due_to_registers = - device_desc.registers_per_core_limit() / - device_desc.registers_per_thread_limit(); - CHECK_NE(0, threads_per_block_limit_due_to_registers); - if (threads_per_block_limit_due_to_registers < threads_per_block) { - threads_per_block = - // Make `threads_per_block` a multiple of warp size to use GPU - // efficiently. - warp_size * - std::max(1LL, threads_per_block_limit_due_to_registers / warp_size); - VLOG(2) << "Update # of threads per block due to register pressure = " - << threads_per_block; + // Since we don't do any inter-warp communication, we're free to choose any + // block size we want, subject to hardware constraints. We choose the + // smallest block size that allows the GPU to reach full occupancy (assuming + // the kernel uses sufficiently few registers). This gives us max performance + // when the kernel uses few registers, and lets us scale down gracefully as + // the kernel uses more registers. + // + // Specifically, we choose the number of threads per block such that + // + // * = + + auto threads_per_core = device_desc.threads_per_core_limit(); + auto blocks_per_core = device_desc.blocks_per_core_limit(); + int64 threads_per_block; + if (threads_per_core != 0 && blocks_per_core != 0) { + threads_per_block = device_desc.threads_per_core_limit() / + device_desc.blocks_per_core_limit(); + } else { + static std::atomic log_count{0}; + if (log_count.fetch_add(1) < 8) { + LOG(WARNING) << "Attempting to calculate launch dimensions for GPU " + "without full information about its capabilities. " + "StreamExecutor's PopulateDeviceDescription should be " + "updated for this device."; + } + threads_per_block = device_desc.threads_per_warp(); + if (threads_per_block == 0) { + // Fall back to *something* if we can't even get num threads per warp. + threads_per_block = 32; } } @@ -84,8 +88,6 @@ LaunchDimensions CalculateLaunchDimensions( << threads_per_block << ") because the latter is smaller."; } - // Calculate the block count. We copy the strategy used by Eigen: - // eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorExecutor.h int64 block_count = CeilOfRatio(num_elements, threads_per_block); VLOG(2) << tensorflow::strings::Printf( "Initialized the block count to ceil(# of elements / threads per " diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.h b/tensorflow/compiler/xla/service/gpu/partition_assignment.h index 8f7fce884acc93fd39510ad0826b819a6d9731a7..0bf463a6ef95d5a32784838c08ad239752fd1acf 100644 --- a/tensorflow/compiler/xla/service/gpu/partition_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.h @@ -30,14 +30,6 @@ limitations under the License. namespace xla { namespace gpu { -enum class PartitionStrategy { - // Optimized for latency by allowing maximum number of registers per thread. - kLatency, - // Optimized for throughput. This may limit registers per thread and cause - // longer latency. - kThroughput -}; - // Encapsulates the launch dimensions of a kernel, e.g., the block count and the // number of threads per block. class LaunchDimensions { @@ -66,8 +58,7 @@ std::ostream& operator<<(std::ostream& out, LaunchDimensions CalculateLaunchDimensions( const Shape& shape, - const perftools::gputools::DeviceDescription& device_desc, - PartitionStrategy partition_strategy = PartitionStrategy::kLatency); + const perftools::gputools::DeviceDescription& device_desc); } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/thunk.h b/tensorflow/compiler/xla/service/gpu/thunk.h index 0ff27888ad72f8190400c22a9086d1965448662c..2c3032d79be221e8cacb178ffb1817459b603cc0 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk.h +++ b/tensorflow/compiler/xla/service/gpu/thunk.h @@ -41,8 +41,13 @@ class GpuExecutable; class Thunk { public: enum class Kind { + kConditional, kConvolution, kCopy, + kCudnnBatchNormBackward, + kCudnnBatchNormForwardInference, + kCudnnBatchNormForwardTraining, + kFft, kGemm, kInfeed, kKernel, @@ -70,6 +75,29 @@ class Thunk { return tensorflow::Status::OK(); } + // Users of Thunk should call ShouldHaltAllActivityBeforeRunning(stream) + // before calling ExecuteOnStream(stream). If it returns true, it's the + // user's responsibility to wait for all activity on the GPU to finish before + // calling ExecuteOnStream. + // + // This value is not required to be constant for a given Thunk. For example, + // a Thunk that performs autotuning may return true for its first run and + // false thereafter. + virtual bool ShouldHaltAllActivityBeforeRunning( + perftools::gputools::Stream* /*stream*/) { + return false; + } + + // Indicates whether thunks scheduled after this one should wait for this one + // to complete before running. For example, a convolution thunk creates a + // scratch allocator, then kicks off a convolution in cudnn via the stream + // executor. When the stream executor call returns, the scratch allocator goes + // out of scope, and the scratch memory is deallocated. In this case, the + // convolution thunk needs to return true so that future thunks wait for the + // convolution thunk to avoid reusing the deallocated memory until the + // convolution thunk is done with it. + virtual bool ShouldBlockFutureThunks() { return false; } + // Execute the kernel for the thunk on the given stream. This method must be // called after Initialize and can be called multiple times over Thunk's // lifetime. Stream argument must be non-null. diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.cc b/tensorflow/compiler/xla/service/gpu/while_thunk.cc index 0d2412096abf7838b7b0e7617811c789f507a4a1..c21559af6d2e5dfb5aaf62afcdcaed514e0914c9 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.cc @@ -34,16 +34,14 @@ WhileThunk::WhileThunk( body_thunk_sequence_( MakeUnique(std::move(*body_thunk_sequence), hlo)) {} -tensorflow::Status WhileThunk::Initialize(const GpuExecutable& executable) { +Status WhileThunk::Initialize(const GpuExecutable& executable) { TF_RETURN_IF_ERROR(condition_thunk_sequence_->Initialize(executable)); TF_RETURN_IF_ERROR(body_thunk_sequence_->Initialize(executable)); - return tensorflow::Status::OK(); + return Status::OK(); } -tensorflow::Status WhileThunk::ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) { - +Status WhileThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { perftools::gputools::DeviceMemoryBase condition_result_data = buffer_allocations.GetDeviceAddress(condition_result_buffer_index_); @@ -55,9 +53,11 @@ tensorflow::Status WhileThunk::ExecuteOnStream( // Copy the result of condition computation and break the loop if 'false'. bool condition_result; stream->ThenMemcpy(&condition_result, condition_result_data, sizeof(bool)); - if (!stream->BlockHostUntilDone()) { + Status block_status = stream->BlockHostUntilDone(); + if (!block_status.ok()) { return InternalError( - "Failed to complete all kernels launched on stream %p", stream); + "Failed to complete all kernels launched on stream %p: %s", stream, + block_status.error_message().c_str()); } if (!condition_result) { @@ -68,7 +68,7 @@ tensorflow::Status WhileThunk::ExecuteOnStream( TF_RETURN_IF_ERROR( body_thunk_sequence_->ExecuteOnStream(buffer_allocations, stream)); } - return tensorflow::Status::OK(); + return Status::OK(); } } // namespace gpu diff --git a/tensorflow/compiler/xla/service/gpu/while_thunk.h b/tensorflow/compiler/xla/service/gpu/while_thunk.h index 95ed5497cea4fa3ba5dcdc6762cbd53cec88339a..4c9f45de9e42494df58706d0a4a3eb0c4220b8b8 100644 --- a/tensorflow/compiler/xla/service/gpu/while_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/while_thunk.h @@ -45,10 +45,9 @@ class WhileThunk : public Thunk { WhileThunk(const WhileThunk&) = delete; WhileThunk& operator=(const WhileThunk&) = delete; - tensorflow::Status Initialize(const GpuExecutable& executable) override; - tensorflow::Status ExecuteOnStream( - const BufferAllocations& buffer_allocations, - perftools::gputools::Stream* stream) override; + Status Initialize(const GpuExecutable& executable) override; + Status ExecuteOnStream(const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; private: const BufferAllocation::Slice condition_result_buffer_index_; diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.cc b/tensorflow/compiler/xla/service/gpu/while_transformer.cc index ccdd1717593e4fa7c1d1deb3f0f9ebfab1bf7209..e6caec8625f0d622dbb92bcc20802d254fe23f94 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.cc @@ -44,7 +44,7 @@ namespace { // // Parameter // | -// Const GetTupleElemet +// Const GetTupleElement // \ / // Add (root) // @@ -62,7 +62,7 @@ namespace { // &tagged_instructions)); // // Instructions that are "tagged" with a context-specific string will -// be returned in 'tagged_instructions' for further procesing (i.e. parsing +// be returned in 'tagged_instructions' for further processing (i.e. parsing // constants or recording the tuple_index). // class ExprTree { diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer.h b/tensorflow/compiler/xla/service/gpu/while_transformer.h index a4f527fce0e4e280e24efc1f33ea68a0b71555b9..fe3a954e1828ee4a323872eea81f64c7e780ad24 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer.h +++ b/tensorflow/compiler/xla/service/gpu/while_transformer.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/statusor.h" @@ -40,4 +40,4 @@ StatusOr> CanTransformWhileToFor( } // namespace gpu } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_WHILE_TRANSFORMER_H_ diff --git a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc index f16daa0b5481474e754c880ead1945297ca50168..2f290f61bd527e9827472a78256f015e066e44be 100644 --- a/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc +++ b/tensorflow/compiler/xla/service/gpu/while_transformer_test.cc @@ -117,9 +117,7 @@ class WhileTransformerTest : public HloTestBase { } void RunCopyInsertionPass() { - HloVerifier verifier([](const Shape& shape) { - return ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); - }); + HloVerifier verifier; TF_ASSERT_OK(verifier.Run(module_.get()).status()); CopyInsertion copy_insertion; TF_ASSERT_OK(copy_insertion.Run(module_.get()).status()); diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc index 049e8d80d80c835bca4a4d38592564ba82a3ecf9..05017008e2ddbe0b9e78d06275fdec5d08d94bfa 100644 --- a/tensorflow/compiler/xla/service/graphviz_example.cc +++ b/tensorflow/compiler/xla/service/graphviz_example.cc @@ -108,8 +108,11 @@ std::unique_ptr MakeBigGraph() { HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0)); auto clamp = builder.AddInstruction(HloInstruction::CreateTernary( vshape, HloOpcode::kClamp, copy, param_v1, param_v2)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0)); + HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums)); auto tuple = builder.AddInstruction( HloInstruction::CreateTuple({dot, param_s, clamp})); auto scalar = builder.AddInstruction( diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc index 34e2f7ee206c6a74073d8f4e867e862feb4aff49..cde5877e29f36abc61c5417ce960e2c7699e2749 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.cc +++ b/tensorflow/compiler/xla/service/heap_simulator.cc @@ -64,10 +64,8 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign) { - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, - &module_sequence); + const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + HeapSimulator heap(std::move(algorithm), size_fn, options, &module_sequence); const HloComputation* entry_computation = module.entry_computation(); const std::vector& instruction_sequence = FindOrDie(module_sequence, entry_computation); @@ -81,9 +79,8 @@ StatusOr HeapSimulator::Run( std::unique_ptr algorithm, const HloComputation& computation, const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign) { - HeapSimulator heap(std::move(algorithm), size_fn, buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options) { + HeapSimulator heap(std::move(algorithm), size_fn, options, /*module_sequence=*/nullptr); TF_RETURN_IF_ERROR(heap.RunComputation(computation, instruction_sequence, points_to_analysis)); @@ -199,15 +196,17 @@ Status HeapSimulator::RunComputation( // We can only share with the operand buffer if it is about to be freed; // we must be the last user of the buffer. bool shared = false; - for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { - if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && - buffer->instruction()->opcode() != HloOpcode::kCopy && - CanShareOperandBufferWithUser( - operand_buffer->instruction(), operand_buffer->index(), - buffer->instruction(), buffer->index(), points_to_analysis)) { - ShareBuffer(buffer, operand_buffer, instruction); - shared = true; - break; + if (options_.may_reuse_operand_buffers) { + for (const LogicalBuffer* operand_buffer : operand_buffers_to_free) { + if (buffer->instruction()->IsUserOf(operand_buffer->instruction()) && + buffer->instruction()->opcode() != HloOpcode::kCopy && + CanShareOperandBufferWithUser( + operand_buffer->instruction(), operand_buffer->index(), + buffer->instruction(), buffer->index(), points_to_analysis)) { + ShareBuffer(buffer, operand_buffer, instruction); + shared = true; + break; + } } } @@ -266,13 +265,12 @@ Status HeapSimulator::RunComputation( HeapSimulator::HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, - const FlatSet* buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence) : no_fragmentation_stats_(MakeUnique()), algorithm_(std::move(algorithm)), size_fn_(size_fn), - buffers_to_assign_(buffers_to_assign), + options_(options), module_sequence_(module_sequence) { debug_trace_.set_whole_module_simulation(module_sequence_ != nullptr); } @@ -280,13 +278,16 @@ HeapSimulator::HeapSimulator( HeapSimulator::~HeapSimulator() {} bool HeapSimulator::IgnoreBuffer(const LogicalBuffer* buffer) const { - // Buffers for constants are ignored, as with BufferAssigner. Also ignore - // buffers that we're not meant to assign. + // Buffers for constants are ignored unless the alloc_constants option is + // set. Also ignore buffers that we're not meant to assign. // // TODO(b/32248867): For consistency, constants should get allocations. - return buffer->instruction()->opcode() == HloOpcode::kConstant || - (buffers_to_assign_ != nullptr && - buffers_to_assign_->count(buffer) == 0); + if (!options_.alloc_constants && + buffer->instruction()->opcode() == HloOpcode::kConstant) { + return true; + } + return options_.buffers_to_assign != nullptr && + options_.buffers_to_assign->count(buffer) == 0; } // Alloc always calls the underlying heap algorithm. @@ -400,8 +401,8 @@ HeapSimulator::Result HeapSimulator::Finish() { } // If we were told to assign specific buffers, make sure we've assigned // exactly that many buffers. - if (buffers_to_assign_ != nullptr) { - CHECK_EQ(buffers_to_assign_->size(), result.chunk_map.size()); + if (options_.buffers_to_assign != nullptr) { + CHECK_EQ(options_.buffers_to_assign->size(), result.chunk_map.size()); } } diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h index a03ad2f37cf5ede35275ea019ab3d5998fb85d0a..636f19dd39f09721bd82fc4b44785f196f281ad7 100644 --- a/tensorflow/compiler/xla/service/heap_simulator.h +++ b/tensorflow/compiler/xla/service/heap_simulator.h @@ -67,6 +67,23 @@ class HeapSimulator { HeapSimulatorTrace debug_trace; }; + // The different options to be passed to the Run() APIs. + struct Options { + Options() + : may_reuse_operand_buffers(true), + alloc_constants(false), + buffers_to_assign(nullptr) {} + + // Whether a buffer about to be Free()-ed, can be recycled for a new born + // one, hence collapsing Free()+Alloc() calls (default true). + bool may_reuse_operand_buffers; + // Whether to issue Alloc() and Free() calls for constants (default false). + bool alloc_constants; + // If 'buffers_to_assign' is provided, only those buffers are assigned + // offsets, otherwise all buffers defined by the instructions are assigned. + const tensorflow::gtl::FlatSet* buffers_to_assign; + }; + // Run the heap simulation with the given algorithm, assuming the given // module_sequence, which must contain a topologically-consistent total // ordering of all instructions within each computation. The result is invalid @@ -76,15 +93,12 @@ class HeapSimulator { // to running on a per-computation basis, since we can re-use buffer space for // called sub-computations. // - // If 'buffers_to_assign' is provided, only those buffers are assigned - // offsets, otherwise all buffers defined by the instructions are assigned. static StatusOr Run( std::unique_ptr algorithm, const HloModule& module, const SequentialHloOrdering::HloModuleSequence& module_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign = - nullptr); + const Options& options = Options()); // Same as above, but runs on a single computation. The 'instruction_sequence' // must contain a topologically-consistent total ordering of all instructions @@ -96,8 +110,7 @@ class HeapSimulator { const std::vector& instruction_sequence, const TuplePointsToAnalysis& points_to_analysis, const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign = - nullptr); + const Options& options = Options()); private: // If 'module_sequence' is non-null, it is used to find kCall and kWhile @@ -105,8 +118,7 @@ class HeapSimulator { // be run recursively. I.e. the simulation is run over the whole module. HeapSimulator( std::unique_ptr algorithm, - const LogicalBuffer::SizeFunction& size_fn, - const tensorflow::gtl::FlatSet* buffers_to_assign, + const LogicalBuffer::SizeFunction& size_fn, const Options& options, const SequentialHloOrdering::HloModuleSequence* module_sequence); ~HeapSimulator(); @@ -130,7 +142,7 @@ class HeapSimulator { const std::unique_ptr no_fragmentation_stats_; const std::unique_ptr algorithm_; const LogicalBuffer::SizeFunction size_fn_; - const tensorflow::gtl::FlatSet* buffers_to_assign_; + const Options options_; const SequentialHloOrdering::HloModuleSequence* module_sequence_; // In addition to Alloc and Free, the heap simulator exposes a concept of @@ -264,7 +276,7 @@ class LazyBestFitHeap : public HeapAlgorithm { enum { kLazyAllocOffset = -1 }; struct OrderChunkByIncreasingSize { - bool operator()(const Chunk& a, const Chunk& b) { + bool operator()(const Chunk& a, const Chunk& b) const { if (a.size != b.size) return a.size < b.size; return a.offset < b.offset; } diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc index 17b926c8748e45b55f380e7595711b9e7a748f64..387b649a731ebcbfd8307807469f39f22d192b06 100644 --- a/tensorflow/compiler/xla/service/heap_simulator_test.cc +++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc @@ -259,8 +259,11 @@ TEST_F(HeapSimulatorTest, MultiplyDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); // The buffer for dot is the output, and it cannot be shared with the buffer // for mul, since dot isn't elementwise. @@ -292,8 +295,11 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto add = builder.AddInstruction( HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA)); @@ -327,10 +333,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); // The buffer for dot1 is the output. No buffers can be shared. The buffer // for mul is freed before the end, since it's no longer used after dot0 @@ -365,10 +374,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) { HloInstruction::CreateParameter(2, f32scalar_, "paramY")); auto mul = builder.AddInstruction(HloInstruction::CreateBinary( f32vec4_, HloOpcode::kMultiply, paramA, paramX)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot0 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY)); + HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums)); auto dot1 = builder.AddInstruction( - HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY)); + HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums)); auto tuple = builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1})); diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index e984bdb5f75f714fb7b4453a97178158d9b8a8b8..36db711c6c3570efdf678261ad38bbdb08cf94aa 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -36,6 +36,9 @@ option cc_enable_arenas = true; // Serialization of HloInstruction. message HloInstructionProto { + reserved 10; + reserved "parameter_name"; + string name = 1; string opcode = 2; xla.Shape shape = 3; @@ -50,9 +53,8 @@ message HloInstructionProto { // Literal, only present for kConstant. xla.LiteralProto literal = 8; - // Parameter info, only present for kParameter. + // Parameter number is only present for kParameter. int64 parameter_number = 9; - string parameter_name = 10; // Fusion state, only present for kFusion. string fusion_kind = 11; @@ -118,6 +120,15 @@ message HloInstructionProto { // Shape of outfeed request. xla.Shape outfeed_shape = 29; + + // Describes the dimension numbers used for a dot operation + xla.DotDimensionNumbers dot_dimension_numbers = 30; + + // FFT type (FFT, IFFT, etc). + xla.FftType fft_type = 31; + + // FFT length. + repeated int64 fft_length = 32; } // Serialization of HloComputation. @@ -189,6 +200,7 @@ message BufferAllocationProto { bool is_reusable = 4; bool is_entry_computation_parameter = 5; int64 parameter_number = 6; + repeated int64 parameter_shape_index = 10; bool maybe_live_out = 7; int64 color = 8; repeated Assigned assigned = 9; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c215cc48d60b93a88d64b7c4aecb2aa3bb460443..5432419e4a2dd2916da32ac6566851bf52fd68ca 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -131,9 +131,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->parameter_name(); + string param_name = param_instruction->name(); // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters so replace the final number in the name with + // renumbering the parameters, so replace the final number in the name with // the updated value. const string param_underscore = ".param_"; size_t index = param_name.rfind(param_underscore); @@ -176,10 +176,6 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { return false; } - if (instruction->HasSideEffect()) { - return false; - } - return true; } @@ -207,7 +203,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( worklist.pop(); if (removed.count(item) != 0 || item->user_count() != 0 || - item == root_instruction() || !IsRemovable(item)) { + item == root_instruction() || !IsRemovable(item) || + item->HasSideEffect()) { continue; } for (int i = 0; i < item->operand_count(); ++i) { @@ -367,26 +364,27 @@ std::list HloComputation::MakeEmbeddedComputationsList() return post_order; } -string HloComputation::ToString(int nested_level, - bool include_large_constants) const { +string HloComputation::ToString(const HloPrintOptions& options) const { std::ostringstream s; - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } - s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " {\n"; + if (options.print_percent()) { + s << "%"; + } + s << name(); + if (options.print_program_shape()) { + s << " " << ShapeUtil::HumanString(ComputeProgramShape()); + } + s << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } s << " " << (instruction == root_instruction_ ? "ROOT " : "") - << instruction->ToString( - /*compact_operands=*/false, - /*include_metadata=*/true, - /*include_large_constants=*/include_large_constants) - << "\n"; + << instruction->ToString(options) << "\n"; } - for (int i = 0; i < nested_level; i++) { + for (int i = 0; i < options.indent_amount(); i++) { s << " "; } s << "}"; @@ -463,20 +461,6 @@ HloInstruction* HloComputation::CreateFusionInstruction( return fusion_instruction; } -HloInstruction* HloComputation::CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums) { - CHECK(HloInstruction::FusionKind::kConvBackwardFilter == fusion_kind || - HloInstruction::FusionKind::kConvBackwardInput == fusion_kind); - HloInstruction* root = instructions_to_fuse.front(); - HloInstruction* fusion_instruction = - AddInstruction(HloInstruction::CreateFusionForBackwardConvolution( - root->shape(), fusion_kind, window, conv_dnums, root)); - FuseInstructionsInto(instructions_to_fuse, fusion_instruction); - return fusion_instruction; -} - StatusOr HloComputation::DeepCopyHelper( HloInstruction* instruction, const ShapeTree* indices_to_copy, ShapeTree* copies_added, ShapeIndex* index) { @@ -543,7 +527,7 @@ ProgramShape HloComputation::ComputeProgramShape() const { for (auto* param_instruction : param_instructions_) { *program_shape.add_parameters() = param_instruction->shape(); - *program_shape.add_parameter_names() = param_instruction->parameter_name(); + *program_shape.add_parameter_names() = param_instruction->name(); } *program_shape.mutable_result() = root_instruction_->shape(); @@ -579,8 +563,11 @@ Status HloComputation::ReplaceWithNewInstruction( Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, HloInstruction* new_instruction) { - TF_RET_CHECK(ShapeUtil::Compatible(old_instruction->shape(), - new_instruction->shape())); + TF_RET_CHECK( + ShapeUtil::Compatible(old_instruction->shape(), new_instruction->shape())) + << ShapeUtil::HumanString(old_instruction->shape()) << " vs " + << ShapeUtil::HumanString(new_instruction->shape()); + VLOG(10) << "transformed " << old_instruction->ToString() << " to " << new_instruction->ToString(); // Try to add metadata for HLO instructions that are created to replace diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 353b30bc69d98556311635d6097e3d6ad5fb2aaa..061c59abe5e315917161ed737f89de53d71bb1b6 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -138,8 +138,11 @@ class HloComputation { void UniquifyName(NameUniquer* name_uniquer); // Return a string representation of the computation. - string ToString(int nested_level = 0, - bool include_large_constants = false) const; + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Returns a serialized representation of this computation. HloComputationProto ToProto() const; @@ -221,15 +224,6 @@ class HloComputation { tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction::FusionKind fusion_kind); - // Creates a fusion instruction that represents a backward convolution. This - // is similar to CreateFusionInstruction but takes window and conv_dnums which - // indicate the window and convolution dimension numbers of the backward - // convolution. - HloInstruction* CreateFusionInstructionForBackwardConvolution( - tensorflow::gtl::ArraySlice instructions_to_fuse, - HloInstruction::FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums); - // Create a deep copy of the given instruction and return the instruction // producing the copied result. All instructions performing the copy are added // to the computation. For array-shaped values, this method trivially returns @@ -313,11 +307,17 @@ class HloComputation { replacements, HloModule* module = nullptr, const string& suffix = "clone"); - // Returns true if the given instruction can be removed from the - // computation. Instructions such as parameters and send/receive instructions - // cannot be removed without violating invariants of the HLO computation or - // module with the exception of fusion computation. A parameter instruction - // is removable for a fusion computation. + // Returns true if the given instruction can be removed from the computation. + // Parameter instructions cannot be removed without violating invariants of + // the HLO computation with the exception of fusion computation. A parameter + // instruction is removable for a fusion computation. + // + // Note that IsRemovable() is a necessariy condition to remove an instruction + // rather than a sufficient condition. For example, instructions with + // side-effect (e.g., Send, Infeed) may be removed from a computation, but the + // transformation must guarantee the invariants relevant to the instructions + // still hold (e.g., Send and Recv must be removed together to make each + // channel complete). bool IsRemovable(const HloInstruction* instruction); // Returns true if this computation has a side effect. A computation has a diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 6fcc01dd64d1ac041e99eedb8b1de476409b257d..9cd5a1e2b71a7aa768e478289e8e4cc13030fcc3 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -201,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) { Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { const Shape& lhs_shape = dot->operand(0)->shape(); const Shape& rhs_shape = dot->operand(1)->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); // Count of elements along the reduction dimension (last dimension for the // rhs). - int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); - + int64 reduction_width = + lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0)); // First divide by reduction width before multiplying by rhs elements to avoid // overflow. int64 fma_count; @@ -391,13 +392,35 @@ Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { return Status::OK(); } +Status HloCostAnalysis::HandleFft(const HloInstruction* fft) { + auto real_shape = + ShapeUtil::IsTuple(fft->operand(0)->shape()) + ? ShapeUtil::GetTupleElementShape(fft->operand(0)->shape(), 0) + : fft->operand(0)->shape(); + constexpr int kFmaPerComplexMul = 4; + int64 log_factors = 1; + for (int64 dim : fft->fft_length()) { + log_factors *= tensorflow::Log2Floor(dim); + } + current_properties_[kFlopsKey] = kFmaFlops * kFmaPerComplexMul * log_factors * + ShapeUtil::ElementsIn(real_shape); + return Status::OK(); +} + Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. // // TODO(b/33004697): Compute correct cost here, taking the actual number of // replicas into account. - current_properties_[kFlopsKey] = ShapeUtil::ElementsIn(crs->shape()); + double flops = 0.0; + ShapeUtil::ForEachSubshape( + crs->shape(), [&, this](const Shape& subshape, const ShapeIndex&) { + if (ShapeUtil::IsArray(subshape)) { + flops += ShapeUtil::ElementsIn(subshape); + } + }); + current_properties_[kFlopsKey] = flops; return Status::OK(); } @@ -446,7 +469,13 @@ Status HloCostAnalysis::HandleCall(const HloInstruction* call) { } Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { - return Unimplemented("Custom-call is not implemented for HLO cost analysis."); + // We can't do anything sane with CustomCalls, since we don't know what they + // do, and returning an error status will stop iteration over this + // computation, which is probably also not what we want. So just punt and + // return OK. This will cause all of the properties to be reported as 0, + // which is fine. + current_should_compute_bottleneck_time_ = false; + return Status::OK(); } Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index fade19522cf0c30eab037aa355de1f9203f80014..e5783539e5436f09fa58bf7889118380ee90fea0 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -67,6 +67,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleCopy(const HloInstruction* copy) override; Status HandleDot(const HloInstruction* dot) override; Status HandleConvolution(const HloInstruction* convolution) override; + Status HandleFft(const HloInstruction* fft) override; Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index d35ba19a730555433099072c51ca5cf3774d4b99..279edd4ba8772a9c576f76f554de8ec68631b953 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" namespace xla { @@ -91,6 +92,10 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { StatusOr HloCSE::Run(HloModule* module) { bool changed = false; + const std::function + eq_instructions = std::equal_to(); + const std::function + eq_computations = std::equal_to(); for (auto* computation : module->computations()) { changed |= CombineConstants(computation, is_layout_sensitive_); @@ -110,11 +115,12 @@ StatusOr HloCSE::Run(HloModule* module) { // of this instruction. const HloInstruction* operand = instruction->operand(0); - std::vector equivalent_instructions; + tensorflow::gtl::InlinedVector + equivalent_instructions; for (HloInstruction* user : operand->users()) { - if (user != instruction && user->Identical(*instruction) && - (!is_layout_sensitive_ || - ShapeUtil::Equal(user->shape(), instruction->shape()))) { + if (user != instruction && + user->Identical(*instruction, eq_instructions, eq_computations, + is_layout_sensitive_)) { equivalent_instructions.push_back(user); } } diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 3f34b9ceb34abc89fca5b896bb8fbe3a06cd6ed4..ccbbe8f1966d59b4ab2904dcc6ea724aaf4a7603 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -154,7 +154,11 @@ bool HloDataflowAnalysis::Phi( tensorflow::gtl::ArraySlice inputs) { CHECK(ssa_form_); VLOG(4) << "Phi(" << instruction->name() << ")"; - + VLOG(5) << "instruction value set = " + << GetInstructionValueSet(instruction).ToString(); + for (const InstructionValueSet* input : inputs) { + VLOG(5) << "input value set = " << input->ToString(); + } for (const InstructionValueSet* input : inputs) { DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape())); } @@ -171,9 +175,14 @@ bool HloDataflowAnalysis::Phi( value_set.values().size() == 1 ? value_set.values()[0] : nullptr; // Construct a vector of unique value IDs of the inputs. + // Don't add value ids where the input is equal to the definition. std::vector input_value_ids; for (const InstructionValueSet* input : inputs) { for (const HloValue* value : input->element(index).values()) { + if (value->defining_instruction() == instruction && + value->defining_index() == index) { + continue; + } input_value_ids.push_back(value->id()); } } @@ -190,6 +199,7 @@ bool HloDataflowAnalysis::Phi( current_value->defining_instruction() == instruction && current_value->defining_index() == index); if (current_value_defined_here) { + VLOG(5) << "current_value_defined_here: " << current_value->ToString(); CHECK(current_value->is_phi()); auto it = std::find(input_value_ids.begin(), input_value_ids.end(), current_value->id()); @@ -197,7 +207,7 @@ bool HloDataflowAnalysis::Phi( input_value_ids.erase(it); } } - + VLOG(5) << "after input_value_ids.size = " << input_value_ids.size(); if (input_value_ids.empty()) { // A value set which has at least one element should never have its value // set reduced to zero elements. During dataflow value sets only can go @@ -276,6 +286,23 @@ bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) { return false; } +bool HloDataflowAnalysis::UpdateSliceValueSet(HloInstruction* slice) { + CHECK_EQ(slice->opcode(), HloOpcode::kSlice); + if (!slice->IsInPlaceSlice()) { + return false; + } + // If this slice is lowered to an in-place version, then it forwards the + // operand value to the output. + const InstructionValueSet& operand_set = + GetInstructionValueSet(slice->operand(0)); + InstructionValueSet& slice_set = GetInstructionValueSet(slice); + if (operand_set != slice_set) { + slice_set = operand_set; + return true; + } + return false; +} + bool HloDataflowAnalysis::UpdateSendValueSet(HloInstruction* send) { CHECK_EQ(send->opcode(), HloOpcode::kSend); bool changed = false; @@ -333,6 +360,21 @@ bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) { return false; } +bool HloDataflowAnalysis::UpdateConditionalValueSet( + HloInstruction* conditional) { + CHECK_EQ(conditional->opcode(), HloOpcode::kConditional); + std::vector inputs = { + &GetInstructionValueSet( + conditional->true_computation()->root_instruction()), + &GetInstructionValueSet( + conditional->false_computation()->root_instruction())}; + // A phi-node is not defined for a kConditional instruction even though it + // represents a join point. This is because the current approach is to define + // a phi-node only for kWhile to account for the dataflow through back-edges + // and deal with the ambiguity in other cases. + return GetInstructionValueSet(conditional).AssignUnionOf(inputs); +} + bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { CHECK_EQ(copy->opcode(), HloOpcode::kCopy); bool changed = false; @@ -394,7 +436,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { CHECK_EQ(call_graph_node.context(), CallContext::kSequential); std::vector inputs; - bool called_from_while = false; + bool need_phi = false; for (const CallSite& callsite : call_graph_node.caller_callsites()) { if (callsite.instruction()->opcode() == HloOpcode::kCall) { // The operand values of a call instruction are forwarded to the @@ -416,14 +458,32 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) { inputs.push_back(&GetInstructionValueSet( callsite.instruction()->while_body()->root_instruction())); } - called_from_while = true; + need_phi = true; + } else if (callsite.instruction()->opcode() == HloOpcode::kConditional) { + CHECK_EQ(parameter->parameter_number(), 0); + auto conditional = callsite.instruction(); + // Conditional has 3 operands. Operand 0 is the predicate, operand 1 is + // the argument to the true computation and operand 2 is the argument to + // the false computation. + // + // If the parameter belongs to conditional's true computation, then + // operand 1 is forwarded to this parameter instruction. If the parameter + // belongs to conditional's false computation, then operand 2 is forwarded + // to this parameter instruction. + if (parameter->parent() == conditional->true_computation()) { + inputs.push_back(&GetInstructionValueSet(conditional->operand(1))); + } else { + CHECK_EQ(parameter->parent(), conditional->false_computation()); + inputs.push_back(&GetInstructionValueSet(conditional->operand(2))); + } + need_phi = true; } else { LOG(FATAL) << "CallContext::kSequential computations should only be " - "called from call or while instructions"; + "called from call, while, or conditional instructions"; } } - if (ssa_form_ && called_from_while) { + if (ssa_form_ && need_phi) { return Phi(parameter, inputs); } else { return GetInstructionValueSet(parameter).AssignUnionOf(inputs); @@ -494,6 +554,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( switch (instruction->opcode()) { case HloOpcode::kBitcast: return UpdateBitcastValueSet(instruction); + case HloOpcode::kSlice: + return UpdateSliceValueSet(instruction); case HloOpcode::kCopy: return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: @@ -512,6 +574,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateSendValueSet(instruction); case HloOpcode::kRecvDone: return UpdateRecvDoneValueSet(instruction); + case HloOpcode::kConditional: + return UpdateConditionalValueSet(instruction); default: // Instruction does not forward HloValues (it defines all values in its // output). No update is necessary. @@ -521,16 +585,23 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( void HloDataflowAnalysis::Propagate() { std::queue worklist; + tensorflow::gtl::FlatSet workset; + auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) { + if (workset.insert(instruction).second) { + worklist.push(instruction); + } + }; for (HloComputation* computation : module_->computations()) { for (HloInstruction* instruction : computation->instructions()) { - worklist.push(instruction); + add_to_worklist(instruction); } } while (!worklist.empty()) { HloInstruction* instruction = worklist.front(); worklist.pop(); + workset.erase(workset.find(instruction)); VLOG(3) << "Worklist top: " << instruction->name(); VLOG(3) << ToString(); @@ -544,19 +615,38 @@ void HloDataflowAnalysis::Propagate() { VLOG(4) << "New value set for " << instruction->name() << ": " << GetInstructionValueSet(instruction); - // Instruction value was updated. Add users to work list. + // Instruction value was updated. Add users to work list if we haven't + // already. for (HloInstruction* user : instruction->users()) { - worklist.push(user); + add_to_worklist(user); // If user sequentially calls a computation, then the respective // parameter(s) of the computation need to be updated. - for (HloComputation* called_computation : user->called_computations()) { - const CallGraphNode& call_graph_node = - call_graph_->GetNode(called_computation); - if (call_graph_node.context() == CallContext::kSequential) { - for (int64 operand_number : user->OperandIndices(instruction)) { - worklist.push( - called_computation->parameter_instruction(operand_number)); + if (user->opcode() == HloOpcode::kConditional) { + // If operand 0 is the use of instruction, then no parameters need to be + // updated, since that is the predicate of the conditional. + // If operand 1 is the use of instruction, then the true_computation's + // parameter need to be updated. + // If operand 2 is the use of instruction, then the false_computation's + // parameter need to be updated. + // + // Note that the same instruction can be used in both operand 1 and + // operand 2. + if (user->operand(1) == instruction) { + add_to_worklist(user->true_computation()->parameter_instruction(0)); + } + if (user->operand(2) == instruction) { + add_to_worklist(user->false_computation()->parameter_instruction(0)); + } + } else { + for (HloComputation* called_computation : user->called_computations()) { + const CallGraphNode& call_graph_node = + call_graph_->GetNode(called_computation); + if (call_graph_node.context() == CallContext::kSequential) { + for (int64 operand_number : user->OperandIndices(instruction)) { + add_to_worklist( + called_computation->parameter_instruction(operand_number)); + } } } } @@ -568,14 +658,15 @@ void HloDataflowAnalysis::Propagate() { const CallGraphNode& call_graph_node = call_graph_->GetNode(instruction->parent()); for (const CallSite& callsite : call_graph_node.caller_callsites()) { - if (callsite.instruction()->opcode() == HloOpcode::kCall) { - worklist.push(callsite.instruction()); + if ((callsite.instruction()->opcode() == HloOpcode::kCall) || + (callsite.instruction()->opcode() == HloOpcode::kConditional)) { + add_to_worklist(callsite.instruction()); } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) { // Add the while itself, and the body and condition parameters. - worklist.push(callsite.instruction()); - worklist.push( + add_to_worklist(callsite.instruction()); + add_to_worklist( callsite.instruction()->while_body()->parameter_instruction(0)); - worklist.push( + add_to_worklist( callsite.instruction()->while_condition()->parameter_instruction( 0)); } @@ -634,8 +725,14 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { define_all_values(); } break; + case HloOpcode::kSlice: + if (!instruction->IsInPlaceSlice()) { + define_all_values(); + } + break; case HloOpcode::kWhile: case HloOpcode::kCall: + case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index dfd81ae951042f7a4d6d3c24af4d5b7e046c272d..89d318188f0855c7924836a51cfe98d531e08cb4 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -145,7 +145,9 @@ class HloDataflowAnalysis { // Updates the value set for a particular instruction type. Returns whether // the instruction value set changed. bool UpdateBitcastValueSet(HloInstruction* bitcast); + bool UpdateSliceValueSet(HloInstruction* slice); bool UpdateCallValueSet(HloInstruction* call); + bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc index f08f0b1d6833b028baa5f997929a17eb5abae205..e714b2567fd1b3eab607a19f0bb7e3288150dc64 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { namespace { +using ::testing::ElementsAre; using ::testing::UnorderedElementsAre; // Test is parameterized on a bool which is whether the dataflow analysis is @@ -77,11 +78,23 @@ class HloDataflowAnalysisTest : public HloTestBase, analysis_->GetValueDefinedAt(b), *analysis_); } + std::unique_ptr CreateR0F32UnaryOpComputation( + HloOpcode opcode) { + HloComputation::Builder builder(TestName() + "." + HloOpcodeString(opcode)); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, opcode, param0)); + return builder.Build(); + } + std::unique_ptr module_; std::unique_ptr analysis_; const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); const Shape vector_shape_ = ShapeUtil::MakeShape(F32, {42}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); }; TEST_P(HloDataflowAnalysisTest, BinaryOperation) { @@ -1528,6 +1541,315 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) { EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log)); } +TEST_P(HloDataflowAnalysisTest, ConditionalWithIdentity) { + // Test conditional with identity computations in both true and false cases. + // + // true_computation(F32[] %true_param): + // return %true_param + // + // false_computation(F32[] %false_param): + // return %false_param + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // return Conditional(%pred, %constant1, true_computation, + // %constant2, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "true_param")); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "false_param")); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, constant1, true_computation, constant2, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + ElementsAre(HloUse{conditional, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + ElementsAre(HloUse{conditional, 2, {}})); + + EXPECT_EQ(analysis.values().size(), 3); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(constant1), + analysis.GetValueDefinedAt(constant2))); +} + +TEST_P(HloDataflowAnalysisTest, ConditionalTakingTupleOperand) { + // Test conditional with true and false computations taking a tuple operand. + // + // true_computation((F32[], F32[]) %true_param): + // %true_x = GetTupleElement(%true_param, 0) + // %true_y = GetTupleElement(%true_param, 1) + // return Add(%true_x, %true_y) + // + // false_computation((F32[], F32[]) %false_param): + // %false_x = GetTupleElement(%false_param, 0) + // %false_y = GetTupleElement(%false_param, 1) + // return Subtract(%false_x, %false_y) + // + // entry: + // %pred = Constant(true) + // %constant1 = Constant(56.0) + // %constant2 = Constant(12.0) + // %tuple_operand = Tuple(%constant1, %constant2) + // return Conditional(%pred, %tuple_operand, true_computation, + // %tuple_operand, false_computation) + + auto true_builder = HloComputation::Builder(TestName() + "_true"); + auto true_param = true_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "true_param")); + auto true_x = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 0)); + auto true_y = true_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, true_param, 1)); + auto add = true_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kAdd, true_x, true_y)); + HloComputation* true_computation = + module_->AddEmbeddedComputation(true_builder.Build()); + + auto false_builder = HloComputation::Builder(TestName() + "_false"); + auto false_param = false_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "false_param")); + auto false_x = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 0)); + auto false_y = false_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, false_param, 1)); + auto sub = false_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_shape_, HloOpcode::kSubtract, false_x, false_y)); + HloComputation* false_computation = + module_->AddEmbeddedComputation(false_builder.Build()); + + auto builder = HloComputation::Builder(TestName()); + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(56.0f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(12.0f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred, tuple_operand, true_computation, tuple_operand, + false_computation)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(add)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(sub)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_y)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_x)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_y)); + + EXPECT_EQ(analysis.GetUniqueValueAt(true_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_param), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_y), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_x), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_y), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_THAT(analysis.GetValueDefinedAt(pred).uses(), + ElementsAre(HloUse{conditional, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {0}}, + HloUse{conditional, 2, {0}}, + HloUse{add, 0, {}}, HloUse{sub, 0, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(constant2).uses(), + UnorderedElementsAre(HloUse{conditional, 1, {1}}, + HloUse{conditional, 2, {1}}, + HloUse{add, 1, {}}, HloUse{sub, 1, {}})); + EXPECT_THAT(analysis.GetValueDefinedAt(tuple_operand).uses(), + UnorderedElementsAre( + HloUse{conditional, 1, {}}, HloUse{conditional, 2, {}}, + HloUse{true_x, 0, {}}, HloUse{true_y, 0, {}}, + HloUse{false_x, 0, {}}, HloUse{false_y, 0, {}})); + + EXPECT_EQ(analysis.values().size(), 6); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT(HloValuesAt(conditional), + UnorderedElementsAre(analysis.GetValueDefinedAt(add), + analysis.GetValueDefinedAt(sub))); +} + +TEST_P(HloDataflowAnalysisTest, NestedConditionals) { + // computation1(F32[] %param1): + // %ceil = Ceil(%param1) + // return %ceil + // + // computation2(F32[] %param2): + // %floor = Floor(%param2) + // return %floor + // + // computation3(F32[] %param3): + // %negate = Negate(%param3) + // return %negate + // + // inner_conditional((PRED, F32[], F32[]) %param_cond): + // %pred_cond = GetTupleElement(%param_cond, 0) + // %true_operand_cond = GetTupleElement(%param_cond, 1) + // %false_opearnd_cond = GetTupleElement(%param_cond, 2) + // return Conditional(%pred_cond, %true_operand_cond, computation1, + // %false_operand_cond, computation2) + // + // entry: + // %pred1 = Constant(true) + // %pred2 = Constant(false) + // %constant1 = Constant(1.1); + // %constant2 = Constant(2.2); + // %constant3 = Constant(3.3); + // return Conditional(%pred1, (%pred2, %constant1, %constant2), + // inner_conditional, %constant3, computation3) + + auto computation1 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kCeil)); + auto computation2 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kFloor)); + auto computation3 = module_->AddEmbeddedComputation( + CreateR0F32UnaryOpComputation(HloOpcode::kNegate)); + + // Build inner_conditional computation. + const Shape scalar_bool_shape = ShapeUtil::MakeShape(PRED, {}); + const Shape tuple_param_shape = ShapeUtil::MakeTupleShape( + {scalar_bool_shape, scalar_shape_, scalar_shape_}); + auto inner_builder = + HloComputation::Builder(TestName() + "_inner_conditional"); + auto param_cond = inner_builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_param_shape, "param_cond")); + auto pred_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_bool_shape, param_cond, 0)); + auto true_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 1)); + auto false_operand_cond = inner_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param_cond, 2)); + auto inner_conditional = + inner_builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred_cond, true_operand_cond, computation1, + false_operand_cond, computation2)); + auto inner_conditional_computation = + module_->AddEmbeddedComputation(inner_builder.Build()); + + // Build entry computation. + auto builder = HloComputation::Builder(TestName()); + auto pred1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + auto pred2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + auto constant1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); + auto constant2 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.2f))); + auto constant3 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(3.3f))); + auto tuple_operand = builder.AddInstruction( + HloInstruction::CreateTuple({pred2, constant1, constant2})); + auto conditional = builder.AddInstruction(HloInstruction::CreateConditional( + scalar_shape_, pred1, tuple_operand, inner_conditional_computation, + constant3, computation3)); + module_->AddEntryComputation(builder.Build()); + + const HloDataflowAnalysis& analysis = RunAnalysis(GetParam()); + + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(pred2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant1)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant2)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(constant3)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(tuple_operand)); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation1->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation2->root_instruction())); + EXPECT_TRUE(analysis.ValueIsDefinedAt(computation3->root_instruction())); + + auto computation1_param = computation1->parameter_instruction(0); + auto computation2_param = computation2->parameter_instruction(0); + auto computation3_param = computation3->parameter_instruction(0); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation1_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation2_param)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(computation3_param)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation1_param), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation2_param), + analysis.GetValueDefinedAt(constant2)); + EXPECT_EQ(analysis.GetUniqueValueAt(computation3_param), + analysis.GetValueDefinedAt(constant3)); + + EXPECT_FALSE(analysis.ValueIsDefinedAt(param_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(pred_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(true_operand_cond)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(false_operand_cond)); + EXPECT_EQ(analysis.GetUniqueValueAt(param_cond), + analysis.GetValueDefinedAt(tuple_operand)); + EXPECT_EQ(analysis.GetUniqueValueAt(pred_cond), + analysis.GetValueDefinedAt(pred2)); + EXPECT_EQ(analysis.GetUniqueValueAt(true_operand_cond), + analysis.GetValueDefinedAt(constant1)); + EXPECT_EQ(analysis.GetUniqueValueAt(false_operand_cond), + analysis.GetValueDefinedAt(constant2)); + + EXPECT_EQ(analysis.values().size(), 9); + EXPECT_FALSE(analysis.ValueIsDefinedAt(inner_conditional)); + EXPECT_FALSE(analysis.ValueIsDefinedAt(conditional)); + EXPECT_THAT( + HloValuesAt(inner_conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()))); + EXPECT_THAT( + HloValuesAt(conditional), + UnorderedElementsAre( + analysis.GetValueDefinedAt(computation1->root_instruction()), + analysis.GetValueDefinedAt(computation2->root_instruction()), + analysis.GetValueDefinedAt(computation3->root_instruction()))); +} + INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation, HloDataflowAnalysisTest, ::testing::Values(false, true)); diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 40e67c87807b3e13d8ac09206bf6be02e4f9ff31..1e5f0f797a13fd7e7ce1cc934387a274a74153bc 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -55,7 +55,8 @@ StatusOr HloDCE::Run(HloModule* module) { for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && live_instructions.count(instruction) == 0 && - computation->IsRemovable(instruction)) { + computation->IsRemovable(instruction) && + !instruction->HasSideEffect()) { dead_roots.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index d54b9a27087a42fd23eab0bd06e8deaca567312b..5a56607a665c4cbeb7b2572f182b88e890602968 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -70,6 +70,26 @@ TEST_F(HloDceTest, NoDeadCode) { EXPECT_EQ(3, computation->instruction_count()); } +TEST_F(HloDceTest, InstructionsWithSideEffect) { + // Verify that side-effect instructions (Send in this test) are not removed. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + builder.AddInstruction( + HloInstruction::CreateSend(constant, /*channel_id=*/0)); + builder.AddInstruction(HloInstruction::CreateTuple({})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_EQ(3, computation->instruction_count()); + + HloDCE dce; + EXPECT_FALSE(dce.Run(module.get()).ValueOrDie()); + + EXPECT_EQ(3, computation->instruction_count()); +} + TEST_F(HloDceTest, DeadParameters) { // Verify that dead parameters are not removed, but use of the dead parameters // are. diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc new file mode 100644 index 0000000000000000000000000000000000000000..c782d1b0add17c70e0f54826917df251d5a613e2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -0,0 +1,207 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_query.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { +namespace { + +HloInstruction* ToElementType(HloInstruction* hlo, PrimitiveType type) { + if (hlo->shape().element_type() != type) { + Shape shape = ShapeUtil::ChangeElementType(hlo->shape(), type); + hlo = hlo->parent()->AddInstruction( + HloInstruction::CreateConvert(shape, hlo)); + } + CHECK_EQ(hlo->shape().element_type(), type); + return hlo; +} + +bool HasOperandType(HloInstruction* hlo, PrimitiveType type) { + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == type) { + return true; + } + } + return false; +} + +// Finds out the Tuple Shape of the new instruction after converting the element +// type of the operands of the original instruction from `from_type` to +// `to_type`. +// +// This routine assumes the resulting `shape` of the original instruction is a +// non-nested tuple. This assumption is currently safe as only kTuple, kInfeed, +// kOutfeed, kCall, kCustomCall and kBatchNorm* HLO instructions can produce +// results with tuple shapes, and this routine is only called to convert the +// result shapes of kBatchNorm* HLO instructions, which are non-nested tuples. +Shape GetConvertedTupleShape(const Shape& shape, PrimitiveType from_type, + PrimitiveType to_type) { + std::vector new_tuple_subshapes; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + Shape subshape = ShapeUtil::GetTupleElementShape(shape, i); + CHECK(!ShapeUtil::IsTuple(subshape)); + if (subshape.element_type() == from_type) { + subshape = ShapeUtil::ChangeElementType(subshape, to_type); + } + new_tuple_subshapes.push_back(subshape); + } + return ShapeUtil::MakeTupleShape(new_tuple_subshapes); +} + +// Converts the elements of the result of `hlo` to produce a new tuple with +// shape `to_shape`. +// +// This routine assumes `hlo` is an instruction that produces a non-nested Tuple +// as a result. +HloInstruction* ConvertTupleElements(HloInstruction* hlo, + const Shape& to_shape) { + const Shape& shape = hlo->shape(); + HloComputation* computation = hlo->parent(); + std::vector tuple_elements; + for (int64 i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) { + const Shape& ele_shape = ShapeUtil::GetTupleElementShape(shape, i); + HloInstruction* element = computation->AddInstruction( + HloInstruction::CreateGetTupleElement(ele_shape, hlo, i)); + const Shape& to_ele_shape = ShapeUtil::GetTupleElementShape(to_shape, i); + CHECK(!ShapeUtil::IsTuple(ele_shape)); + if (ele_shape.element_type() != to_ele_shape.element_type()) { + element = computation->AddInstruction( + HloInstruction::CreateConvert(to_ele_shape, element)); + } + tuple_elements.push_back(element); + } + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +} // namespace + +HloElementTypeConverter::HloElementTypeConverter( + PrimitiveType eliminate_type, PrimitiveType replace_with_type) + : eliminate_type_(eliminate_type), replace_with_type_(replace_with_type) {} + +// This routine converts the arithmetic operations in the given module that use +// eliminate_type_ to operations that use replace_with_type_. +StatusOr HloElementTypeConverter::Run(HloModule* module) { + XLA_VLOG_LINES( + 3, "HloElementTypeConverter::Run(), before:\n" + module->ToString()); + + if (eliminate_type_ == replace_with_type_) { + return false; + } + + bool changed = false; + for (auto* computation : module->computations()) { + for (auto* hlo : computation->MakeInstructionPostOrder()) { + const auto opcode = hlo->opcode(); + // These are ops where it does not make sense to convert them. + if (opcode == HloOpcode::kParameter || opcode == HloOpcode::kConstant || + opcode == HloOpcode::kTuple || opcode == HloOpcode::kConvert || + opcode == HloOpcode::kGetTupleElement || + opcode == HloOpcode::kInfeed || opcode == HloOpcode::kOutfeed) { + continue; + } + + // We cannot change a CustomCall since we have no way of adjusting the + // called binary to expect the updated type. + if (opcode == HloOpcode::kCustomCall) { + continue; + } + + // These are ops with embedded computations where it suffices to convert + // the embedded computations instead of converting the ops themselves. + if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || + opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || + opcode == HloOpcode::kSelectAndScatter || + opcode == HloOpcode::kConditional) { + continue; + } + TF_RET_CHECK(hlo->called_computations().empty()) << hlo->ToString(); + + if (!HasOperandType(hlo, eliminate_type_)) { + // If this CHECK fires, then this was an instruction that does not take + // the elimination type as an operand but it does return it. This pass + // does not have a feature to change the output type in that case, so + // instead of silently failing to eliminate the type, it fails loudly. + TF_RET_CHECK(hlo->shape().element_type() != eliminate_type_); + continue; + } + + // Handle instructions that perform arithmetic operations and contain + // operands with eliminate_type_. + // + // First, convert the operands with eliminate_type_ to operands with + // replace_with_type_. + std::vector new_operands; + for (HloInstruction* operand : hlo->operands()) { + if (operand->shape().element_type() == eliminate_type_) { + operand = ToElementType(operand, replace_with_type_); + } + new_operands.push_back(operand); + } + + // Then find out the result type of the new instruction with the same + // opcode but using the converted operands, create the new instruction, + // and convert the result of the new instruction back to match the result + // type of the original instruction. + HloInstruction* new_hlo; + if (hlo->shape().element_type() == eliminate_type_) { + Shape shape = + ShapeUtil::ChangeElementType(hlo->shape(), replace_with_type_); + new_hlo = computation->AddInstruction( + hlo->CloneWithNewOperands(shape, new_operands, hlo->GetModule())); + new_hlo = ToElementType(new_hlo, eliminate_type_); + } else if (ShapeUtil::IsTuple(hlo->shape())) { + Shape old_shape = hlo->shape(); + Shape new_shape = GetConvertedTupleShape(hlo->shape(), eliminate_type_, + replace_with_type_); + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + new_shape, new_operands, hlo->GetModule())); + // Convert the elements of the result of `new_hlo` to produce a new + // tuple with shape `old_shape`. + new_hlo = ConvertTupleElements(new_hlo, old_shape); + } else { + new_hlo = computation->AddInstruction(hlo->CloneWithNewOperands( + hlo->shape(), new_operands, hlo->GetModule())); + } + + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(hlo, new_hlo)); + changed = true; + } + } + XLA_VLOG_LINES( + 2, "HloElementTypeConverter::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.h b/tensorflow/compiler/xla/service/hlo_element_type_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..2b109225d0b192e5c9e4f6d841377ffad8078dc2 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.h @@ -0,0 +1,49 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass that eliminates certain element types as the input or output of ops by +// inserting Convert ops. This allows a backend to support an element type while +// only actually implementing the Convert op for that element type. This is +// generally not the fastest approach, but it works. +class HloElementTypeConverter : public HloPassInterface { + public: + // eliminate_type is the type to eliminate as the input or output of ops, + // using Convert ops to replace it with replace_with_type. + HloElementTypeConverter(PrimitiveType eliminate_type, + PrimitiveType replace_with_type); + + tensorflow::StringPiece name() const override { + return "element_type_converter"; + } + + // Returns the pass on the module and returns whether the module was modified. + StatusOr Run(HloModule* module) override; + + private: + PrimitiveType eliminate_type_; + PrimitiveType replace_with_type_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_ELEMENT_TYPE_CONVERTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cb94d9f19b825d1321263a4737b66a6bf198a772 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class HloElementTypeConverterTest : public HloTestBase { + public: + std::unique_ptr CreateModuleFromHloString( + const string& hlo_string) { + return HloRunner::CreateModuleFromString(hlo_string, + GetDebugOptionsForTest()) + .ValueOrDie(); + } +}; + +TEST_F(HloElementTypeConverterTest, CustomCallsNotConverted) { + const string& hlo_string = R"( + HloModule custom_call + ENTRY CustomCall { + constant = bf16[1]{0} constant({12345}) + ROOT custom-call = bf16[1,2,3]{0,2,1} custom-call(constant), + custom_call_target="foo" + } + )"; + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_FALSE(converted); +} + +TEST_F(HloElementTypeConverterTest, InfeedsOutfeedsNotConverted) { + const string& hlo_string = R"( + HloModule InfeedOutfeed + ENTRY RoundTrip16MiBR1.v2 { + ROOT infeed = bf16[4]{0} infeed() + outfeed = () outfeed(infeed) + } + )"; + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_FALSE(converted); +} + +TEST_F(HloElementTypeConverterTest, OperationsInNestedTuplesConverted) { + const string& hlo_string = R"( + HloModule NestedTuples + ENTRY NestedTuples.v5 { + constant.4 = bf16[] constant(42) + constant.2 = f32[2]{0} constant({1, 2}) + constant.3 = bf16[] constant(42) + add = bf16[] add(constant.2, constant.3) + tuple = (f32[2]{0}, bf16[]) tuple(constant.2, add) + constant.5 = bf16[2]{0} constant({22, 44}) + ROOT tuple.1 = ((f32[2]{0}, bf16[]), bf16[2]{0}) tuple(tuple, constant.5) + } + )"; + + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + const HloInstruction* bf16_op = + module->entry_computation()->root_instruction()->operand(0)->operand(1); + EXPECT_THAT(bf16_op, op::Convert(op::Add(op::Constant(), op::Convert()))); +} + +TEST_F(HloElementTypeConverterTest, BatchNormGradBF16Converted) { + const string& hlo_string = R"( + HloModule BatchNormGrad + ENTRY BatchNormGrad.v6 { + constant.4 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + { /*i1=0*/ {0}, {0} }, { /*i1=1*/ {0}, {0} } }, { /*i0=1*/ { /*i1=0*/ {0}, + {0} }, { /*i1=1*/ {0}, {0} } } }) + constant.5 = bf16[2]{0} constant({1, 1}) + constant.6 = bf16[2]{0} constant({0, 0}) + constant.7 = bf16[2]{0} constant({1, 1}) + constant.8 = bf16[2,2,2,1]{3,2,1,0} constant(bf16[2,2,2,1] { { /*i0=0*/ + { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} } }, { /*i0=1*/ { /*i1=0*/ + {5}, {6} }, { /*i1=1*/ {7}, {8} } } }) + ROOT batch-norm-grad = (bf16[2,2,2,1]{3,2,1,0}, bf16[2]{0}, bf16[2]{0}) + batch-norm-grad(constant.4, constant.5, constant.6, constant.7, + constant.8), epsilon=0, feature_index=2 + } + )"; + + auto module = CreateModuleFromHloString(hlo_string); + HloElementTypeConverter type_converter(BF16, F32); + TF_ASSERT_OK_AND_ASSIGN(bool converted, type_converter.Run(module.get())); + EXPECT_TRUE(converted); + const HloInstruction* tuple_instr = + module->entry_computation()->root_instruction(); + ::testing::Matcher batch_norm = + op::BatchNormGrad(); + EXPECT_THAT(tuple_instr, + op::Tuple(op::Convert(op::GetTupleElement(batch_norm, 0)), + op::Convert(op::GetTupleElement(batch_norm, 1)), + op::Convert(op::GetTupleElement(batch_norm, 2)))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e693d167a1f96f65b894d07fb2c8f33e61ff8c49..81212cda4266ec820230d0d84fc2a395edaf411e 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/hlo_query.h" #include "tensorflow/compiler/xla/service/shape_inference.h" @@ -39,9 +40,11 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/bitmap.h" +#include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/types.h" @@ -165,13 +168,67 @@ StatusOr> ElementWiseUnaryOpImpl( return std::move(result); } +// For one particular placement of a window in a base shape (the placement is +// represented as `window_count_index`), iterates inside the window. Translates +// the window index into base index. If the base index is within bound, call `f` +// with the base index. +void IterateThroughWindow( + const Shape& window_shape, const Window& window, const Shape& base_shape, + const tensorflow::gtl::ArraySlice& window_count_index, + const std::function&)>& f) { + const int64 rank = ShapeUtil::Rank(base_shape); + DimensionVector window_index(rank); + std::fill(window_index.begin(), window_index.end(), 0); + do { + std::vector base_index(rank); + bool out_of_bound = false; + for (int64 i = 0; i < rank; ++i) { + base_index[i] = window_count_index[i] * window.dimensions(i).stride() + + window_index[i] - window.dimensions(i).padding_low(); + if (base_index[i] < 0 || base_index[i] >= base_shape.dimensions(i)) { + out_of_bound = true; + break; + } + } + if (!out_of_bound) { + f(base_index); + } + } while (IndexUtil::BumpIndices(window_shape, &window_index)); +} + } // namespace -template +template class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { public: explicit TypedVisitor(HloEvaluator* p) : parent_(p) {} + // The following higher-order functions convert a function with ElementwiseT + // to a function with ReturnT. + std::function ConvertUnaryFunction( + const std::function& unary_op) { + return [&unary_op](ReturnT arg) { + return static_cast(unary_op(static_cast(arg))); + }; + } + std::function ConvertBinaryFunction( + const std::function& + binary_op) { + return [&binary_op](ReturnT arg1, ReturnT arg2) { + return static_cast(binary_op(static_cast(arg1), + static_cast(arg2))); + }; + } + std::function ConvertTernaryFunction( + const std::function& ternary_op) { + return [&ternary_op](ReturnT arg1, ReturnT arg2, ReturnT arg3) { + return static_cast(ternary_op(static_cast(arg1), + static_cast(arg2), + static_cast(arg3))); + }; + } + Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", HloOpcodeString(hlo_instruction->opcode()).c_str()); @@ -197,24 +254,25 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { is_complex_t::value>::type* = nullptr> Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], - ElementWiseUnaryOp(abs, [](NativeT elem_operand) { + ElementWiseUnaryOp(abs, [](ElementwiseT elem_operand) { return std::abs(elem_operand); })); return Status::OK(); } Status HandleAbs(HloInstruction* abs) override { - return HandleAbs(abs); + return HandleAbs(abs); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRound(HloInstruction* round) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[round], - ElementWiseUnaryOp(round, [](ReturnT elem_operand) { - return std::round(elem_operand); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[round], + ElementWiseUnaryOp(round, [](ElementwiseT elem_operand) { + return std::round(elem_operand); + })); return Status::OK(); } @@ -233,7 +291,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[broadcast] = Literal::CreateFromShape(broadcast->shape()); auto output = parent_->evaluated_[broadcast].get(); - auto operand_to_broadcast = + const Literal& operand_to_broadcast = parent_->GetEvaluatedLiteralFor(broadcast->operand(0)); std::vector broadcast_indices( ShapeUtil::Rank(broadcast->operand(0)->shape()), 0); @@ -264,7 +322,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleCeil(HloInstruction* ceil) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], - ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { + ElementWiseUnaryOp(ceil, [](ElementwiseT elem_operand) { return std::ceil(elem_operand); })); return Status::OK(); @@ -299,7 +357,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleExp(HloInstruction* exp) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], - ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { + ElementWiseUnaryOp(exp, [](ElementwiseT elem_operand) { return std::exp(elem_operand); })); return Status::OK(); @@ -309,10 +367,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleFloor(HloInstruction* floor) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], - ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { - return std::floor(elem_operand); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[floor], + ElementWiseUnaryOp(floor, [](ElementwiseT elem_operand) { + return std::floor(elem_operand); + })); return Status::OK(); } @@ -329,7 +388,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], - ElementWiseUnaryOp(log, [](ReturnT elem_operand) { + ElementWiseUnaryOp(log, [](ElementwiseT elem_operand) { return std::log(elem_operand); })); return Status::OK(); @@ -341,7 +400,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { !std::is_same::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { return ~elem_operand; })); return Status::OK(); @@ -351,7 +410,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { NativeT>::value>::type* = nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { return !elem_operand; })); return Status::OK(); @@ -362,7 +421,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { nullptr> Status HandleNot(HloInstruction* not_) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], - ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + ElementWiseUnaryOp(not_, [](ElementwiseT elem_operand) { return !elem_operand; })); return Status::OK(); @@ -376,7 +435,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleNot(HloInstruction* not_) override { - return HandleNot(not_); + return HandleNot(not_); } template ::value>::type* = nullptr> Status HandleNegate(HloInstruction* negate) { using type = typename std::make_unsigned::type; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { - return NativeT(-type(elem_operand)); - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp(negate, [](ElementwiseT elem_operand) { + return NativeT(-type(elem_operand)); + })); return Status::OK(); } @@ -397,10 +457,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { !std::is_signed::value || std::is_floating_point::value>::type* = nullptr> Status HandleNegate(HloInstruction* negate) { - TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], - ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { - return -elem_operand; - })); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[negate], + ElementWiseUnaryOp( + negate, [](ElementwiseT elem_operand) { return -elem_operand; })); return Status::OK(); } @@ -413,9 +473,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { - return (ReturnT(0) < elem_operand) - - (elem_operand < ReturnT(0)); + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { + return (ElementwiseT(0) < elem_operand) - + (elem_operand < ElementwiseT(0)); })); return Status::OK(); } @@ -425,9 +485,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename std::enable_if::value>::type* = nullptr> Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], - ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + ElementWiseUnaryOp(sign, [](ElementwiseT elem_operand) { auto abs_val = std::abs(elem_operand); - return 0 == abs_val ? ReturnT(0) + return 0 == abs_val ? ElementwiseT(0) : elem_operand / abs_val; })); return Status::OK(); @@ -437,9 +497,30 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return HandleSign(sign); } + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[atan2], + ElementWiseBinaryOp(atan2, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return std::atan2(lhs_elem, rhs_elem); + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleAtan2(HloInstruction* atan2) { + return InvalidArgument("Unsupported type for Atan2"); + } + + Status HandleAtan2(HloInstruction* atan2) override { + return HandleAtan2(atan2); + } + Status HandleTanh(HloInstruction* tanh) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], - ElementWiseUnaryOp(tanh, [](ReturnT elem_operand) { + ElementWiseUnaryOp(tanh, [](ElementwiseT elem_operand) { return std::tanh(elem_operand); })); return Status::OK(); @@ -453,9 +534,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { using type = typename std::make_unsigned::type; TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return NativeT(type(lhs_elem) * type(rhs_elem)); - })); + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return NativeT(type(lhs_elem) * type(rhs_elem)); + })); return Status::OK(); } @@ -467,40 +549,42 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleMultiply(HloInstruction* multiply) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], - ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem * rhs_elem; - })); + ElementWiseBinaryOp(multiply, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem * rhs_elem; + })); return Status::OK(); } Status HandleMultiply(HloInstruction* multiply) override { - return HandleMultiply(multiply); + return HandleMultiply(multiply); } Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], - ElementWiseBinaryOp(subtract, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem - rhs_elem; - })); + ElementWiseBinaryOp(subtract, + [](ElementwiseT lhs_elem, ElementwiseT rhs_elem) { + return lhs_elem - rhs_elem; + })); return Status::OK(); } Status HandleAdd(HloInstruction* add) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[add], - ElementWiseBinaryOp(add, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem + rhs_elem; - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[add], + ElementWiseBinaryOp(add, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem + rhs_elem; + })); return Status::OK(); } Status HandleDivide(HloInstruction* divide) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[divide], - ElementWiseBinaryOp(divide, [](ReturnT lhs_elem, ReturnT rhs_elem) { - return lhs_elem / rhs_elem; - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[divide], + ElementWiseBinaryOp(divide, [](ElementwiseT lhs_elem, + ElementwiseT rhs_elem) { + return lhs_elem / rhs_elem; + })); return Status::OK(); } @@ -510,7 +594,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleMaximum(HloInstruction* maximum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], - ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { + ElementWiseBinaryOp(maximum, [](ElementwiseT lhs, ElementwiseT rhs) { return std::fmax(lhs, rhs); })); return Status::OK(); @@ -524,18 +608,18 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleMaximum(HloInstruction* maximum) override { - return HandleMaximum(maximum); + return HandleMaximum(maximum); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleMinimum(HloInstruction* minimum) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[minimum], - ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::fmin(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[minimum], + ElementWiseBinaryOp(minimum, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmin(lhs_el, rhs_el); + })); return Status::OK(); } @@ -547,15 +631,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleMinimum(HloInstruction* minimum) override { - return HandleMinimum(minimum); + return HandleMinimum(minimum); } Status HandlePower(HloInstruction* power) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[power], - ElementWiseBinaryOp(power, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::pow(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[power], + ElementWiseBinaryOp(power, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::pow(lhs_el, rhs_el); + })); return Status::OK(); } @@ -563,11 +647,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleRemainder(HloInstruction* remainder) { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[remainder], - ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { - return std::fmod(lhs_el, rhs_el); - })); + TF_ASSIGN_OR_RETURN(parent_->evaluated_[remainder], + ElementWiseBinaryOp(remainder, [](ElementwiseT lhs_el, + ElementwiseT rhs_el) { + return std::fmod(lhs_el, rhs_el); + })); return Status::OK(); } @@ -579,7 +663,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleRemainder(HloInstruction* remainder) override { - return HandleRemainder(remainder); + return HandleRemainder(remainder); } template evaluated_[and_], - ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el & rhs_el; })); return Status::OK(); @@ -599,7 +683,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleAnd(HloInstruction* and_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[and_], - ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(and_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el && rhs_el; })); return Status::OK(); @@ -613,7 +697,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleAnd(HloInstruction* and_) override { - return HandleAnd(and_); + return HandleAnd(and_); } template evaluated_[or_], - ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el | rhs_el; })); return Status::OK(); @@ -633,7 +717,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status HandleOr(HloInstruction* or_) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[or_], - ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { + ElementWiseBinaryOp(or_, [](ElementwiseT lhs_el, ElementwiseT rhs_el) { return lhs_el || rhs_el; })); return Status::OK(); @@ -647,7 +731,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleOr(HloInstruction* or_) override { - return HandleOr(or_); + return HandleOr(or_); } template (shl); + return HandleShiftLeft(shl); } template (shra); + return HandleShiftRightArithmetic(shra); } template (shrl); + return HandleShiftRightLogical(shrl); } template < typename NativeT, typename std::enable_if::value>::type* = nullptr> Status HandleClamp(HloInstruction* clamp) { - std::function clamp_op = - [](ReturnT low, ReturnT value, ReturnT high) { + std::function + clamp_op = [](ElementwiseT low, ElementwiseT value, ElementwiseT high) { return std::fmax(low, std::fmin(value, high)); }; - TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], - ElementWiseTernaryOp(clamp, std::move(clamp_op))); + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[clamp], + ElementwiseTernaryOp(clamp, + std::move(ConvertTernaryFunction(clamp_op)))); return Status::OK(); } @@ -749,7 +835,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleClamp(HloInstruction* clamp) override { - return HandleClamp(clamp); + return HandleClamp(clamp); } Status HandleSelect(HloInstruction* select) override { @@ -762,7 +848,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return on_false; }; TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], - ElementWiseTernaryOp(select, std::move(select_op))); + ElementwiseTernaryOp(select, std::move(select_op))); return Status::OK(); } @@ -780,7 +866,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { << " but is inferred to be: " << ShapeUtil::HumanString(inferred_return_shape); - auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto result = Literal::CreateFromShape(result_shape); TF_RETURN_IF_ERROR(result->Populate( @@ -860,7 +946,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size()); auto func = [&](tensorflow::gtl::ArraySlice out_index) { - ReturnT result_val = static_cast(0); + ElementwiseT result_val = static_cast(0); std::fill(lhs_index.begin(), lhs_index.end(), 0); std::fill(rhs_index.begin(), rhs_index.end(), 0); @@ -889,14 +975,21 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { out_index[output_spatial_dim] * window_dim.stride() - window_dim.padding_low() + rhs_spatial_index[ki] * window_dim.window_dilation(); - // Skip if the lhs (input) index is to be dilated. - if (undilated_index % window_dim.base_dilation() != 0) { + // Skip if the lhs (input) index is to be dilated. As an + // optimization, skip this mod if there's no dilation. + if (window_dim.base_dilation() > 1 && + undilated_index % window_dim.base_dilation() != 0) { goto cnt; } - // Calculate the actual lhs (input) index after dilation. - lhs_index[input_spatial_dim] = - undilated_index / window_dim.base_dilation(); + // Calculate the actual lhs (input) index after dilation. As an + // optimization, skip this integer divide if there's no dilation. + if (window_dim.base_dilation() > 1) { + lhs_index[input_spatial_dim] = + undilated_index / window_dim.base_dilation(); + } else { + lhs_index[input_spatial_dim] = undilated_index; + } // Skip if input index is not in bound. if (!(lhs_index[input_spatial_dim] >= 0 && @@ -911,13 +1004,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { : rhs_spatial_index[ki]; } - result_val += lhs_literal.Get(lhs_index) * - rhs_literal.Get(rhs_index); + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); } cnt : {} } while (IndexUtil::BumpIndices(window_shape, &rhs_spatial_index)); - return result_val; + return static_cast(result_val); }; auto result = Literal::CreateFromShape(result_shape); @@ -934,61 +1028,126 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { CHECK(ShapeUtil::IsArray(lhs->shape())); CHECK(ShapeUtil::IsArray(rhs->shape())); - // Dot only supports operands of rank 1 and 2. - const auto dot_rank = ShapeUtil::Rank(dot->shape()); + const auto& dnums = dot->dot_dimension_numbers(); + const auto lhs_rank = ShapeUtil::Rank(lhs->shape()); const auto rhs_rank = ShapeUtil::Rank(rhs->shape()); - CHECK(lhs_rank > 0 && lhs_rank <= 2); - CHECK(rhs_rank > 0 && rhs_rank <= 2); - CHECK_EQ(dot_rank, lhs_rank + rhs_rank - 2); CHECK(ShapeUtil::SameElementType(lhs->shape(), rhs->shape())); CHECK(ShapeUtil::SameElementType(lhs->shape(), dot->shape())); - // Check contracted dimensions are the same. - // - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - const int64 lhs_contracted_dimension = - ShapeUtil::GetDimensionNumber(lhs->shape(), -1); - const int64 rhs_contracted_dimension = 0; - CHECK_EQ(lhs->shape().dimensions(lhs_contracted_dimension), - rhs->shape().dimensions(rhs_contracted_dimension)) + // There must be 1 and only 1 Contracting dimension for lhs and rhs. + CHECK_EQ(dnums.lhs_contracting_dimensions_size(), 1); + CHECK_EQ(dnums.rhs_contracting_dimensions_size(), 1); + const int64 lhs_contracting_dimension = dnums.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = dnums.rhs_contracting_dimensions(0); + // Contracted dimension sizes must be the same. + CHECK_EQ(lhs->shape().dimensions(lhs_contracting_dimension), + rhs->shape().dimensions(rhs_contracting_dimension)) << "lhs contracted dimension: " - << lhs->shape().dimensions(lhs_contracted_dimension) + << lhs->shape().dimensions(lhs_contracting_dimension) << " rhs contracted dimension: " - << rhs->shape().dimensions(rhs_contracted_dimension); + << rhs->shape().dimensions(rhs_contracting_dimension); const int64 contracted_dimension_size = - lhs->shape().dimensions(lhs_contracted_dimension); + lhs->shape().dimensions(lhs_contracting_dimension); const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); auto result = Literal::CreateFromShape(dot->shape()); - TF_RETURN_IF_ERROR(result->Populate( - [&](tensorflow::gtl::ArraySlice multi_index) { - ReturnT result_val = static_cast(0); - std::vector lhs_index(lhs_rank, 0); - std::vector rhs_index(rhs_rank, 0); - // Set index for non-contracted dimension for lhs and rhs. - if (lhs_rank > 1) { - lhs_index[0] = multi_index[0]; + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); + + std::vector lhs_non_contracting_dims; + for (int64 i = 0; i < lhs_rank; i++) { + if (i != lhs_contracting_dimension) { + lhs_non_contracting_dims.push_back(i); + } + } + + std::vector rhs_non_batch_non_contracting_dims; + tensorflow::gtl::FlatSet batch_dims_set( + dnums.rhs_batch_dimensions().begin(), + dnums.rhs_batch_dimensions().end()); + for (int64 i = 0; i < rhs_rank; i++) { + if (i != rhs_contracting_dimension && batch_dims_set.count(i) == 0) { + rhs_non_batch_non_contracting_dims.push_back(i); + } + } + + const int64 batch_dim_size = dnums.lhs_batch_dimensions_size(); + const int64 lhs_non_contracting_size = lhs_non_contracting_dims.size(); + + DimensionVector lhs_index(lhs_rank); + DimensionVector rhs_index(rhs_rank); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice result_index) { + ElementwiseT result_val = static_cast(0); + + // Find the corresponding non-contracting indices for lhs and rhs. + // + // For `result_index`, its batch dimension, if exists, will be at the + // same dimension as the batch dimension of lhs and rhs. More + // specifically: + // - For lhs, the non-contracting dimensions, including the batch + // dimension have the same index as the `result_index`. + // - For rhs, the batch dimension is set seperately from other + // non-contracting dimensions, since these other non-contracting + // dimensions in rhs follow the non-contracting dimensions of lhs in + // the resulting index. + // + // As an example, for a resulting index: + // result_index [result_batch, result_x, result_y] + // the effecting lhs and rhs indices are: + // lhs [result_batch, lhs_non_contracting_dim, contracting_dim + // rhs [result_batch, contracting_dim, rhs_non_contracting_dim] + // `result_x` is only affected by the lhs_non_contracting_dim and + // likewise `result_y` only depends on rhs_non_contracting_dim. + // + // so we can look up the lhs and rhs indices by: + // + // lhs: + // batch index is the same as `result_batch`. + // non-contracting dimension is the same as + // result_index[lhs_non_contracting_dim] + // rhs: + // batch index: the same as `result_batch`. + // non-contracting dimension index: *not* the same as + // result_index[rhs_non_contractng_dim], since the + // non-contracting dimensions of lhs are included in the + // result_index first. Instead, the non_contracting_dim of rhs must + // be calculated as following: + // lhs_non_contracting_dimensions_size + + // (rhs_non_batch_non_contracting_dim - batch_dim_size) - 1 + // + // Note that (rhs_non_batch_contracting_dim - batch_dim_size) is + // the index offset to the result_index that only depends on + // the non_batch and non-contracting dimensions of rhs. -1 at the + // end translates size to index. + for (auto i : lhs_non_contracting_dims) { + lhs_index[i] = result_index[i]; + } + for (auto i : dnums.rhs_batch_dimensions()) { + rhs_index[i] = result_index[i]; } - if (rhs_rank > 1) { - rhs_index[1] = multi_index[multi_index.size() - 1]; + for (auto i : rhs_non_batch_non_contracting_dims) { + const int64 rhs_non_batch_non_contracting_dim = + lhs_non_contracting_size + (i - batch_dim_size) - 1; + rhs_index[i] = result_index[rhs_non_batch_non_contracting_dim]; } // Accumulates resulting product along the contracted dimension. for (int64 i = 0; i < contracted_dimension_size; ++i) { - lhs_index[lhs_contracted_dimension] = i; - rhs_index[rhs_contracted_dimension] = i; + lhs_index[lhs_contracting_dimension] = i; + rhs_index[rhs_contracting_dimension] = i; - result_val += lhs_literal.Get(lhs_index) * - rhs_literal.Get(rhs_index); + result_val += + static_cast(lhs_literal.Get(lhs_index)) * + static_cast(rhs_literal.Get(rhs_index)); } - return result_val; + return static_cast(result_val); })); parent_->evaluated_[dot] = std::move(result); @@ -1021,7 +1180,8 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return scalar; })); - auto evaluated_operand = parent_->GetEvaluatedLiteralFor(pad->operand(0)); + const Literal& evaluated_operand = + parent_->GetEvaluatedLiteralFor(pad->operand(0)); std::vector input_index(ShapeUtil::Rank(evaluated_operand.shape()), 0); @@ -1174,6 +1334,97 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template + StatusOr> MapImpl(HloInstruction* map) { + auto operands = map->operands(); + HloComputation* computation = map->to_apply(); + + auto result = Literal::CreateFromShape(map->shape()); + + HloEvaluator embedded_evaluator; + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + std::vector> arg_literals; + arg_literals.reserve(operands.size()); + + // Construct scalar literal parameters to be passed to the map + // computation. + for (auto operand : operands) { + const Literal& arg_literal = + parent_->GetEvaluatedLiteralFor(operand); + + auto curr_val = arg_literal.Get(multi_index); + auto curr_val_literal = Literal::CreateR0(curr_val); + + arg_literals.push_back(std::move(curr_val_literal)); + } + + std::unique_ptr computed_result = + embedded_evaluator + .Evaluate>(*computation, + arg_literals) + .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); + + return computed_result->Get({}); + })); + return std::move(result); + } + + Status HandleMap(HloInstruction* map) override { + switch (map->operand(0)->shape().element_type()) { + case PRED: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case U64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S8: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case S64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F32: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case F64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + case C64: { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[map], MapImpl(map)); + break; + } + default: + LOG(FATAL) << "HandleMap: unhandled primitive type for " + "input operand: " + << PrimitiveType_Name( + map->operand(0)->shape().element_type()); + } + + return Status::OK(); + } + Status HandleReduce(HloInstruction* reduce) override { auto arg = reduce->operand(0); auto init_value = reduce->operand(1); @@ -1220,6 +1471,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } } + HloEvaluator embedded_evaluator; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { @@ -1239,13 +1491,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::vector args = {curr_val_literal.get(), result_val_literal.get()}; - // We need a new visitor for each evaluation, so that the same - // computation can be visited more than once (with different - // inputs). - HloEvaluator embedded_evaluator; std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) + embedded_evaluator.Evaluate(*function, args) .ConsumeValueOrDie(); + // Clear visit states so that the we can use the evaluate again on + // the same computation. + embedded_evaluator.ResetVisitStates(); // Assign computed result to result_val. result_val = computed_result->Get({}); @@ -1263,6 +1514,111 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + auto operand = select_and_scatter->operand(0); + auto source = select_and_scatter->operand(1); + const Window& window = select_and_scatter->window(); + + const Literal& init_literal = + parent_->GetEvaluatedLiteralFor(select_and_scatter->operand(2)); + TF_RET_CHECK(ShapeUtil::IsScalar(init_literal.shape())); + auto init_scalar = init_literal.Get({}); + + auto result = Literal::CreateFromShape(select_and_scatter->shape()); + + // Initialize result array with the init value. + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice output_index) { + return init_scalar; + })); + + std::vector window_dimension_sizes; + for (const auto& window_dimension : window.dimensions()) { + window_dimension_sizes.push_back(window_dimension.size()); + } + const Shape window_shape = ShapeUtil::MakeShape( + operand->shape().element_type(), window_dimension_sizes); + + HloComputation* select = select_and_scatter->select(); + HloComputation* scatter = select_and_scatter->scatter(); + + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& source_literal = parent_->GetEvaluatedLiteralFor(source); + + int64 rank = ShapeUtil::Rank(operand_literal.shape()); + + HloEvaluator embedded_evaluator; + DimensionVector source_index(rank); + + std::fill(source_index.begin(), source_index.end(), 0); + do { + // For each element in `source`, we place a window in `operand`. For each + // window placement, we iterate inside the window twice: + // + // 1. Find the selected index by applying `select` function to all + // elements. E.g., If the `select` function is GreaterEqual, the first + // iteration through the window finds the biggest value and returns its + // index. + // + // 2. Using the selected index, scatter value from `source` to result. We + // do this by iterating through the window, and compare each index with + // the selected index. + tensorflow::gtl::optional selected_val; + tensorflow::gtl::optional> selected_index; + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + if (!selected_val) { + selected_val = curr_val; + selected_index = operand_index; + } + const auto curr_val_literal = Literal::CreateR0(curr_val); + const auto selected_val_literal = + Literal::CreateR0(*selected_val); + + const std::vector args = { + curr_val_literal.get(), selected_val_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*select, args) + .ConsumeValueOrDie(); + bool selected = computed_result->Get({}); + if (selected) { + selected_val = curr_val; + selected_index = operand_index; + } + embedded_evaluator.ResetVisitStates(); + }); + + IterateThroughWindow( + window_shape, window, operand_literal.shape(), source_index, + [&](const std::vector& operand_index) { + if (std::equal(operand_index.begin(), operand_index.end(), + selected_index->begin())) { + auto source = source_literal.Get(source_index); + auto scattered = result->Get(operand_index); + const auto source_literal = Literal::CreateR0(source); + const auto scattered_literal = + Literal::CreateR0(scattered); + + const std::vector args = { + source_literal.get(), scattered_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*scatter, args) + .ConsumeValueOrDie(); + result->Set(operand_index, computed_result->Get({})); + // Clear visit states so that the we can use the evaluator again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + } + }); + } while (IndexUtil::BumpIndices(source->shape(), &source_index)); + + parent_->evaluated_[select_and_scatter] = std::move(result); + return Status::OK(); + } + Status HandleReduceWindow(HloInstruction* reduce_window) override { auto operand = reduce_window->operand(0); const Window& window = reduce_window->window(); @@ -1302,6 +1658,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { DimensionVector window_index(window.dimensions_size()); DimensionVector operand_index(ShapeUtil::Rank(operand_literal.shape())); + HloEvaluator embedded_evaluator; // For each resulting dimension, calculate and assign computed value. TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice output_index) { @@ -1310,39 +1667,28 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::fill(window_index.begin(), window_index.end(), 0); std::fill(operand_index.begin(), operand_index.end(), 0); - do { - // Set curr_val to 0 if out of bound (padded). - ReturnT curr_val = static_cast(0); - bool out_of_bound = false; - for (int i = 0; i < operand_index.size(); ++i) { - operand_index[i] = - output_index[i] * window.dimensions(i).stride() + - window_index[i] - window.dimensions(i).padding_low(); - if (operand_index[i] < 0 || - operand_index[i] >= operand_literal.shape().dimensions(i)) { - out_of_bound = true; - break; - } - } - if (!out_of_bound) { - curr_val = operand_literal.Get(operand_index); - } - // Evaluate computation with specified literal operands. - const auto curr_val_literal = Literal::CreateR0(curr_val); - const auto result_val_literal = - Literal::CreateR0(result_val); - const std::vector args = {curr_val_literal.get(), - result_val_literal.get()}; - // We need a new visitor for each evaluation, so that the same - // computation can be visited more than once (with different - // inputs). - HloEvaluator embedded_evaluator; - std::unique_ptr computed_result = - embedded_evaluator.Evaluate(*function, args) - .ConsumeValueOrDie(); - - result_val = computed_result->Get({}); - } while (IndexUtil::BumpIndices(window_shape, &window_index)); + IterateThroughWindow( + window_shape, window, operand_literal.shape(), output_index, + [&](const std::vector& operand_index) { + auto curr_val = operand_literal.Get(operand_index); + + // Evaluate computation with specified literal operands. + const auto curr_val_literal = + Literal::CreateR0(curr_val); + const auto result_val_literal = + Literal::CreateR0(result_val); + const std::vector args = { + curr_val_literal.get(), result_val_literal.get()}; + std::unique_ptr computed_result = + embedded_evaluator.Evaluate(*function, args) + .ConsumeValueOrDie(); + + // Clear visit states so that the we can use the evaluate again + // on the same computation. + embedded_evaluator.ResetVisitStates(); + + result_val = computed_result->Get({}); + }); return result_val; })); @@ -1364,7 +1710,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { << ShapeUtil::HumanString(inferred_return_shape); const int64 rank = ShapeUtil::Rank(operand->shape()); - auto operand_literal = parent_->GetEvaluatedLiteralFor(operand); + const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(operand); auto func = [&](tensorflow::gtl::ArraySlice out_index) { DimensionVector operand_index(rank); for (int64 i = 0; i < rank; ++i) { @@ -1385,7 +1731,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { NativeT>::value>::type* = nullptr> Status HandleSin(HloInstruction* sin) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sin], - ElementWiseUnaryOp(sin, [](ReturnT elem_operand) { + ElementWiseUnaryOp(sin, [](ElementwiseT elem_operand) { return std::sin(elem_operand); })); return Status::OK(); @@ -1400,14 +1746,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleSin(HloInstruction* sin) override { - return HandleSin(sin); + return HandleSin(sin); } template ::value>::type* = nullptr> Status HandleCos(HloInstruction* cos) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[cos], - ElementWiseUnaryOp(cos, [](ReturnT elem_operand) { + ElementWiseUnaryOp(cos, [](ElementwiseT elem_operand) { return std::cos(elem_operand); })); return Status::OK(); @@ -1422,7 +1768,116 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } Status HandleCos(HloInstruction* cos) override { - return HandleCos(cos); + return HandleCos(cos); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[reduce_precision], + ElementWiseUnaryOp(reduce_precision, [reduce_precision]( + ElementwiseT elem) { + uint32_t value_as_int = tensorflow::bit_cast(elem); + const uint32_t mantissa_bits = reduce_precision->mantissa_bits(); + const uint32_t exponent_bits = reduce_precision->exponent_bits(); + + // Code is based on the CPU/GPU implementation in LLVM-emitting code. + // + // Bits in float type: + // mantissa : bits [0:22] + // exponent : bits [23:30] + // sign : bits [31] + if (mantissa_bits < 23) { + const uint32_t last_mantissa_bit_mask = 1u << (23 - mantissa_bits); + + // Compute rounding bias for round-to-nearest with ties to even. + // This is equal to a base value of 0111... plus one bit if the last + // remaining mantissa bit is 1. + const uint32_t base_rounding_bias = + (last_mantissa_bit_mask >> 1) - 1; + const uint32_t x_last_mantissa_bit = + (value_as_int & last_mantissa_bit_mask) >> (23 - mantissa_bits); + const uint32_t x_rounding_bias = + x_last_mantissa_bit + base_rounding_bias; + + // Add rounding bias, and mask out truncated bits. Note that the + // case where adding the rounding bias overflows into the exponent + // bits is correct; the non-masked mantissa bits will all be zero, + // and the exponent will be incremented by one. + const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1); + value_as_int = value_as_int + x_rounding_bias; + value_as_int = value_as_int & truncation_mask; + } + if (exponent_bits < 8) { + // Masks for f32 values. + const uint32_t f32_sign_bit_mask = 1u << 31; + const uint32_t f32_exp_bits_mask = 0xffu << 23; + + // An exponent of 2^(n-1)-1 -- that is, 0111... with the zero in the + // most- significant bit -- is equal to 1.0f for all exponent sizes. + // Adding 2^(n-1)-1 to this gives us the highest non-infinite + // exponent for a bit- size of n, and subtracting 2^(n-1)-1 from + // this gives us the lowest' exponent (corresponding to 0.0f). + // + // Thus, the f32 exponent corresponding to the highest non-infinite + // exponent for a bit size of n is (2^7-1) + 2^(n-1)-1, and the f32 + // exponent corresponding to the lowest exponent for a bit size of n + // is (2^7-1) - 2^(n-1)-1. + // + // Note that we have already checked that exponents_bits >= 1. + const uint32_t f32_exponent_bias = (1 << 7) - 1; + const uint32_t reduced_exponent_bias = + (1 << (exponent_bits - 1)) - 1; + const uint32_t reduced_max_exponent = + f32_exponent_bias + reduced_exponent_bias; + const uint32_t reduced_min_exponent = + f32_exponent_bias - reduced_exponent_bias; + + // Do we overflow or underflow? + const uint32_t x_exponent = value_as_int & f32_exp_bits_mask; + const bool x_overflows = x_exponent > (reduced_max_exponent << 23); + const bool x_underflows = + x_exponent <= (reduced_min_exponent << 23); + + // Compute appropriately-signed values of zero and infinity. + const uint32_t x_signed_zero = value_as_int & f32_sign_bit_mask; + const uint32_t x_signed_inf = x_signed_zero | f32_exp_bits_mask; + + // Force to zero or infinity if overflow or underflow. (Note that + // this truncates all denormal values to zero, rather than rounding + // them.) + value_as_int = x_overflows ? x_signed_inf : value_as_int; + value_as_int = x_underflows ? x_signed_zero : value_as_int; + } + + float reduced_result = tensorflow::bit_cast(value_as_int); + if (std::isnan(elem)) { + reduced_result = mantissa_bits > 0 + ? elem + : std::numeric_limits::infinity(); + } + return reduced_result; + })); + return Status::OK(); + } + + template ::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Double not supported for reduce precision"); + } + + template < + typename NativeT, + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleReducePrecision(HloInstruction* reduce_precision) { + return InvalidArgument("Unsupported type for reduce precision"); + } + + Status HandleReducePrecision(HloInstruction* reduce_precision) override { + return HandleReducePrecision(reduce_precision); } private: @@ -1430,8 +1885,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> DynamicSlice( const Literal& operand_literal, const Literal& start_indices_literal, const Shape& result_shape) { - const auto& start_indices_typed = - start_indices_literal.GetArraySlice(); + auto start_indices_typed = start_indices_literal.data(); std::vector start(start_indices_typed.begin(), start_indices_typed.end()); @@ -1459,12 +1913,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> DynamicUpdateSlice( const Literal& operand_literal, const Literal& update_literal, const Literal& start_indices_literal) { - const auto& start_indices_typed = - start_indices_literal.GetArraySlice(); + auto start_indices_typed = start_indices_literal.data(); const std::vector start(start_indices_typed.begin(), start_indices_typed.end()); - auto result = MakeUnique(operand_literal); + auto result = operand_literal.CloneToUnique(); std::vector result_index(ShapeUtil::Rank(result->shape()), 0); auto func = [&](const std::vector& update_index) { @@ -1487,22 +1940,27 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { StatusOr> ElementWiseUnaryOp( HloInstruction* instruction, - const std::function& unary_op) { + const std::function& unary_op) { const Literal& operand_literal = parent_->GetEvaluatedLiteralFor(instruction->operand(0)); - return ElementWiseUnaryOpImpl(instruction, unary_op, - operand_literal); + TF_ASSIGN_OR_RETURN( + auto result_literal, + (ElementWiseUnaryOpImpl( + instruction, ConvertUnaryFunction(unary_op), operand_literal))); + + return std::move(result_literal); } StatusOr> ElementWiseBinaryOp( HloInstruction* instruction, - const std::function& binary_op) { + const std::function& + binary_op) { const auto shape = instruction->shape(); const auto* lhs = instruction->operand(0); const auto* rhs = instruction->operand(1); - // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is - // removed. + // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast + // is removed. if (!(ShapeUtil::SameDimensions(shape, rhs->shape()) && ShapeUtil::SameDimensions(lhs->shape(), rhs->shape()))) { return Unimplemented( @@ -1520,14 +1978,15 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_RETURN_IF_ERROR(result->Populate( [&](tensorflow::gtl::ArraySlice multi_index) { - return binary_op(lhs_literal.Get(multi_index), - rhs_literal.Get(multi_index)); + return ConvertBinaryFunction(binary_op)( + lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); })); return std::move(result); } template - StatusOr> ElementWiseTernaryOp( + StatusOr> ElementwiseTernaryOp( HloInstruction* instruction, const std::function& ternary_op) { const auto shape = instruction->shape(); @@ -1589,9 +2048,11 @@ HloEvaluator::HloEvaluator() { typed_visitors_[F64] = MakeUnique>(this); typed_visitors_[C64] = MakeUnique>(this); - typed_visitors_[BF16] = MakeUnique([](HloInstruction*) { - return Unimplemented("HloEvaluator: unhandled primitive type: BF16."); - }); + // Most of the evaluator computations we use don't support BF16 (e.g., + // std::ceil, std::tanh). To make evaluator work with BF16, we set all + // elementwise computations to be done in F32 and do BF16<->F32 conversion + // around the input and the output of the computations. + typed_visitors_[BF16] = MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); @@ -1600,41 +2061,53 @@ HloEvaluator::HloEvaluator() { }); } +template StatusOr> HloEvaluator::Evaluate( const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals) { + tensorflow::gtl::ArraySlice arg_literals) { XLA_VLOG_LINES(2, "HloEvaluator::Evaluate module:\n" + module.ToString()); - arg_literals_ = arg_literals; evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } TF_RETURN_IF_ERROR(module.entry_computation()->Accept(this)); - return MakeUnique( - GetEvaluatedLiteralFor(module.entry_computation()->root_instruction())); + return GetEvaluatedLiteralFor(module.entry_computation()->root_instruction()) + .CloneToUnique(); } +template StatusOr> HloEvaluator::Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals) { + tensorflow::gtl::ArraySlice arg_literals) { XLA_VLOG_LINES( 2, "HloEvaluator::Evaluate computation:\n" + computation.ToString()); - arg_literals_ = arg_literals; + evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } TF_RETURN_IF_ERROR(computation.Accept(this)); - return MakeUnique( - GetEvaluatedLiteralFor(computation.root_instruction())); + return GetEvaluatedLiteralFor(computation.root_instruction()).CloneToUnique(); } +template StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction, - tensorflow::gtl::ArraySlice operands) { + tensorflow::gtl::ArraySlice arg_literals) { TF_RET_CHECK(hlo_query::AllOperandsAreParametersOrConstants(*instruction)); TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); - arg_literals_ = operands; evaluated_.clear(); + arg_literals_.clear(); + for (const auto& literal_ptr : arg_literals) { + arg_literals_.push_back(&*literal_ptr); + } // Evaluate operands of Parameter type against the input literals which // caches the evaluated literal results. @@ -1645,14 +2118,14 @@ StatusOr> HloEvaluator::Evaluate( << input_literal->ToString(); TF_RET_CHECK(ShapeUtil::Equal(operand->shape(), input_literal->shape())); - evaluated_[operand] = MakeUnique(*input_literal); + evaluated_[operand] = input_literal->CloneToUnique(); } } TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return MakeUnique(GetEvaluatedLiteralFor(instruction)); + return GetEvaluatedLiteralFor(instruction).CloneToUnique(); } StatusOr> HloEvaluator::Evaluate( @@ -1673,7 +2146,7 @@ StatusOr> HloEvaluator::Evaluate( TF_RETURN_IF_ERROR(Preprocess(instruction)); TF_RETURN_IF_ERROR(instruction->Visit(this)); TF_RETURN_IF_ERROR(Postprocess(instruction)); - return MakeUnique(GetEvaluatedLiteralFor(instruction)); + return GetEvaluatedLiteralFor(instruction).CloneToUnique(); } std::unique_ptr HloEvaluator::TryEvaluate( @@ -1722,11 +2195,15 @@ StatusOr> HloEvaluator::EvaluateWithSubstitutions( } Status HloEvaluator::HandleParameter(HloInstruction* parameter) { + CHECK_LT(parameter->parameter_number(), arg_literals_.size()); const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); - DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())); + DCHECK(ShapeUtil::Equal(parameter->shape(), input_literal->shape())) + << "parameter shape is: " << ShapeUtil::HumanString(parameter->shape()) + << ", but input literal shape is: " + << ShapeUtil::HumanString(input_literal->shape()); - evaluated_[parameter] = MakeUnique(*input_literal); + evaluated_[parameter] = input_literal->CloneToUnique(); return Status::OK(); } @@ -1749,8 +2226,8 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { tensorflow::gtl::ArraySlice operands( concatenate->operands()); - // The result concatenate dimension is going to be the sum of all concatenate - // dimensions of the operands taking part of the operation. + // The result concatenate dimension is going to be the sum of all + // concatenate dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); CHECK(!ShapeUtil::IsTuple(reference_shape)); const int64 rank = ShapeUtil::Rank(reference_shape); @@ -1777,7 +2254,7 @@ Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { for (auto operand : operands) { const Shape& operand_shape = operand->shape(); - TF_RETURN_IF_ERROR(result_literal->Copy( + TF_RETURN_IF_ERROR(result_literal->CopySliceFrom( GetEvaluatedLiteralFor(operand), source_indices, dest_indices, AsInt64Slice(operand_shape.dimensions()))); dest_indices[concat_dim] += @@ -1935,16 +2412,17 @@ Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const Literal& operand_tuple_literal = GetEvaluatedLiteralFor(operand); - evaluated_[get_tuple_element] = - MakeUnique(operand_tuple_literal.tuple_literals(index)); - - return Status::OK(); + evaluated_[get_tuple_element] = MakeUnique( + ShapeUtil::GetTupleElementShape(operand->shape(), index)); + return evaluated_[get_tuple_element]->CopyFrom(operand_tuple_literal, + /*dest_shape_index=*/{}, + /*src_shape_index=*/{index}); } Status HloEvaluator::HandleCopy(HloInstruction* copy) { TF_RET_CHECK(ShapeUtil::Compatible(copy->shape(), copy->operand(0)->shape())); - auto result = MakeUnique(GetEvaluatedLiteralFor(copy->operand(0))); + auto result = GetEvaluatedLiteralFor(copy->operand(0)).CloneToUnique(); evaluated_[copy] = std::move(result); return Status::OK(); } @@ -1960,4 +2438,30 @@ Status HloEvaluator::Postprocess(HloInstruction* hlo) { return Status::OK(); } +// Explicit instantiation of templatized Evaluate* methods. +// +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloModule& module, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + const HloModule& module, + tensorflow::gtl::ArraySlice> arg_literals); + +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(const HloComputation& computation, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + const HloComputation& computation, + tensorflow::gtl::ArraySlice> arg_literals); + +template StatusOr> HloEvaluator::Evaluate< + const Literal*>(HloInstruction* instruction, + tensorflow::gtl::ArraySlice arg_literals); +template StatusOr> +HloEvaluator::Evaluate>( + HloInstruction* instruction, + tensorflow::gtl::ArraySlice> arg_literals); + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 7557aaa2484d184555411a79d8dce2c9241427b0..3b2b697e492a78a06a4e5ae6bf056ff8676f2ff5 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ #include @@ -42,9 +42,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Precondition: The indices of arg_literals correspond to the parameter // numbers of the HLO parameters in the computation. See comment below for an // example. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( const HloModule& module, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates an HLO computation and an array of pointers to literals. // Returns the evaluated result as a literal if successful. @@ -62,9 +65,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // where Parameter0 has parameter_number 0 and Parameter1 has parameter_number // 1 in this computation. The input literals array will then have its first // literal map to Parameter0 and the second map to Parameter1. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( const HloComputation& computation, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates a single HLO instruction and an array of pointers to literals. // Return the evaluated result as literal if successful. @@ -72,10 +78,12 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // 1. argument literals correspond to the input instruction's parameters in // their post-ordering. // 2. the instruction's operands must be of either Parameter or Constant type. - // TODO(b/35950897): implement more ops other than element-wise ops. + // `LiteralPtr` accepts either std::unique_ptr or const Literal* + // type. + template StatusOr> Evaluate( HloInstruction* instruction, - tensorflow::gtl::ArraySlice arg_literals); + tensorflow::gtl::ArraySlice arg_literals); // Evaluates a single HLO instruction with constant operands. // Returns the evaluated result as literal if successful. @@ -100,12 +108,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting // literal type of each evaluated Handle* method of a TypedVisitor. - // There are however a few notable exceptions to this is rule, notably: + // There are however a few notable exceptions to this rule, notably: // - HandleCompare and HandleIsFinite: where the resulting literal type is // always boolean. // These operations are handled outside of the parent HloEvaluator handlers // instead of from within TypedVisitor. - template + // + // Type params: + // - ReturnT: The type of input and output of each operation. + // - ElementwiseT: The type in which internal computation are done. + template class TypedVisitor; // Wraps around instruction handling to infer types before dispatching to @@ -134,6 +146,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault { Status HandleIsFinite(HloInstruction* is_finite) override; Status HandleCompare(HloInstruction* compare) override; + Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; @@ -167,17 +180,19 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // TODO(b/35950897): have better memory management here to free instructions // that are no longer a parent for any other subsequent instruction in // post-orderring. + // Must be cleared for each evaluation. tensorflow::gtl::FlatMap> evaluated_; - // Stores input literals, assuming they are in post-order. Literals are not - // owned by this class, and they must outlive the lifetime of the instance of - // this class. - tensorflow::gtl::ArraySlice arg_literals_; + // Caches pointers to input literals, assuming they are in post-order. + // Literals are not owned by this class, and they must outlive the lifetime of + // each invocation to the Evaluate* method. + // Must be cleared for each evaluation. + std::vector arg_literals_; TF_DISALLOW_COPY_AND_ASSIGN(HloEvaluator); }; } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_EVALUATOR_H_ diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index b2c4351896764fa8683e91396f526d97ba208df6..97765d65909cee192f65069777f8f195081603b2 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -25,8 +25,10 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_element_type_converter.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" @@ -35,15 +37,33 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class HloEvaluatorTest : public HloVerifiedTestBase { +static std::array use_bf16_params{true, false}; + +class HloEvaluatorTest : public ::testing::WithParamInterface, + public HloVerifiedTestBase { protected: - HloEvaluatorTest() { evaluator_ = MakeUnique(); } + HloEvaluatorTest() : use_bfloat16_(GetParam()) { + evaluator_ = MakeUnique(); + } + + std::unique_ptr Evaluate( + tensorflow::gtl::ArraySlice arg_literals = {}) { + if (use_bfloat16_) { + // In BF16 mode, we convert all F32 type to BF16 and evaluate the module. + auto type_converter = HloElementTypeConverter(F32, BF16); + type_converter.Run(&module()).ValueOrDie(); + } + return evaluator_->Evaluate(*module().entry_computation(), arg_literals) + .ConsumeValueOrDie(); + } std::unique_ptr evaluator_; @@ -52,12 +72,11 @@ class HloEvaluatorTest : public HloVerifiedTestBase { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(input))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateUnary(expected->shape(), opcode, c1)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto element_type = expected->shape().element_type(); if (element_type == F32 || element_type == F64) { @@ -74,20 +93,24 @@ class HloEvaluatorTest : public HloVerifiedTestBase { HloComputation::Builder b(TestName()); auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateBinary(expected->shape(), opcode, c1, c2)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*expected, *result); } + + bool use_bfloat16_; }; +#define XLA_TYPED_TEST_P(test_case_name, test_name, test_type1) \ + TEST_P(test_case_name, test_name) + // Verifies that HloEvaluator evaluates a HLO instruction that performs clamp // with 3 operands. -TEST_F(HloEvaluatorTest, DoesClamp) { +TEST_P(HloEvaluatorTest, DoesClamp) { auto low = Literal::CreateR2({{0.f, 2.f}, {2.f, 4.f}}); auto value = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); auto high = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); @@ -97,19 +120,18 @@ TEST_F(HloEvaluatorTest, DoesClamp) { auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{0, 4}, {2, 4}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { +TEST_P(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto low = Literal::CreateR0(0.f); auto value = Literal::CreateR2({{-1.f, 0.f}, {1.f, 2.f}}); auto high = Literal::CreateR0(1.f); @@ -119,12 +141,11 @@ TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { auto c1 = b.AddInstruction(HloInstruction::CreateConstant(std::move(low))); auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(high))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{0, 0}, {1, 1}}); @@ -133,7 +154,7 @@ TEST_F(HloEvaluatorTest, DISABLED_DoesClampSpecialBroadcast) { // Verifies that HloEvaluator evaluates a HLO instruction that performs select // with 3 operands. -TEST_F(HloEvaluatorTest, DoesSelect) { +TEST_P(HloEvaluatorTest, DoesSelect) { auto pred = Literal::CreateR2({{true, false}, {false, true}}); auto on_true = Literal::CreateR2({{2.f, 4.f}, {4.f, 4.f}}); auto on_false = Literal::CreateR2({{0.f, 5.f}, {0.f, 4.f}}); @@ -145,12 +166,11 @@ TEST_F(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_true))); auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); - auto instruction = b.AddInstruction( + b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); auto expected = Literal::CreateR2({{2, 5}, {0, 4}}); @@ -159,7 +179,7 @@ TEST_F(HloEvaluatorTest, DoesSelect) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise addition with 2 operands. -TEST_F(HloEvaluatorTest, DoesAdd) { +TEST_P(HloEvaluatorTest, DoesAdd) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{3, 4}, {-96, 8}}); @@ -168,7 +188,7 @@ TEST_F(HloEvaluatorTest, DoesAdd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise and with 2 operands. -TEST_F(HloEvaluatorTest, DoesAnd) { +TEST_P(HloEvaluatorTest, DoesAnd) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{0, 0}, {4, 4}}); @@ -177,7 +197,7 @@ TEST_F(HloEvaluatorTest, DoesAnd) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise or with 2 operands. -TEST_F(HloEvaluatorTest, DoesOr) { +TEST_P(HloEvaluatorTest, DoesOr) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{3, 4}, {-100, 4}}); @@ -186,7 +206,7 @@ TEST_F(HloEvaluatorTest, DoesOr) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise multiply with 2 operands. -TEST_F(HloEvaluatorTest, DoesMultiply) { +TEST_P(HloEvaluatorTest, DoesMultiply) { auto lhs = Literal::CreateR2({{-1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2( {{std::numeric_limits::min(), 4}, {4, 4}}); @@ -197,14 +217,14 @@ TEST_F(HloEvaluatorTest, DoesMultiply) { } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_F(HloEvaluatorTest, DoesDivideInt64) { +TEST_P(HloEvaluatorTest, DoesDivideInt64) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); TestBinaryOp(HloOpcode::kDivide, std::move(expected), std::move(lhs), std::move(rhs)); } -TEST_F(HloEvaluatorTest, DoesDivideDouble) { +TEST_P(HloEvaluatorTest, DoesDivideDouble) { auto lhs = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); auto rhs = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); auto expected = @@ -215,40 +235,41 @@ TEST_F(HloEvaluatorTest, DoesDivideDouble) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_F(HloEvaluatorTest, DoesAbsR2) { +TEST_P(HloEvaluatorTest, DoesAbsR2) { auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesAbsR0) { +TEST_P(HloEvaluatorTest, DoesAbsR0) { auto operand = Literal::CreateR0(-1.0f); auto expected = Literal::CreateR0(1.0f); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) { +TEST_P(HloEvaluatorTest, DoesAbsR1WithZeroSize) { auto operand = Literal::CreateR1({}); auto expected = Literal::CreateR1({}); TestUnaryOp(HloOpcode::kAbs, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesNegateR2) { +TEST_P(HloEvaluatorTest, DoesNegateR2) { auto operand = Literal::CreateR2( {{0, std::numeric_limits::min()}, {-1, 4}}); auto expected = Literal::CreateR2({{0, std::numeric_limits::min()}, {1, -4}}); TestUnaryOp(HloOpcode::kNegate, std::move(expected), std::move(operand)); } -TEST_F(HloEvaluatorTest, DoesCosR2) { +TEST_P(HloEvaluatorTest, DoesCosR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{1, -1}, {-1, 1}}); - TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand)); + TestUnaryOp(HloOpcode::kCos, std::move(expected), std::move(operand), + use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); } -TEST_F(HloEvaluatorTest, DoesSinR2) { +TEST_P(HloEvaluatorTest, DoesSinR2) { auto operand = Literal::CreateR2({{0, M_PI}, {-M_PI, 2 * M_PI}}); auto expected = Literal::CreateR2({{0, 0}, {0, 0}}); TestUnaryOp(HloOpcode::kSin, std::move(expected), std::move(operand), - 0x1.0P-20); + use_bfloat16_ ? 0x1.0P-5 : 0x1.0P-20); } -TEST_F(HloEvaluatorTest, DoesNotR2) { +TEST_P(HloEvaluatorTest, DoesNotR2) { auto operand = Literal::CreateR2({{0, std::numeric_limits::min()}, {-1, std::numeric_limits::max()}}); @@ -259,7 +280,7 @@ TEST_F(HloEvaluatorTest, DoesNotR2) { } // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. -TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { +TEST_P(HloEvaluatorTest, DoesTraverseInstructions) { auto lhs = Literal::CreateR2({{1, 0}, {-100, 4}}); auto rhs = Literal::CreateR2({{2, 4}, {4, 4}}); auto rhs2 = Literal::CreateR2({{1, -20}, {-100, 4}}); @@ -279,10 +300,9 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, args).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(args); auto expected = Literal::CreateR2({{4, -16}, {-196, 12}}); @@ -290,7 +310,7 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { } // Verifies Reshape operation is correctly evaluated. -TEST_F(HloEvaluatorTest, DoesReshape) { +TEST_P(HloEvaluatorTest, DoesReshape) { HloComputation::Builder b(TestName()); const int64 dimensions[] = {11, 8, 7, 5, 9}; TF_ASSERT_OK_AND_ASSIGN(auto literal, @@ -304,21 +324,20 @@ TEST_F(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); using NativeT = typename primitive_util::PrimitiveTypeToNative::type; result->EachCell( [&](tensorflow::gtl::ArraySlice indices, NativeT value) { std::vector rindexes = Permute(permutation, indices); - EXPECT_TRUE(value == literal_clone->Get(rindexes)); + EXPECT_NEAR(value, literal_clone->Get(rindexes), 0x1.0P-5); }); } // Verifies Broadcast operation is correctly evaluated. -TEST_F(HloEvaluatorTest, DoesBroadcast) { +TEST_P(HloEvaluatorTest, DoesBroadcast) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); auto output_literal = Literal::CreateR3( @@ -327,15 +346,14 @@ TEST_F(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, {1, 2})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); LiteralTestUtil::ExpectEqual(*result, *output_literal); } -TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { +TEST_P(HloEvaluatorTest, DoesBroadcastScalar) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR0(111); auto output_literal = Literal::CreateR2( @@ -347,15 +365,14 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, /*broadcast_dimensions=*/{})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate({}); LiteralTestUtil::ExpectEqual(*result, *output_literal); } -TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { +TEST_P(HloEvaluatorTest, DoesConcatenateSimple) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction(HloInstruction::CreateConstant( @@ -368,17 +385,16 @@ TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{-1, -2}, {100, 200}, {-2, -3}, {-100, -200}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { +TEST_P(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { HloComputation::Builder b(TestName()); HloInstruction* operand1 = b.AddInstruction( @@ -391,16 +407,15 @@ TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({100, 200}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { +TEST_P(HloEvaluatorTest, ConvertWithSameLayout) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2({{1, 2}, {3, 4}, {5, 6}}); @@ -412,15 +427,14 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*result, *expected); } -TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { +TEST_P(HloEvaluatorTest, ConvertWithDifferentLayout) { HloComputation::Builder b(TestName()); auto input_literal = Literal::CreateR2WithLayout( @@ -433,10 +447,9 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); LiteralTestUtil::ExpectEqual(*result, *expected); } @@ -454,7 +467,7 @@ PaddingConfig CreatePaddingConfig( return padding_config; } -TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { +TEST_P(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto operand = Literal::CreateR2({{}, {}}); HloComputation::Builder b(TestName()); auto operand_instruction = @@ -467,11 +480,11 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { auto padding_config = CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}}); Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); - auto pad_instruction = b.AddInstruction(HloInstruction::CreatePad( + b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); module().AddEntryComputation(b.Build()); - auto result = evaluator_->Evaluate(pad_instruction).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2( {{10, 10}, {10, 10}, {10, 10}, {10, 10}, {10, 10}}); @@ -479,7 +492,7 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { HloComputation::Builder b(TestName()); Array4D input_array(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); @@ -496,10 +509,9 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(8, 5, 1, 1); expected_array->Fill(kPadValue); @@ -515,7 +527,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, NegativePadding2D) { +TEST_P(HloEvaluatorTest, NegativePadding2D) { HloComputation::Builder b(TestName()); // input_array: @@ -541,10 +553,9 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // f32[1,5] { 7.0, 2.718, 2.718, 2.718, 2.718 } auto expected_array = MakeUnique>(1, 5); @@ -555,10 +566,10 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { (*expected_array)(0, 4) = 2.718f; auto expected = Literal::CreateR2FromArray2D(*expected_array); - LiteralTestUtil::ExpectEqual(*expected, *result); + LiteralTestUtil::ExpectNear(*expected, *result, ErrorSpec(0x1.0P-5)); } -TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { +TEST_P(HloEvaluatorTest, NegativeAndInteriorPadding2D) { HloComputation::Builder b(TestName()); // f32[4,3] { @@ -587,10 +598,9 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = MakeUnique>(0, 9); auto expected = Literal::CreateR2FromArray2D(*expected_array); @@ -598,7 +608,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank2AndRank1) { +TEST_P(HloEvaluatorTest, DotRank2AndRank1) { HloComputation::Builder b(TestName()); // lhs: @@ -621,12 +631,14 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off auto expected_array = Array2D({ @@ -641,7 +653,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank1AndRank2) { +TEST_P(HloEvaluatorTest, DotRank1AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -664,19 +676,21 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({22.f, 28.f}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DotRank2AndRank2) { +TEST_P(HloEvaluatorTest, DotRank2AndRank2) { HloComputation::Builder b(TestName()); // lhs: @@ -705,12 +719,14 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal))); Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); - b.AddInstruction(HloInstruction::CreateBinary( - shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - auto computation = module().AddEntryComputation(b.Build()); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction, + rhs_instruction, dot_dnums)); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected_array = Array2D({ {22.f, 28.f}, @@ -723,7 +739,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SimpleConv1D) { +TEST_P(HloEvaluatorTest, SimpleConv1D) { HloComputation::Builder b(TestName()); Array3D lhs_array = {{{1, 2, 3}}}; @@ -761,10 +777,9 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array3D expected_array = {{{11.f, 18.f, 9.f}}}; auto expected = Literal::CreateR3FromArray3D(expected_array); @@ -772,7 +787,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { +TEST_P(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -816,10 +831,9 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 4, 4); // clang-format off @@ -835,7 +849,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { +TEST_P(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { HloComputation::Builder b(TestName()); // clang-format off @@ -900,21 +914,22 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensionsReversed) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); + Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = Literal::CreateR4FromArray4D( + use_bfloat16_ ? expected_array_bf16 : expected_array); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { +TEST_P(HloEvaluatorTest, Conv2DGeneralDimensions) { HloComputation::Builder b(TestName()); // clang-format off @@ -976,21 +991,22 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off // Result dimensions: [feature=1, height=1, batch=1, width=2] Array4D expected_array({{{{2514, 2685}}}}); + Array4D expected_array_bf16({{{{2512, 2672}}}}); // clang-format on - auto expected = Literal::CreateR4FromArray4D(expected_array); + auto expected = Literal::CreateR4FromArray4D( + use_bfloat16_ ? expected_array_bf16 : expected_array); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { +TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1034,10 +1050,9 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 7, 7); expected_array.FillWithYX(Array2D({ @@ -1054,7 +1069,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { +TEST_P(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { HloComputation::Builder b(TestName()); Array4D lhs_array(1, 1, 4, 4); @@ -1098,10 +1113,9 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 8, 8); expected_array.FillWithYX(Array2D({ @@ -1119,7 +1133,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, +TEST_P(HloEvaluatorTest, DilatedWindowAndBaseConv2DWithDifferentLowAndHighPaddingAndStrides) { HloComputation::Builder b(TestName()); @@ -1170,10 +1184,9 @@ TEST_F(HloEvaluatorTest, const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); Array4D expected_array(1, 1, 9, 3); expected_array.FillWithYX(Array2D({ @@ -1192,7 +1205,7 @@ TEST_F(HloEvaluatorTest, LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceAdd) { +TEST_P(HloEvaluatorTest, ReduceAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1225,17 +1238,16 @@ TEST_F(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR1({6, 18}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowMax) { +TEST_P(HloEvaluatorTest, ReduceWindowMax) { HloComputation::Builder b(TestName()); // arg: @@ -1278,15 +1290,15 @@ TEST_F(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{6, 7}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowAdd) { +TEST_P(HloEvaluatorTest, ReduceWindowAdd) { HloComputation::Builder b(TestName()); // arg: @@ -1335,21 +1347,21 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({{1, 3, 5}, {5, 11, 13}}); LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { +TEST_P(HloEvaluatorTest, ReduceWindowAdd6D) { HloComputation::Builder b(TestName()); // arg: f32[4,4,4,4,4,4] full of ones. Using small dims to limit run-time. std::vector input_dims(6, 4); std::unique_ptr arg_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); + Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); HloInstruction* arg_instruction = b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); @@ -1396,17 +1408,17 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); std::vector output_dims = {4, 3, 3, 3, 4, 4}; std::unique_ptr result_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 8.0f); + Literal::CreateFullWithDescendingLayout(output_dims, 8.0f); LiteralTestUtil::ExpectEqual(*result_literal, *result); } -TEST_F(HloEvaluatorTest, StridedSlice) { +TEST_P(HloEvaluatorTest, StridedSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1427,10 +1439,9 @@ TEST_F(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {3}, @@ -1440,7 +1451,7 @@ TEST_F(HloEvaluatorTest, StridedSlice) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DynamicSlice) { +TEST_P(HloEvaluatorTest, DynamicSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1461,10 +1472,9 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {2, 3, 4}, @@ -1476,7 +1486,7 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { // Verifies that the HloEvaluator's implementation goes along with existing // backends' behavior, although this is not required by the spec. -TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { +TEST_P(HloEvaluatorTest, DynamicSliceModSlice) { HloComputation::Builder b(TestName()); // arg: @@ -1497,10 +1507,9 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {2, 3, 4}, @@ -1510,7 +1519,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { +TEST_P(HloEvaluatorTest, DynamicSliceUpdate) { HloComputation::Builder b(TestName()); // arg: @@ -1534,10 +1543,9 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {1, -2, -3}, @@ -1547,7 +1555,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SetAndGetTuples) { +TEST_P(HloEvaluatorTest, SetAndGetTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1570,9 +1578,9 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto expected = Literal::CreateR2({ {1, 2, 3}, @@ -1582,7 +1590,7 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { +TEST_P(HloEvaluatorTest, SetAndGetNestedTuples) { HloComputation::Builder b(TestName()); // arg: @@ -1609,9 +1617,9 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - auto computation = module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = Evaluate(); auto result_inner_literal = Literal::CreateR2FromArray2D(*operand_array); @@ -1623,7 +1631,7 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, Reverse) { +TEST_P(HloEvaluatorTest, Reverse) { HloComputation::Builder b(TestName()); // Input shape is float[4x3x2x1]. @@ -1649,10 +1657,9 @@ TEST_F(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - auto computation = module().AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); - std::unique_ptr result = - evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); + std::unique_ptr result = Evaluate(); // clang-format off auto expected = Literal::CreateR4FromArray4D({ @@ -1677,7 +1684,7 @@ TEST_F(HloEvaluatorTest, Reverse) { LiteralTestUtil::ExpectEqual(*expected, *result); } -TEST_F(HloEvaluatorTest, EvaluateWithSubstitutions) { +TEST_P(HloEvaluatorTest, EvaluateWithSubstitutions) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1700,7 +1707,7 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutions) { // Check that EvaluateWithSubstitutions works if one of the operands to the op // we're evaluating is a constant. -TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { +TEST_P(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { HloComputation::Builder b(TestName()); Shape shape = ShapeUtil::MakeShape(F32, {4}); @@ -1722,5 +1729,8 @@ TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { *result.ValueOrDie()); } +INSTANTIATE_TEST_CASE_P(HloEvaluatorTest_Instantiation, HloEvaluatorTest, + ::testing::ValuesIn(use_bf16_params)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index ba75e2ef1b485f015f3b8f8dbd76f214d6ab0130..f0df93b61d29c1535d8a89fbd65e669de5b43729 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -32,7 +32,7 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { InsertOrDie(&computation_to_profile_idx_, computation, current_profile_index++); for (const HloInstruction* instruction : computation->instructions()) { - // For simplicity we track all instrutions here, but we could skip + // For simplicity we track all instructions here, but we could skip // non-executing instructions like constants and parameters. InsertOrDie(&instruction_to_profile_idx_, instruction, current_profile_index++); @@ -40,82 +40,75 @@ HloProfileIndexMap::HloProfileIndexMap(const HloModule& module) { } } -std::unique_ptr CreateHloProfilePrinter( +std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, const HloCostAnalysis& cost_analysis) { - using HloComputationInfo = HloProfilePrinter::HloComputationInfo; - using HloInstructionInfo = HloProfilePrinter::HloInstructionInfo; - - HloComputationInfo* computation_infos = - new HloComputationInfo[hlo_profile_index_map.computation_count()]; - - // There are two "indices" in play here. The first one is the index of the - // HloComputationInfo or HloInstructionInfo in the array that contains said - // HloComputationInfo or HloInstructionInfo. The second index is the index of - // the HloComputationInfo or HloInstructionInfo in the profile counters array, - // as decided by hlo_profile_index_map. The latter index is always referred - // to as "profile_index". - - size_t computation_index_in_static_data = 0; - size_t max_profile_index = hlo_profile_index_map.total_count(); - for (const auto& pair : hlo_profile_index_map.computation_to_profile_idx()) { - CHECK_LT(pair.second, max_profile_index); + using HloComputationInfo = HloProfilePrinterData::HloComputationInfo; + using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo; + + size_t profile_counters_size = hlo_profile_index_map.total_count(); + + std::unique_ptr profile_printer_data = + MakeUnique(); + profile_printer_data->set_profile_counters_size(profile_counters_size); + profile_printer_data->mutable_computation_infos()->Reserve( + hlo_profile_index_map.computation_count()); + + const auto& computation_to_profile_idx_map = + hlo_profile_index_map.computation_to_profile_idx(); + + // computation_to_profile_idx_map's order is not deterministic so create a + // deterministic computation_and_profile_idx_list so that we end up with a + // deterministic HloProfilePrinterData protobuf. + + std::vector> + computation_and_profile_idx_list(computation_to_profile_idx_map.begin(), + computation_to_profile_idx_map.end()); + + // The profile indices were computed deterministically in + // HloProfileIndexMap::HloProfileIndexMap. + c_sort(computation_and_profile_idx_list, + [](const std::pair& left, + const std::pair& right) { + return left.second < right.second; + }); + + for (const auto& pair : computation_and_profile_idx_list) { + CHECK_LT(pair.second, profile_counters_size); const HloComputation* computation = pair.first; - size_t current_computation_index = computation_index_in_static_data++; HloComputationInfo* computation_info = - &computation_infos[current_computation_index]; + profile_printer_data->add_computation_infos(); - computation_info->name = strdup(computation->name().c_str()); - computation_info->profile_index = pair.second; - computation_info->instructions = - new HloInstructionInfo[computation->instruction_count()]; - computation_info->instructions_size = computation->instruction_count(); + computation_info->set_name(computation->name()); + computation_info->set_profile_index(pair.second); + computation_info->mutable_instruction_infos()->Reserve( + computation->instruction_count()); - size_t instruction_index_in_static_data = 0; for (const HloInstruction* hlo : computation->instructions()) { - HloProfilePrinter::HloInstructionInfo* instruction_info = - &computation_info->instructions[instruction_index_in_static_data++]; - instruction_info->long_name = strdup(hlo->ToString().c_str()); - instruction_info->short_name = - strdup(hlo->ToString(/*compact_operands=*/true).c_str()); - instruction_info->category = strdup(hlo->ToCategory().c_str()); - instruction_info->flop_count = cost_analysis.flop_count(*hlo); - instruction_info->transcendental_count = - cost_analysis.transcendental_count(*hlo); - instruction_info->bytes_accessed = cost_analysis.bytes_accessed(*hlo); - instruction_info->optimal_seconds = cost_analysis.optimal_seconds(*hlo); - instruction_info->profile_index = - hlo_profile_index_map.GetProfileIndexFor(*hlo); - CHECK_LT(instruction_info->profile_index, max_profile_index); + HloInstructionInfo* instruction_info = + computation_info->add_instruction_infos(); + instruction_info->set_long_name(hlo->ToString()); + instruction_info->set_short_name( + hlo->ToString(HloPrintOptions().set_compact_operands(true))); + instruction_info->set_category(hlo->ToCategory()); + instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); + instruction_info->set_transcendental_count( + cost_analysis.transcendental_count(*hlo)); + instruction_info->set_bytes_accessed(cost_analysis.bytes_accessed(*hlo)); + instruction_info->set_optimal_seconds( + cost_analysis.optimal_seconds(*hlo)); + instruction_info->set_profile_index( + hlo_profile_index_map.GetProfileIndexFor(*hlo)); } } - auto deleter = [](HloProfilePrinter::HloComputationInfo* computation_infos, - int64 computation_infos_size) { - for (int64 i = 0; i < computation_infos_size; i++) { - HloInstructionInfo* instruction_infos = computation_infos[i].instructions; - for (int64 j = 0; j < computation_infos[i].instructions_size; j++) { - // We can't make instruction_infos[j].long_name etc. non-const pointers - // since they may point into static storage, so we have a const_cast - // here. - free(const_cast(instruction_infos[j].long_name)); - free(const_cast(instruction_infos[j].short_name)); - free(const_cast(instruction_infos[j].category)); - } - delete[] instruction_infos; - free(const_cast(computation_infos[i].name)); - } - delete[] computation_infos; - }; - - return MakeUnique( - computation_infos, hlo_profile_index_map.computation_count(), deleter); + return profile_printer_data; } HloExecutionProfile::HloExecutionProfile( - const HloProfilePrinter* hlo_profile_printer, + const HloProfilePrinterData* hlo_profile_printer_data, const HloProfileIndexMap* hlo_profile_index_map) - : hlo_profile_printer_(*hlo_profile_printer), + : hlo_profile_printer_data_(*hlo_profile_printer_data), hlo_profile_index_map_(*hlo_profile_index_map), profile_counters_( /*count*/ hlo_profile_index_map_.total_count(), diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index 470fd4ce3c205d84152238f4b18daad77e403f68..6fb91b9bef9d1df82b8806ce79cc147823edeb3d 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -77,8 +77,8 @@ class HloProfileIndexMap { std::unordered_map computation_to_profile_idx_; }; -// Create an instance of `HloProfilePrinter` that owns its memory. -std::unique_ptr CreateHloProfilePrinter( +// Create an instance of `HloProfilePrinterData`. +std::unique_ptr CreateHloProfilePrinterData( const HloProfileIndexMap& hlo_profile_index_map, const HloCostAnalysis& cost_analysis); @@ -90,7 +90,7 @@ class HloExecutionProfile { public: using DeviceDescription = perftools::gputools::DeviceDescription; - HloExecutionProfile(const HloProfilePrinter* hlo_profile_printer, + HloExecutionProfile(const HloProfilePrinterData* hlo_profile_printer_data, const HloProfileIndexMap* hlo_profile_index_map); // Record how many cycles this HLO took to execute. @@ -117,17 +117,19 @@ class HloExecutionProfile { // debugging; e.g. emits cycle counts, execution time at the nominal device // frequency, and the effective throughput given the provided cost_analysis // for the operations in a given computation. Returns an empty string if it - // wasn't possible to generate a printable version. cost_analysis should be a - // clean analysis that can be used to visit the computation. + // wasn't possible to generate a printable version. string ToString(const DeviceDescription& device_description) const { - return hlo_profile_printer_.ToString(profile_counters_.data(), - device_description.clock_rate_ghz()); + return PrintHloProfile(hlo_profile_printer_data_, profile_counters_.data(), + device_description.clock_rate_ghz()); } std::vector* mutable_profile_counters() { return &profile_counters_; } + const std::vector& profile_counters() const { + return profile_counters_; + } private: - const HloProfilePrinter& hlo_profile_printer_; + const HloProfilePrinterData& hlo_profile_printer_data_; const HloProfileIndexMap& hlo_profile_index_map_; // Stores per-Hlo profile counters. This is the only thing that changes when diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc index b1e6729e2bccad4bdbe075a635d8a9b1ede6fecb..a0cb28246d3be541e798e85552436f64a3521f22 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile_test.cc @@ -73,8 +73,8 @@ TEST_F(HloExecutionProfileTest, Basic) { HloCostAnalysis cost_analysis(shape_size_function); HloProfileIndexMap profile_index_map(*hlo_module); - std::unique_ptr profile_printer = - CreateHloProfilePrinter(profile_index_map, cost_analysis); + std::unique_ptr profile_printer = + CreateHloProfilePrinterData(profile_index_map, cost_analysis); HloExecutionProfile execution_profile(profile_printer.get(), &profile_index_map); diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 84187d578346eafd5e32727a15f5eab9cc79feef..44fcd36370dcd0cf77601aa1cd2b92810947bd5f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/window_util.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/numbers.h" @@ -508,8 +509,17 @@ stylesheet=" // The "to_node" value may be a NULL, indicating that this points to the // "root" tag rather than a normal node. - int64 from_node_id = node_ids_.at(from_node); - int64 to_node_id = to_node ? node_ids_.at(to_node) : root_node_id_; + int64 from_node_id = + tensorflow::gtl::FindWithDefault(node_ids_, from_node, -1); + if (from_node_id == -1) { + LOG(FATAL) << from_node->name() << " was added to edges but not to nodes"; + } + int64 to_node_id = + to_node ? tensorflow::gtl::FindWithDefault(node_ids_, to_node, -1) + : root_node_id_; + if (to_node != nullptr && to_node_id == -1) { + LOG(FATAL) << to_node->name() << " was added to edges but not to nodes"; + } add_hover_css_rule("node", from_node_id, kBlue); add_hover_css_rule("node", to_node_id, kRed); @@ -653,12 +663,15 @@ string HloDotDumper::DumpComputation(const HloComputation* comp) { string HloDotDumper::DumpRootTag() { const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); - auto from_id = InstructionId(from); - if (!filter_.Show(from)) { + // We didn't display constants as separate nodes; so if the root is a + // constant, we don't add root tag or edge for it. + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { return ""; } + auto from_id = InstructionId(from); + // The ID of the root computation is otherwise unused, so it makes a good ID // to use for the root-tag node. However, the edge_ids_ map requires a // HloInstruction* pointer for the 'to' value, so we use a NULL value there @@ -784,7 +797,7 @@ string HloDotDumper::GetInstructionNodeInlinedOperands( // Otherwise, print e.g. "%constant.42 (s32[100])". string constant_name; - if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { + if (tensorflow::StringPiece(constant->name()).starts_with("constant")) { constant_name = constant->name(); } else { constant_name = StrCat("constant ", constant->name()); @@ -948,6 +961,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: + case HloOpcode::kFft: return kDarkBlue; case HloOpcode::kReducePrecision: return kRed; @@ -1000,7 +1014,7 @@ string HloDotDumper::GetInstructionNodeLabel(const HloInstruction* instr) { // The HLO instruction name contains usually the opcode, e.g. "%add.42" is // an add instruction. In this case we render just the name. if (tensorflow::StringPiece(instr->name()) - .starts_with(StrCat("%", HloOpcodeString(instr->opcode())))) { + .starts_with(HloOpcodeString(instr->opcode()))) { return Printf("%s", HtmlLikeStringSanitize(instr->name())); } string extended_opcode = @@ -1036,62 +1050,32 @@ string HloDotDumper::GetInstructionNodeMetadata(const HloInstruction* instr) { } string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { - string opcode_specific_info = [&]() -> string { - switch (instr->opcode()) { - case HloOpcode::kRng: - return RandomDistribution_Name(instr->random_distribution()); - case HloOpcode::kConvolution: - return StrCat( - HtmlLikeStringSanitize( - instr->ConvolutionDimensionNumbersToString()), - "
", - HtmlLikeStringSanitize(window_util::ToString(instr->window()))); - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: - case HloOpcode::kReduce: - return Printf("dims={%s}", Join(instr->dimensions(), ",")); - case HloOpcode::kGetTupleElement: - return Printf("index=%lld", instr->tuple_index()); - case HloOpcode::kBatchNormTraining: - case HloOpcode::kBatchNormGrad: - return Printf("feature_index=%lld", instr->feature_index()); - case HloOpcode::kCustomCall: - return Printf("custom_call_target=%s", instr->custom_call_target()); - case HloOpcode::kSlice: - return std::all_of(instr->slice_strides().begin(), - instr->slice_strides().end(), - [](int64 stride) { return stride == 1; }) - ? "" - : StrCat("stride=", VectorString(instr->slice_strides())); - case HloOpcode::kSend: - case HloOpcode::kSendDone: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: - return StrCat("channel_id=", instr->channel_id()); - default: - return ""; - } - }(); - std::vector lines; - if (!opcode_specific_info.empty()) { - lines.push_back(opcode_specific_info); - } - if (instr->has_sharding()) { - lines.push_back(StrCat("sharding=", instr->sharding().ToString())); + + // Get the instruction's extra attributes excluding the names of its + // subcomputations, since those are drawn explicitly in the graph. + for (const auto& line : instr->ExtraAttributesToString( + HloPrintOptions().set_print_subcomputation_references(false))) { + lines.push_back(HtmlLikeStringSanitize(line)); } + // Show the shape and layout of the instruction, unless it's an inlined fusion // node -- there the shape and layout is present in the output node. if (instr->opcode() != HloOpcode::kFusion || !ShouldShowFusionSubcomputation(instr)) { - string instr_shape = ShapeUtil::HumanString(instr->shape()); - - // Show layout of non-tuple shapes with more than one dimension. - if (LayoutUtil::HasLayout(instr->shape()) && - instr->shape().dimensions_size() > 1 && - !ShapeUtil::IsTuple(instr->shape())) { - StrAppend(&instr_shape, "{", - Join(instr->shape().layout().minor_to_major(), ","), "}"); + // Show layout of instructions with more than one dimension. Don't show + // layout on tuples or tensors with just one dimension (which only have one + // possible layout) to avoid visual noise. + bool shape_is_multidim = false; + ShapeUtil::ForEachSubshape(instr->shape(), + [&](const Shape& s, const ShapeIndex&) { + shape_is_multidim |= s.dimensions_size() > 1; + }); + string instr_shape; + if (instr->opcode() != HloOpcode::kTuple && shape_is_multidim) { + instr_shape = ShapeUtil::HumanStringWithLayout(instr->shape()); + } else { + instr_shape = ShapeUtil::HumanString(instr->shape()); } // Some instructions have giant tuples as their shapes, so truncate the @@ -1353,19 +1337,16 @@ string SaveGraph(const string& graph, file_extension = ".pbtxt"; break; } - string path = JoinPath( - dest_path, StrCat("hlo_graph_", output_num++, ".XXXXXX", file_extension)); + string path = JoinPath(dest_path, StrCat("hlo_graph_", output_num++, ".")); auto status = Status::OK(); - int fd = mkstemps(&path[0], file_extension.length()); - if (fd < 0) { + auto env = tensorflow::Env::Default(); + if (!env->CreateUniqueFileName(&path, file_extension)) { status = Status(tensorflow::error::Code::UNKNOWN, StrCat("Failed to create temporary file to dump HLO graph: ", strerror(errno))); } else { - status = - tensorflow::WriteStringToFile(tensorflow::Env::Default(), path, graph); - close(fd); + status = tensorflow::WriteStringToFile(env, path, graph); } if (!status.ok()) { LOG(WARNING) << "Saving HLO graph failed: " << status; @@ -1438,15 +1419,18 @@ void DumpText(const HloModule& module, const string& label, do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); TF_CHECK_OK(WriteStringToFile( - env, path, module.ToString(/*include_large_constants=*/true))); + env, path, + module.ToString(HloPrintOptions().set_print_large_constants(true)))); LOG(INFO) << "dumping module '" << module.name() << "' to " << path; } string MaybeDumpHloModule(const HloModule& module, const string& label, const HloExecutionProfile* profile) { - VLOG(2) << "MaybeDumpHloModule called on module " << module.name(); - string graph_url; const DebugOptions& debug_options = module.config().debug_options(); + VLOG(2) << "MaybeDumpHloModule called on module " << module.name() + << " with generate_hlo_graph regex \"" + << debug_options.xla_generate_hlo_graph() << "\""; + string graph_url; if (!debug_options.xla_generate_hlo_graph().empty() && RE2::PartialMatch(module.name(), debug_options.xla_generate_hlo_graph())) { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 8e1531c87f9c6e133e2d6763b046b1d5dcbcd09f..1f00aa41dc783f9e5657f5fa654884a31fae0fe7 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -117,5 +117,18 @@ TEST(HloGraphDumperTest, NestedFusion) { HasSubstr(inner_sum->name())); } +TEST(HloGraphDumperTest, Constant) { + HloComputation::Builder b("b"); + auto instruction = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-42))); + instruction->set_name("i_am_a_constant_root_instruction"); + HloModule m(TestName()); + HloComputation* root_computation = m.AddEntryComputation(b.Build()); + string graph = hlo_graph_dumper::DumpGraph( + *root_computation, /*label=*/"an_empty_graph", DebugOptions()); + EXPECT_THAT(graph, HasSubstr("an_empty_graph")); + EXPECT_THAT(graph, Not(HasSubstr("i_am_a_constant_root_instruction"))); +} + } // anonymous namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index c30c4326547bbeae4f7054974f0d3fade65e3382..0981f1f4fe57751d5b7059b4b08099385369e4b9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -101,10 +101,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->metadata_ = proto.metadata(); if (proto.has_literal()) { - instruction->literal_ = MakeUnique(proto.literal()); + TF_ASSIGN_OR_RETURN(instruction->literal_, + Literal::CreateFromProto(proto.literal())); } instruction->parameter_number_ = proto.parameter_number(); - instruction->parameter_name_ = proto.parameter_name(); instruction->tuple_index_ = proto.tuple_index(); for (int64 dimension : proto.dimensions()) { @@ -118,6 +118,10 @@ StatusOr> HloInstruction::CreateFromProto( MakeUnique( proto.convolution_dimension_numbers()); } + if (proto.has_dot_dimension_numbers()) { + instruction->dot_dimension_numbers_ = + MakeUnique(proto.dot_dimension_numbers()); + } for (const HloInstructionProto::SliceDimensions& slice_dimensions : proto.slice_dimensions()) { instruction->slice_starts_.push_back(slice_dimensions.start()); @@ -141,6 +145,10 @@ StatusOr> HloInstruction::CreateFromProto( instruction->infeed_config_ = proto.infeed_config(); instruction->custom_call_target_ = proto.custom_call_target(); instruction->outfeed_shape_ = proto.outfeed_shape(); + instruction->fft_type_ = proto.fft_type(); + for (int64 fft_len : proto.fft_length()) { + instruction->fft_length_.push_back(fft_len); + } return std::move(instruction); } @@ -150,7 +158,6 @@ StatusOr> HloInstruction::CreateFromProto( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kParameter, shape)); instruction->parameter_number_ = parameter_number; - instruction->parameter_name_ = name; instruction->name_ = name; return instruction; } @@ -160,8 +167,7 @@ StatusOr> HloInstruction::CreateFromProto( auto instruction = WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); - instruction->literal_.reset(new Literal); - instruction->literal_->append_u8s(tag); + instruction->literal_ = Literal::CreateR1U8(tag); return instruction; } @@ -332,6 +338,41 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, return instruction; } +/* static */ std::unique_ptr HloInstruction::CreateFft( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFft, shape)); + instruction->AppendOperand(operand); + instruction->fft_type_ = fft_type; + instruction->fft_length_.assign(fft_length.begin(), fft_length.end()); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers) { + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = + MakeUnique(dimension_numbers); + return instruction; +} + +/* static */ std::unique_ptr HloInstruction::CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs) { + CHECK_EQ(ShapeUtil::Rank(lhs->shape()), 2); + CHECK_EQ(ShapeUtil::Rank(rhs->shape()), 2); + + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape)); + instruction->AppendOperand(lhs); + instruction->AppendOperand(rhs); + instruction->dot_dimension_numbers_ = MakeUnique(); + instruction->dot_dimension_numbers_->add_lhs_contracting_dimensions(1); + instruction->dot_dimension_numbers_->add_rhs_contracting_dimensions(0); + return instruction; +} + /* static */ std::unique_ptr HloInstruction::CreateReducePrecision(const Shape& shape, HloInstruction* operand, @@ -346,12 +387,9 @@ HloInstruction::CreateReducePrecision(const Shape& shape, } /* static */ std::unique_ptr -HloInstruction::CreateCrossReplicaSum(const Shape& shape, - HloInstruction* operand) { - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); - instruction->AppendOperand(operand); - return instruction; +HloInstruction::CreateCrossReplicaSum( + const Shape& shape, tensorflow::gtl::ArraySlice operands) { + return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); } /* static */ std::unique_ptr HloInstruction::CreateInfeed( @@ -366,6 +404,9 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, tensorflow::StringPiece outfeed_config) { std::unique_ptr instruction = WrapUnique(new HloInstruction(HloOpcode::kOutfeed, ShapeUtil::MakeNil())); + CHECK(ShapeUtil::Compatible(operand->shape(), shape)) + << "Outfeed shape " << shape << " must be compatible with operand shape " + << operand->shape(); instruction->AppendOperand(operand); instruction->outfeed_config_ = outfeed_config.ToString(); instruction->outfeed_shape_ = shape; @@ -631,6 +672,58 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +/* static */ std::unique_ptr +HloInstruction::CreateBroadcastSequence( + const Shape& output_shape, HloInstruction* operand, + const std::function)>& + adder) { + CHECK(ShapeUtil::IsScalar(operand->shape()) || + ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); + Shape broadcast_shape = ShapeUtil::ChangeElementType( + output_shape, operand->shape().element_type()); + // Do explicit broadcast for scalar. + if (ShapeUtil::IsScalar(operand->shape())) { + auto broadcast = + HloInstruction::CreateBroadcast(broadcast_shape, operand, {}); + broadcast->set_metadata(operand->metadata()); + if (operand->has_sharding()) { + broadcast->set_sharding(operand->sharding()); + } + return broadcast; + } + // Do explicit broadcast for degenerate broadcast. + std::vector broadcast_dimensions; + std::vector reshaped_dimensions; + for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { + if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { + broadcast_dimensions.push_back(i); + reshaped_dimensions.push_back(operand->shape().dimensions(i)); + } else { + CHECK_EQ(operand->shape().dimensions(i), 1) + << "An explicit broadcast sequence requires the broadcasted " + "dimensions to be trivial; operand: " + << operand->ToString() << "; output_shape: " << output_shape; + } + } + // Eliminate the size one dimensions. + HloInstruction* reshaped_operand = adder(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(operand->shape().element_type(), + reshaped_dimensions), + operand)); + reshaped_operand->set_metadata(operand->metadata()); + if (operand->has_sharding()) { + reshaped_operand->set_sharding(operand->sharding()); + } + // Broadcast 'reshape' up to the larger size. + auto broadcast = HloInstruction::CreateBroadcast( + broadcast_shape, reshaped_operand, broadcast_dimensions); + broadcast->set_metadata(operand->metadata()); + if (operand->has_sharding()) { + broadcast->set_sharding(operand->sharding()); + } + return broadcast; +} + /* static */ std::unique_ptr HloInstruction::CreatePad( const Shape& shape, HloInstruction* operand, HloInstruction* padding_value, const PaddingConfig& padding_config) { @@ -670,10 +763,23 @@ HloInstruction::CreateSelectAndScatter( return instruction; } +// We put the fusion kind into the instruction's name for transpose-dot fusions, +// since those fusions are really just describing a type of dot rather than +// generating a novel computation. +static string FusionNodeName(HloInstruction::FusionKind fusion_kind) { + switch (fusion_kind) { + case HloInstruction::FusionKind::kTransposeDot: + return "dot_fusion"; + default: + return "fusion"; + } +} + /* static */ std::unique_ptr HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { auto instruction = WrapUnique(new HloInstruction(HloOpcode::kFusion, shape)); instruction->fusion_kind_ = fusion_kind; + instruction->name_ = FusionNodeName(fusion_kind); instruction->set_parent(fused_root->parent()); instruction->set_metadata(fused_root->metadata()); instruction->CloneAndFuseInternal(fused_root); @@ -689,23 +795,12 @@ HloInstruction::CreateSelectAndScatter( instruction->AppendOperand(operand); } instruction->fusion_kind_ = fusion_kind; + instruction->name_ = FusionNodeName(fusion_kind); instruction->called_computations_.push_back(fusion_computation); fusion_computation->SetFusionInstruction(instruction.get()); return instruction; } -/* static */ std::unique_ptr -HloInstruction::CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, HloInstruction* fused_root) { - std::unique_ptr fusion = - CreateFusion(shape, fusion_kind, fused_root); - fusion->window_ = MakeUnique(window); - fusion->convolution_dimension_numbers_ = - MakeUnique(conv_dnums); - return fusion; -} - void HloInstruction::MergeFusionInstruction( HloInstruction* instruction_to_merge) { CHECK_EQ(opcode_, HloOpcode::kFusion); @@ -985,6 +1080,7 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kSendDone: case HloOpcode::kRecv: case HloOpcode::kRecvDone: + case HloOpcode::kRng: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -1086,7 +1182,6 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -1138,9 +1233,16 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, *convolution_dimension_numbers_); break; - case HloOpcode::kCrossReplicaSum: + case HloOpcode::kDot: + CHECK_EQ(new_operands.size(), 2); + clone = CreateDot(shape, new_operands[0], new_operands[1], + *dot_dimension_numbers_); + break; + case HloOpcode::kFft: CHECK_EQ(new_operands.size(), 1); - clone = CreateCrossReplicaSum(shape, new_operands[0]); + return CreateFft(shape, new_operands[0], fft_type_, fft_length_); + case HloOpcode::kCrossReplicaSum: + clone = CreateCrossReplicaSum(shape, new_operands); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1215,7 +1317,7 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( clone = CloneFusionWithNewOperands(shape, new_operands, module); break; case HloOpcode::kParameter: - clone = CreateParameter(parameter_number_, shape, parameter_name_); + clone = CreateParameter(parameter_number_, shape, name_); break; case HloOpcode::kBatchNormTraining: CHECK_EQ(new_operands.size(), 3); @@ -1244,10 +1346,27 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( new_operands[4], epsilon(), feature_index()); break; case HloOpcode::kConditional: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: + CHECK_EQ(new_operands.size(), 3); + clone = CreateConditional(shape, new_operands[0], new_operands[1], + true_computation(), new_operands[2], + false_computation()); + break; case HloOpcode::kSend: + CHECK_EQ(new_operands.size(), 1); + clone = CreateSend(new_operands[0], channel_id()); + break; case HloOpcode::kSendDone: + CHECK_EQ(new_operands.size(), 1); + clone = CreateSendDone(new_operands[0]); + break; + case HloOpcode::kRecv: + CHECK_EQ(new_operands.size(), 0); + clone = CreateRecv(shape, channel_id()); + break; + case HloOpcode::kRecvDone: + CHECK_EQ(new_operands.size(), 1); + clone = CreateRecvDone(new_operands[0]); + break; case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1492,8 +1611,9 @@ bool HloInstruction::HasConstantOperand() const { bool HloInstruction::IdenticalSlowPath( const HloInstruction& other, - std::function - eq_computations) const { + const std::function& + eq_computations, + const std::function& eq_shapes) const { // Perform opcode specific checks. switch (opcode()) { // The result of these instructions only depend upon their opcode and @@ -1509,7 +1629,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: - case HloOpcode::kDot: case HloOpcode::kEq: case HloOpcode::kExp: case HloOpcode::kFloor: @@ -1542,8 +1661,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kTuple: return true; - // These opcodes have complex or special behavior so just return false. case HloOpcode::kFusion: + return fusion_kind() == other.fusion_kind() && + eq_computations(fused_instructions_computation(), + other.fused_instructions_computation()); + + // These opcodes have complex or special behavior so just return false. case HloOpcode::kRng: case HloOpcode::kTrace: case HloOpcode::kWhile: @@ -1553,7 +1676,7 @@ bool HloInstruction::IdenticalSlowPath( return parameter_number() == other.parameter_number() && // Check the shape too because `this` and `other` may be in // different HloComputations. - ShapeUtil::Compatible(shape(), other.shape()); + eq_shapes(shape(), other.shape()); case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: @@ -1582,6 +1705,15 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals( convolution_dimension_numbers(), other.convolution_dimension_numbers()); + // Check dot dimension numbers. + case HloOpcode::kDot: + return protobuf_util::ProtobufEquals(dot_dimension_numbers(), + other.dot_dimension_numbers()); + + // FFT has various types & lengths. + case HloOpcode::kFft: + return fft_type() == other.fft_type() && + fft_length() == other.fft_length(); // Reduction results are determined by the reduction dimension and the // reduction computation. @@ -1600,18 +1732,18 @@ bool HloInstruction::IdenticalSlowPath( protobuf_util::ProtobufEquals(window(), other.window()); case HloOpcode::kReshape: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); // Transpose result is determined by the final shape and the permutation. case HloOpcode::kTranspose: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); // Remaining instructions with special values. case HloOpcode::kBitcast: - return ShapeUtil::Equal(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kBroadcast: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dimensions() == other.dimensions(); case HloOpcode::kConcatenate: return dimensions() == other.dimensions(); @@ -1625,10 +1757,10 @@ bool HloInstruction::IdenticalSlowPath( slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; case HloOpcode::kDynamicSlice: - return ShapeUtil::Compatible(shape(), other.shape()) && + return eq_shapes(shape(), other.shape()) && dynamic_slice_sizes_ == other.dynamic_slice_sizes_; case HloOpcode::kDynamicUpdateSlice: - return ShapeUtil::Compatible(shape(), other.shape()); + return eq_shapes(shape(), other.shape()); case HloOpcode::kCall: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); @@ -1636,9 +1768,11 @@ bool HloInstruction::IdenticalSlowPath( return custom_call_target_ == other.custom_call_target_; case HloOpcode::kReverse: return dimensions() == other.dimensions(); + case HloOpcode::kConditional: + return eq_computations(true_computation(), other.true_computation()) && + eq_computations(false_computation(), other.false_computation()); // These opcodes are not yet supported. - case HloOpcode::kConditional: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: @@ -1671,7 +1805,8 @@ void HloInstruction::RemoveUser(HloInstruction* user) { Status HloInstruction::ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer) { - TF_RET_CHECK(ShapeUtil::Compatible(shape(), new_producer->shape())) + TF_RET_CHECK( + ShapeUtil::CompatibleIgnoringFpPrecision(shape(), new_producer->shape())) << "this shape: " << ShapeUtil::HumanString(shape()) << ", replacement shape: " << ShapeUtil::HumanString(new_producer->shape()); @@ -1694,8 +1829,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(operand_num >= 0); TF_RET_CHECK(operand_num < operand_count()); HloInstruction* old_operand = mutable_operand(operand_num); - TF_RET_CHECK( - ShapeUtil::Compatible(old_operand->shape(), new_operand->shape())) + TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), + new_operand->shape())) << old_operand->shape().ShortDebugString() << " is not compatible with " << new_operand->shape().ShortDebugString(); operands_[operand_num] = new_operand; @@ -1882,16 +2017,23 @@ string HloInstruction::SignatureString() const { return StrCat("(", operands, ") -> ", ShapeUtil::HumanString(shape())); } -string HloInstruction::ToString(bool compact_operands, bool include_metadata, - bool include_large_constants) const { +namespace { + +string PrintName(const string& name, const HloPrintOptions& options) { + return StrCat(options.print_percent() ? "%" : "", name); +} + +} // namespace + +string HloInstruction::ToString(const HloPrintOptions& options) const { string result = - StrCat("%", name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - HloOpcodeString(opcode()), "(", - OperandsToString(compact_operands, include_large_constants), ")"); - for (const string& extra : ExtraAttributesToString()) { + StrCat(PrintName(name(), options), " = ", + ShapeUtil::HumanStringWithLayout(shape()), " ", + HloOpcodeString(opcode()), "(", OperandsToString(options), ")"); + for (const string& extra : ExtraAttributesToString(options)) { StrAppend(&result, ", ", extra); } - if (include_metadata && + if (options.print_metadata() && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); @@ -1899,14 +2041,13 @@ string HloInstruction::ToString(bool compact_operands, bool include_metadata, return result; } -string HloInstruction::OperandsToString(bool compact, - bool include_large_constants) const { +string HloInstruction::OperandsToString(const HloPrintOptions& options) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. if ((!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) || - include_large_constants) { + options.print_large_constants()) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -1931,14 +2072,19 @@ string HloInstruction::OperandsToString(bool compact, } else { tensorflow::gtl::ArraySlice slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; - if (compact && slice.size() > kMaxOperandsToShowIfCompact) { + if (options.compact_operands() && + slice.size() > kMaxOperandsToShowIfCompact) { slice.remove_suffix(slice.size() - kMaxOperandsToShowIfCompact); } operands = Join(slice, ", ", [&](string* out, HloInstruction* operand) { - *out += ShapeUtil::HumanStringWithLayout(operand->shape()); - if (!compact) { - StrAppend(out, " %", operand->name()); + std::vector str; + if (options.print_operand_shape()) { + str.push_back(ShapeUtil::HumanStringWithLayout(operand->shape())); + } + if (!options.compact_operands()) { + str.push_back(PrintName(operand->name(), options)); } + StrAppend(out, Join(str, " ")); }); const int64 remaining = operands_.size() - slice.size(); if (slice.size() != operands_.size()) { @@ -1948,7 +2094,8 @@ string HloInstruction::OperandsToString(bool compact, return operands; } -std::vector HloInstruction::ExtraAttributesToString() const { +std::vector HloInstruction::ExtraAttributesToString( + const HloPrintOptions& options) const { std::vector extra; if (opcode() == HloOpcode::kFusion) { extra.push_back(StrCat("kind=", xla::ToString(fusion_kind()))); @@ -1990,23 +2137,42 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); } - - if (opcode() == HloOpcode::kWhile) { - extra.push_back(StrCat("condition=%", while_condition()->name())); - extra.push_back(StrCat("body=%", while_body()->name())); - } else if (opcode() == HloOpcode::kSelectAndScatter) { - extra.push_back(StrCat("select=%", select()->name())); - extra.push_back(StrCat("scatter=%", scatter()->name())); - } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || - opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { - extra.push_back(StrCat("to_apply=%", to_apply()->name())); - } else if (!called_computations().empty()) { - extra.push_back(StrCat( - "calls=", Join(called_computations(), ", ", - [](string* out, const HloComputation* computation) { - StrAppend(out, "%", computation->name()); - }))); + if (dot_dimension_numbers_ != nullptr) { + extra.push_back(DotDimensionNumbersToString()); + } + if (opcode() == HloOpcode::kFft) { + extra.push_back(StrCat("fft_type=", FftType_Name(fft_type()))); + extra.push_back(StrCat("fft_length={", Join(fft_length(), ","), "}")); + } + + if (options.print_subcomputation_references()) { + if (opcode() == HloOpcode::kWhile) { + extra.push_back( + StrCat("condition=", PrintName(while_condition()->name(), options))); + extra.push_back( + StrCat("body=", PrintName(while_body()->name(), options))); + } else if (opcode() == HloOpcode::kSelectAndScatter) { + extra.push_back(StrCat("select=", PrintName(select()->name(), options))); + extra.push_back( + StrCat("scatter=", PrintName(scatter()->name(), options))); + } else if (opcode() == HloOpcode::kConditional) { + extra.push_back(StrCat("true_computation=", + PrintName(true_computation()->name(), options))); + extra.push_back(StrCat("false_computation=", + PrintName(false_computation()->name(), options))); + } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || + opcode() == HloOpcode::kReduceWindow || + opcode() == HloOpcode::kReduce) { + extra.push_back( + StrCat("to_apply=", PrintName(to_apply()->name(), options))); + } else if (!called_computations().empty()) { + extra.push_back(StrCat( + "calls=", Join(called_computations(), ", ", + [&](string* out, const HloComputation* computation) { + StrAppend(out, + PrintName(computation->name(), options)); + }))); + } } if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || @@ -2023,8 +2189,9 @@ std::vector HloInstruction::ExtraAttributesToString() const { if (!control_predecessors_.empty()) { extra.push_back(StrCat("control-predecessors={", Join(control_predecessors_, ", ", - [](string* out, HloInstruction* pre) { - StrAppend(out, "%", pre->name()); + [&](string* out, HloInstruction* pre) { + StrAppend(out, + PrintName(pre->name(), options)); }), "}")); } @@ -2035,6 +2202,22 @@ std::vector HloInstruction::ExtraAttributesToString() const { extra.push_back( StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); } + if (opcode() == HloOpcode::kRng) { + extra.push_back( + StrCat("distribution=", RandomDistributionToString(distribution_))); + } + if (opcode() == HloOpcode::kReducePrecision) { + extra.push_back(StrCat("exponent_bits=", exponent_bits_)); + extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); + } + + // By contract, we print the custom call target even if + // !options.print_subcomputation_references(), because the call target is not + // an HloComputation. + if (opcode() == HloOpcode::kCustomCall) { + extra.push_back( + StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + } return extra; } @@ -2064,7 +2247,6 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_literal() = literal_->ToProto(); } proto.set_parameter_number(parameter_number_); - proto.set_parameter_name(parameter_name_); if (opcode() == HloOpcode::kFusion) { proto.set_fusion_kind(xla::ToString(fusion_kind())); *proto.mutable_fused_instructions_computation() = @@ -2086,6 +2268,9 @@ HloInstructionProto HloInstruction::ToProto() const { *proto.mutable_convolution_dimension_numbers() = *convolution_dimension_numbers_; } + if (dot_dimension_numbers_ != nullptr) { + *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; + } for (int i = 0; i < slice_starts_.size(); ++i) { auto* slice_dimension = proto.add_slice_dimensions(); slice_dimension->set_start(slice_starts_[i]); @@ -2110,6 +2295,10 @@ HloInstructionProto HloInstruction::ToProto() const { proto.set_infeed_config(infeed_config_); proto.set_custom_call_target(custom_call_target_); *proto.mutable_outfeed_shape() = outfeed_shape_; + proto.set_fft_type(fft_type_); + for (int64 fft_len : fft_length_) { + proto.add_fft_length(fft_len); + } return proto; } @@ -2131,42 +2320,27 @@ string HloInstruction::ToCategory() const { return category; } + // Give transpose-dot and backwards-conv fusions the categories "dot" and + // "convolution" so they match the categories of proper kDot and kConvolution + // ops. These fusion categories are really just a way of expressing a + // particular kind of dot or conv, so they should have the same category as a + // vanilla dot/conv. if (opcode() == HloOpcode::kFusion) { - if (operands().size() == 2) { - bool saw_rank_1 = false; - bool saw_higher_rank = false; - for (const auto* operand : operands()) { - if (!ShapeUtil::IsTuple(operand->shape())) { - saw_rank_1 |= ShapeUtil::Rank(operand->shape()) == 1; - saw_higher_rank |= ShapeUtil::Rank(operand->shape()) > 1; - } - } - if (saw_rank_1 && saw_higher_rank) { - return "rank-1-broadcast binary fusion"; - } - } switch (fusion_kind()) { case FusionKind::kLoop: - if (IsElementwise()) { - return "elementwise fusion"; - } else { - return "non-elementwise fusion"; - } + return "loop fusion"; case FusionKind::kInput: return "input fusion"; case FusionKind::kOutput: return "output fusion"; case FusionKind::kTransposeDot: - return "dot fusion"; - case FusionKind::kConvBackwardFilter: - case FusionKind::kConvBackwardInput: - return "convolution fusion"; + return "dot"; case FusionKind::kCustom: return "custom fusion"; } } - if (IsElementwise() && opcode() != HloOpcode::kFusion) { + if (IsElementwise()) { return "non-fusion elementwise"; } @@ -2182,7 +2356,7 @@ void HloInstruction::set_tracing(HloInstruction* trace_instruction) { string HloInstruction::TracingTag() const { CHECK_EQ(HloOpcode::kTrace, opcode()); CHECK(literal_ != nullptr); - return literal_->u8s_string(); + return literal_->GetR1U8AsString(); } bool HloInstruction::IsFused() const { return parent_->IsFusionComputation(); } @@ -2325,6 +2499,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { return visitor->HandleSelect(this); case HloOpcode::kConvolution: return visitor->HandleConvolution(this); + case HloOpcode::kFft: + return visitor->HandleFft(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); case HloOpcode::kTuple: @@ -2933,10 +3109,6 @@ string ToString(HloInstruction::FusionKind kind) { return "kOutput"; case HloInstruction::FusionKind::kTransposeDot: return "kTransposeDot"; - case HloInstruction::FusionKind::kConvBackwardFilter: - return "kConvBackwardFilter"; - case HloInstruction::FusionKind::kConvBackwardInput: - return "kConvBackwardInput"; case HloInstruction::FusionKind::kCustom: return "kCustom"; } @@ -2956,12 +3128,6 @@ StatusOr StringToFusionKind( if (kind_name == "kTransposeDot") { return HloInstruction::FusionKind::kTransposeDot; } - if (kind_name == "kConvBackwardFilter") { - return HloInstruction::FusionKind::kConvBackwardFilter; - } - if (kind_name == "kConvBackwardInput") { - return HloInstruction::FusionKind::kConvBackwardInput; - } if (kind_name == "kCustom") { return HloInstruction::FusionKind::kCustom; } @@ -3001,6 +3167,28 @@ string OpMetadataToString(const OpMetadata& metadata) { return Join(result, " "); } +string RandomDistributionToString(const RandomDistribution& distribution) { + return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); +} + +StatusOr StringToRandomDistribution(const string& name) { + static std::unordered_map* map = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -3047,10 +3235,39 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { result += "_"; append_dims(rhs_dims, operand(1)->shape()); result += "->"; - append_dims(output_dims, shape()); + + // A convolution can be represented as a kConvolution HLO or as a CustomCall + // that returns a tuple, the first element of which is the result of the + // convolution. + Shape this_shape = + ShapeUtil::IsTuple(shape()) ? shape().tuple_shapes(0) : shape(); + append_dims(output_dims, this_shape); return result; } +string HloInstruction::DotDimensionNumbersToString() const { + std::vector result; + if (dot_dimension_numbers_ == nullptr) { + return ""; + } + const DotDimensionNumbers& dnums = *dot_dimension_numbers_; + if (!dnums.lhs_batch_dimensions().empty()) { + result.push_back(StrCat("lhs_batch_dims={", + Join(dnums.lhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("lhs_contracting_dims={", + Join(dnums.lhs_contracting_dimensions(), ","), "}")); + + if (!dnums.rhs_batch_dimensions().empty()) { + result.push_back(StrCat("rhs_batch_dims={", + Join(dnums.rhs_batch_dimensions(), ","), "}")); + } + result.push_back(StrCat("rhs_contracting_dims={", + Join(dnums.rhs_contracting_dimensions(), ","), "}")); + + return Join(result, ", "); +} + bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index cda8b07c61e2b36a83184648f6f3744deeb86812..3170746157fbcfa7d0a7eaba6d226d46691105f9 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -25,6 +25,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -56,21 +57,119 @@ namespace xla { class HloComputation; class HloModule; +// A bunch of switches that control how the hlo text should be printed. +class HloPrintOptions { + public: + // Constructs the default print options: don't print large constants, don't + // compact operands, no indentation. + HloPrintOptions() + : print_large_constants_(false), + print_subcomputation_references_(true), + print_metadata_(true), + compact_operands_(false), + print_operand_shape_(true), + print_program_shape_(true), + print_percent_(true), + indent_amount_(0) {} + + static HloPrintOptions ShortParsable() { + return HloPrintOptions() + .set_print_large_constants(true) + .set_print_subcomputation_references(true) + .set_print_metadata(false) + .set_print_operand_shape(false) + .set_print_program_shape(false) + .set_print_percent(false); + } + + // If true, large constants will be printed out. + HloPrintOptions& set_print_large_constants(bool value) { + print_large_constants_ = value; + return *this; + } + + // If true, the names of subcomputations (e.g. a fusion node's fused + // computation) won't be printed. This makes the resulting text not parsable. + // + // A CustomCall's call target is printed even if + // print_subcomputation_references is false, because the call target isn't an + // HloComputation. + HloPrintOptions& set_print_subcomputation_references(bool value) { + print_subcomputation_references_ = value; + return *this; + } + + // If true, metatdata will be printed. + HloPrintOptions& set_print_metadata(bool value) { + print_metadata_ = value; + return *this; + } + + // If true, operands' shapes will be printed. + HloPrintOptions& set_print_operand_shape(bool value) { + print_operand_shape_ = value; + return *this; + } + + // If true, program shape of hlo computations will be printed. + HloPrintOptions& set_print_program_shape(bool value) { + print_program_shape_ = value; + return *this; + } + + // If true, names will be printed with prefix '%'. + HloPrintOptions& set_print_percent(bool value) { + print_percent_ = value; + return *this; + } + + // If true, only a part of operands will be printed out, and their names will + // be omitted (note that in this case the text will not be parsable). + HloPrintOptions& set_compact_operands(bool value) { + compact_operands_ = value; + return *this; + } + + // The indent of the hlo text block. + HloPrintOptions& set_indent_amount(int value) { + indent_amount_ = value; + return *this; + } + + bool print_large_constants() const { return print_large_constants_; } + bool print_subcomputation_references() const { + return print_subcomputation_references_; + } + bool print_metadata() const { return print_metadata_; } + bool compact_operands() const { return compact_operands_; } + bool print_operand_shape() const { return print_operand_shape_; } + bool print_program_shape() const { return print_program_shape_; } + bool print_percent() const { return print_percent_; } + int indent_amount() const { return indent_amount_; } + + private: + bool print_large_constants_; + bool print_subcomputation_references_; + bool print_metadata_; + bool compact_operands_; + bool print_operand_shape_; + bool print_program_shape_; + bool print_percent_; + int indent_amount_; +}; + // HLO instructions are the IR used by the high-level compiler. class HloInstruction { public: enum class FusionKind { - kLoop, // Fused into a loop. - kInput, // Op's input is fused into the op itself. - kOutput, // Op's output is fused into the op itself. - // REQUIRES: At least one operand buffer must be able - // to alias the output buffer. - kTransposeDot, // Fused into a dot with transposed operands. - kConvBackwardFilter, // Fused into a backward filter convolution. - kConvBackwardInput, // Fused into a backward input convolution. - - kCustom, // Custom category for backend-specific fusions that - // do not match any of the more specific ones. + kLoop, // Fused into a loop. + kInput, // Op's input is fused into the op itself. + kOutput, // Op's output is fused into the op itself. + // REQUIRES: At least one operand buffer must be able + // to alias the output buffer. + kTransposeDot, // Fused into a dot with transposed operands. + kCustom, // Custom category for backend-specific fusions that + // do not match any of the more specific ones. }; ~HloInstruction(); @@ -160,6 +259,23 @@ class HloInstruction { const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); + // Creates an FFT op, of the type indicated by fft_type. + static std::unique_ptr CreateFft( + const Shape& shape, HloInstruction* operand, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch + // dimensions specified in 'dimension_numbers'. + static std::unique_ptr CreateDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs, + const DotDimensionNumbers& dimension_numbers); + + // Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1 + // of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS + // and the RHS must be of rank 2. + static std::unique_ptr CreateCanonicalDot( + const Shape& shape, HloInstruction* lhs, HloInstruction* rhs); + // Creates a reduce-precision op, where operand is the data to reduce in // precision, and exponent_bits and mantissa_bits describe the precision to // reduce it to. @@ -169,7 +285,8 @@ class HloInstruction { // Creates a cross replica sum op. static std::unique_ptr CreateCrossReplicaSum( - const Shape& shape, HloInstruction* operand); + const Shape& shape, + tensorflow::gtl::ArraySlice operands); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. @@ -289,6 +406,20 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice broadcast_dimensions); + // Creates a sequence of instructions that performs an explicit broadcast of + // the operand to the target shape. + // + // Interior HLOs are passed to "adder", but the "root" HLO of the sequence is + // returned as a unique_ptr for API consistency with other factory methods in + // this interface. + // + // TODO(b/72173833) Ideally HloComputations would always be present, and so + // the adder being passed by the caller would not be necessary. + static std::unique_ptr CreateBroadcastSequence( + const Shape& output_shape, HloInstruction* operand, + const std::function)>& + adder); + // Creates a pad instruction, where the operand is padded on the edges and // between the elements with the given padding value. static std::unique_ptr CreatePad( @@ -332,14 +463,6 @@ class HloInstruction { tensorflow::gtl::ArraySlice operands, HloComputation* fusion_computation); - // Creates a fusion instruction that represents backward convolution. This is - // similar to CreateFusion, but with extra arguments indicating the window and - // dimemsion mapping of the backward convolution. - static std::unique_ptr CreateFusionForBackwardConvolution( - const Shape& shape, FusionKind fusion_kind, const Window& window, - const ConvolutionDimensionNumbers& conv_dnums, - HloInstruction* fused_root); - // Creates a call instruction that applies the given computation on the given // operands. "shape" is the resultant shape. static std::unique_ptr CreateCall( @@ -421,7 +544,7 @@ class HloInstruction { Status RemoveControlDependencyTo(HloInstruction* instruction); // Returns the set of control predecessors (successors) of this - // instruction. Control predecessors (sucessors) must execute before (after) + // instruction. Control predecessors (successors) must execute before (after) // the current instruction. const std::vector& control_predecessors() const { return control_predecessors_; @@ -431,28 +554,42 @@ class HloInstruction { } // Returns true if "other" performs the same computation as this instruction. - // Layout of the instructions' output array is not considered. bool Identical( const HloInstruction& other, - std::function + const std::function& eq_operands = std::equal_to(), - std::function - eq_computations = std::equal_to()) const { + const std::function& + eq_computations = std::equal_to(), + bool layout_sensitive = true) const { // An instruction is always identical to itself. if (this == &other) { return true; } - // Identical instruction must have the same opcode and identical operands. - // In general, there is no need to check shape because shape is inferred - // from the shape of the operands. - if (opcode() != other.opcode() || - !ContainersEqual(operands(), other.operands(), - std::move(eq_operands))) { + // Identical instruction must have the same opcode, shape, and identical + // operands. + if (opcode() != other.opcode()) { + return false; + } + using EqShapeFuncType = bool (*)(const Shape&, const Shape&); + EqShapeFuncType eq_shapes = + layout_sensitive ? ShapeUtil::Equal : ShapeUtil::Compatible; + if (!eq_shapes(shape(), other.shape())) { + return false; + } + if (operands().size() != other.operands().size()) { return false; } - return IdenticalSlowPath(other, eq_computations); + // Use an explicit loop rather than ContainerEquals, because copying around + // std::functions may be too expensive in some cases. + for (size_t i = 0; i < operands().size(); ++i) { + if (!eq_operands(operand(i), other.operand(i))) { + return false; + } + } + + return IdenticalSlowPath(other, eq_computations, eq_shapes); } // Returns whether the instruction has a constant operand. @@ -540,16 +677,6 @@ class HloInstruction { return parameter_number_; } - const string& parameter_name() const { - CHECK_EQ(HloOpcode::kParameter, opcode_); - return parameter_name_; - } - - void set_parameter_name(const string& str) { - CHECK_EQ(HloOpcode::kParameter, opcode_); - parameter_name_ = str; - } - // Returns the dimension sizes or numbers associated with this instruction. // // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, @@ -637,18 +764,20 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false, bool include_metadata = true, - bool include_large_constants = false) const; + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Components of the ToString() representation: // Returns a string representation of the operand list. - string OperandsToString(bool compact, bool include_large_constants) const; + string OperandsToString(const HloPrintOptions& options) const; // Returns string representation of op-specific attributes. - std::vector ExtraAttributesToString() const; - - string ToStringNoMetadata() const { return ToString(false, false); } + std::vector ExtraAttributesToString( + const HloPrintOptions& options) const; // As ToString, but returns a shorter string. string ToShortString() const; @@ -676,13 +805,15 @@ class HloInstruction { // Returns feature_index field associated with the instruction. The index // represents the index of the feature dimension. // - // Precondition: opcode() == HloOpcode::kBatchNormTraining + // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, + // or kBatchNormGrad. int64 feature_index() const { return feature_index_; } // Returns a epsilon value associated with the instruction. The is a small // number added to the variance to avoid divide-by-zero error. // - // Precondition: opcode() == HloOpcode::kBatchNormTraining + // Precondition: opcode() is one of kBatchNormTraining, kBatchNormInference, + // or kBatchNormGrad. float epsilon() const { return epsilon_; } // Returns the infeed configuration string. The infeed configuration includes @@ -749,8 +880,8 @@ class HloInstruction { // Returns true if this instruction is a fusion instruction that generates // multiple outputs. const bool IsMultiOutputFusion() const { - return (opcode() == HloOpcode::kFusion && - fused_expression_root()->opcode() == HloOpcode::kTuple); + return opcode() == HloOpcode::kFusion && + fused_expression_root()->opcode() == HloOpcode::kTuple; } FusionKind fusion_kind() const { @@ -856,6 +987,17 @@ class HloInstruction { } const std::vector& slice_strides() const { return slice_strides_; } + // Returns the flag that describes whether a slice must be lowered into an + // offset into the original operand. + bool IsInPlaceSlice() const { return is_in_place_slice_; } + + // Sets and returns the flag that describes whether a slice must be lowered + // into an offset into the original operand. + bool SetIsInPlaceSlice(bool value) { + is_in_place_slice_ = value; + return value; + } + // Returns the size of the slice in the given dimension for a dynamic // slice node. // @@ -905,16 +1047,45 @@ class HloInstruction { return *padding_config_; } - // Returns data on the dimension numbers used for a convolution - // operation. + // Returns data on the dimension numbers used for a convolution operation, + // which may be a kConvolution instruction or a kCustomCall that implements a + // convolution. const ConvolutionDimensionNumbers& convolution_dimension_numbers() const { CHECK(convolution_dimension_numbers_ != nullptr); return *convolution_dimension_numbers_; } + // Sets the convolution dimension numbers on this instruction. In general you + // shouldn't need to call this; instead, specify the convolution dimension + // numbers when you create the instruction. + void set_convolution_dimension_numbers( + const ConvolutionDimensionNumbers& dnums) { + convolution_dimension_numbers_ = + MakeUnique(dnums); + } + + FftType fft_type() const { + CHECK_EQ(HloOpcode::kFft, opcode_); + return fft_type_; + } + + const std::vector& fft_length() const { + CHECK_EQ(HloOpcode::kFft, opcode_); + return fft_length_; + } + // Returns the dump string of the convolution dimension numbers. string ConvolutionDimensionNumbersToString() const; + // Returns data on the dimension numbers used for a dot operation. + const DotDimensionNumbers& dot_dimension_numbers() const { + CHECK(dot_dimension_numbers_ != nullptr); + return *dot_dimension_numbers_; + } + + // Returns the dump string of the dot dimension numbers. + string DotDimensionNumbersToString() const; + // Returns the random distribution for this rng node. // // Precondition: opcode() == HloOpcode::kRng @@ -1006,10 +1177,9 @@ class HloInstruction { std::tuple, std::vector> ReshapeMerelyInsertsOrDeletes1SizedDimensions() const; - // Returns a string identifier for this instruction. If no string identifier - // has been explicitly set, then the identifier is the serialized pointer to - // this instruction. + // Gets/sets the string identifier for this instruction. const string& name() const { return name_; } + void set_name(tensorflow::StringPiece name) { name_ = name.ToString(); } // Use the given NameUniquer to select a unique name for the instruction based // on the instruction's existing name. @@ -1068,10 +1238,14 @@ class HloInstruction { class FusionReusesParamElements; // See comments on Identical(). + // eq_shapes() is used to check shapes for equality, and would normally be + // expected to be ShapeUtil::Equals or ShapeUtil::Compatible, depending on + // whether we want a layout-sensitive check or not. bool IdenticalSlowPath( const HloInstruction& other, - std::function - eq_computations) const; + const std::function& + eq_computations, + const std::function& eq_shapes) const; // Creates an n-ary elementwise operation. static std::unique_ptr CreateNary( @@ -1173,11 +1347,23 @@ class HloInstruction { // Describes the dimension numbers used for a convolution. std::unique_ptr convolution_dimension_numbers_; + // Describes the dimension numbers used for a dot. + std::unique_ptr dot_dimension_numbers_; + + // Describes FFT type for an FFT instruction. + FftType fft_type_ = FftType::FFT; + + // Indicates the FFT length for an FFT instruction. + std::vector fft_length_; + // Describes the [begin, end) index range for a slice. std::vector slice_starts_; std::vector slice_limits_; std::vector slice_strides_; + // Describes whether the slice can be lowered to an offset into the operand. + bool is_in_place_slice_ = false; + // The bit sizes for a reduce-precision operation. int32 exponent_bits_ = 0; int32 mantissa_bits_ = 0; @@ -1198,7 +1384,6 @@ class HloInstruction { // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; - string parameter_name_; // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; @@ -1267,9 +1452,12 @@ string ToString(HloInstruction::FusionKind kind); StatusOr StringToFusionKind( const string& kind_name); -// Custom stringification functions for protos that live inside HloInstruction. +// Custom (de)stringification functions for protos that live inside +// HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); +string RandomDistributionToString(const RandomDistribution& distribution); +StatusOr StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); @@ -1295,6 +1483,10 @@ template using ConstHloInstructionMap = std::map; +using HloInstructionSet = std::set; +using ConstHloInstructionSet = + std::set; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 76b12fc8d3aadc0a874ce059851666fbcd6a4e94..94e9bfe56eb445ec0b459a55342cd3cc4c6f68ef 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -712,8 +712,8 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { {1, 2}, {3, 4}, }))); - auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}); - auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {0, 1}); + auto shape10 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + auto shape01 = ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1}); auto outfeed10 = builder.AddInstruction( HloInstruction::CreateOutfeed(shape10, constant, "")); auto outfeed01 = builder.AddInstruction( @@ -825,17 +825,42 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { EXPECT_THAT(c1->users(), ElementsAre(fusion)); } -// Convenience function for comparing two HloInstructions inside of -// std::unique_ptrs. -static bool Identical(std::unique_ptr instruction1, - std::unique_ptr instruction2) { +// Convenience function for comparing two HloInstructions. +static bool Identical(const HloInstruction& instruction1, + const HloInstruction& instruction2) { // Verify Identical is reflexive for both instructions. - EXPECT_TRUE(instruction1->Identical(*instruction1)); - EXPECT_TRUE(instruction2->Identical(*instruction2)); + EXPECT_TRUE(instruction1.Identical(instruction1)); + EXPECT_TRUE(instruction2.Identical(instruction2)); - bool is_equal = instruction1->Identical(*instruction2); + bool is_equal = instruction1.Identical(instruction2); // Verify Identical is symmetric. - EXPECT_EQ(is_equal, instruction2->Identical(*instruction1)); + EXPECT_EQ(is_equal, instruction2.Identical(instruction1)); + return is_equal; +} + +// Convenience function for comparing two HloInstructions for structural +// equality. +static bool StructuralEqual(const HloInstruction& instruction1, + const HloInstruction& instruction2) { + auto eq_operand_shapes = [](const HloInstruction* a, + const HloInstruction* b) { + return ShapeUtil::Equal(a->shape(), b->shape()); + }; + auto eq_computations = [](const HloComputation* a, const HloComputation* b) { + return *a == *b; + }; + + // Verify Identical is reflexive for both instructions. + EXPECT_TRUE( + instruction1.Identical(instruction1, eq_operand_shapes, eq_computations)); + EXPECT_TRUE( + instruction2.Identical(instruction2, eq_operand_shapes, eq_computations)); + + bool is_equal = + instruction1.Identical(instruction2, eq_operand_shapes, eq_computations); + // Verify Identical is symmetric. + EXPECT_EQ(is_equal, instruction2.Identical(instruction1, eq_operand_shapes, + eq_computations)); return is_equal; } @@ -858,42 +883,42 @@ TEST_F(HloInstructionTest, IdenticalInstructions) { // Operations which only depend on their operands and opcode. EXPECT_TRUE( - Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1))); + Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1))); EXPECT_FALSE( - Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), - HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2))); + Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + *HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op2))); EXPECT_FALSE( - Identical(HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), - HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1))); + Identical(*HloInstruction::CreateUnary(shape, HloOpcode::kCopy, op1), + *HloInstruction::CreateUnary(shape, HloOpcode::kNegate, op1))); // Tuples. - EXPECT_TRUE(Identical(HloInstruction::CreateTuple({op1, op2}), - HloInstruction::CreateTuple({op1, op2}))); - EXPECT_FALSE(Identical(HloInstruction::CreateTuple({op1, op2}), - HloInstruction::CreateTuple({op2, op1}))); + EXPECT_TRUE(Identical(*HloInstruction::CreateTuple({op1, op2}), + *HloInstruction::CreateTuple({op1, op2}))); + EXPECT_FALSE(Identical(*HloInstruction::CreateTuple({op1, op2}), + *HloInstruction::CreateTuple({op2, op1}))); // Broadcasts. - EXPECT_TRUE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}), - HloInstruction::CreateBroadcast(shape, op1, {0, 1}))); - EXPECT_FALSE(Identical(HloInstruction::CreateBroadcast(shape, op1, {0, 1}), - HloInstruction::CreateBroadcast(shape, op1, {1, 0}))); + EXPECT_TRUE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}), + *HloInstruction::CreateBroadcast(shape, op1, {0, 1}))); + EXPECT_FALSE(Identical(*HloInstruction::CreateBroadcast(shape, op1, {0, 1}), + *HloInstruction::CreateBroadcast(shape, op1, {1, 0}))); Shape bcast_shape1 = ShapeUtil::MakeShape(F32, {2, 2, 42}); Shape bcast_shape2 = ShapeUtil::MakeShape(F32, {2, 2, 123}); EXPECT_FALSE( - Identical(HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}), - HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1}))); + Identical(*HloInstruction::CreateBroadcast(bcast_shape1, op1, {0, 1}), + *HloInstruction::CreateBroadcast(bcast_shape2, op1, {0, 1}))); // Binary operands. EXPECT_TRUE(Identical( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2))); + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2))); EXPECT_FALSE(Identical( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), - HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1))); + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op2, op1))); EXPECT_FALSE(Identical( - HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), - HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); + *HloInstruction::CreateBinary(shape, HloOpcode::kAdd, op1, op2), + *HloInstruction::CreateBinary(shape, HloOpcode::kDivide, op1, op2))); } TEST_F(HloInstructionTest, FunctionVisitor) { @@ -1068,8 +1093,11 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1086,49 +1114,71 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { ShapeUtil::Equal(root->operand(1)->shape(), root2->operand(1)->shape())); EXPECT_TRUE(ShapeUtil::Equal(root->operand(1)->operand(0)->shape(), root2->operand(1)->operand(0)->shape())); + EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); } -TEST_F(HloInstructionTest, IsRandomFusable) { - auto shape = ShapeUtil::MakeShape(F32, {2, 2}); - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->opcode()); - } - { - auto builder = HloComputation::Builder(TestName()); - auto hlo_module = CreateNewModule(); - auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(0.0))); - auto const1 = builder.AddInstruction(HloInstruction::CreateConstant( - Literal::CreateR0(1.0))); - auto rng = builder.AddInstruction(HloInstruction::CreateRng( - shape, RandomDistribution::RNG_NORMAL, {const0, const1})); - builder.AddInstruction(HloInstruction::CreateUnary( - shape, HloOpcode::kNegate, rng)); - auto* computation = hlo_module->AddEntryComputation(builder.Build()); - computation->CreateFusionInstruction({rng, const0, const1}, - HloInstruction::FusionKind::kLoop); - - auto* root = computation->root_instruction(); - - EXPECT_EQ(HloOpcode::kFusion, root->operand(0)->opcode()); - } +TEST_F(HloInstructionTest, FusionEquality) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Create two fusion instructions containing a single unary operation. + auto parameter = + builder.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kExp, parameter)); + auto neg = builder.AddInstruction( + HloInstruction::CreateUnary(r0f32_, HloOpcode::kNegate, parameter)); + auto* computation = module.AddEntryComputation(builder.Build()); + auto* fusion = computation->CreateFusionInstruction( + {exp}, HloInstruction::FusionKind::kLoop); + auto* fusion2 = computation->CreateFusionInstruction( + {neg}, HloInstruction::FusionKind::kLoop); + EXPECT_FALSE(StructuralEqual(*fusion, *fusion2)); + + auto clone = fusion->Clone(); + EXPECT_TRUE(StructuralEqual(*fusion, *clone)); } +TEST_F(HloInstructionTest, NestedFusionEquality) { + HloModule module(TestName()); + HloComputation::Builder builder(TestName()); + + // Build a nested fusion computation. + Shape data_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto a = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 0.0}, {0.0, 1.0}}))); + auto b = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + auto b_t = builder.AddInstruction( + HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto dot = builder.AddInstruction( + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); + auto one = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto add_operand = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape, one, {1})); + auto add = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kAdd, dot, add_operand)); + auto sub = builder.AddInstruction(HloInstruction::CreateBinary( + data_shape, HloOpcode::kSubtract, dot, add_operand)); + builder.AddInstruction( + HloInstruction::CreateBinary(data_shape, HloOpcode::kMultiply, add, sub)); + auto computation = module.AddEntryComputation(builder.Build()); + + auto nested_fusion = computation->CreateFusionInstruction( + {dot, b_t}, HloInstruction::FusionKind::kTransposeDot); + + auto fusion = computation->CreateFusionInstruction( + {add, nested_fusion}, HloInstruction::FusionKind::kOutput); + auto fusion2 = computation->CreateFusionInstruction( + {sub, nested_fusion}, HloInstruction::FusionKind::kOutput); + auto clone = fusion->Clone(); + EXPECT_TRUE(StructuralEqual(*fusion, *clone)); + EXPECT_FALSE(StructuralEqual(*fusion, *fusion2)); +} TEST_F(HloInstructionTest, CloneSuffixNames) { // Test that the suffix string added to cloned instructions is not @@ -1169,7 +1219,7 @@ TEST_F(HloInstructionTest, CloneSuffixNames) { } TEST_F(HloInstructionTest, Stringification) { - // Tests stringification of a simple op, fusion, and while. + // Tests stringification of a simple op, fusion, while, and conditional. const Shape s1 = ShapeUtil::MakeShape(F32, {5, 10}); const Shape s2 = ShapeUtil::MakeShape(F32, {20, 10}); const Shape s2t = ShapeUtil::MakeShape(F32, {10, 20}); @@ -1182,12 +1232,17 @@ TEST_F(HloInstructionTest, Stringification) { builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y")); HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape)); + HloInstruction::CreateDot(sout, x, reshape, dot_dnums)); + + auto options = HloPrintOptions().set_print_metadata(false); - EXPECT_EQ(dot->ToString(false, false), + EXPECT_EQ(dot->ToString(options), "%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} " - "%transpose)"); + "%transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}"); HloModule module(TestName()); auto* computation = module.AddEntryComputation(builder.Build()); @@ -1195,15 +1250,25 @@ TEST_F(HloInstructionTest, Stringification) { {dot, reshape}, HloInstruction::FusionKind::kTransposeDot); EXPECT_EQ( - fusion->ToString(false, false), - "%fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " + fusion->ToString(options), + "%dot_fusion = f32[5,20]{1,0} fusion(f32[5,10]{1,0} %x, " "f32[20,10]{1,0} %y), kind=kTransposeDot, calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); - EXPECT_EQ(loop->ToString(false, false), + EXPECT_EQ(loop->ToString(options), "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " "condition=%TransposeDot, body=%TransposeDot"); + + auto pred = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + HloInstruction* conditional = + builder.AddInstruction(HloInstruction::CreateConditional( + sout, pred, x, computation, x, computation)); + EXPECT_EQ(conditional->ToString(options), + "%conditional = f32[5,20]{1,0} conditional(pred[] %constant, " + "f32[5,10]{1,0} %x, f32[5,10]{1,0} %x), " + "true_computation=%TransposeDot, false_computation=%TransposeDot"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index 4255d6086625dfb9a045e4431e968a5ee0106ac7..bc74c4bc10cad20eab20b5caf8550b17048a5276 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -102,6 +102,36 @@ bool HloGetTupleElementMatcher::MatchAndExplain( return true; } +void HloCustomCallMatcher::DescribeTo(std::ostream* os) const { + HloMatcher::DescribeTo(os); + *os << " with call target that "; + call_target_matcher_.DescribeTo(os); +} + +bool HloCustomCallMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + ::testing::StringMatchResultListener sub_listener; + bool result = ExplainMatchResult( + call_target_matcher_, instruction->custom_call_target(), &sub_listener); + if (sub_listener.str().empty()) { + sub_listener << " that "; + + std::stringstream desc_stream; + if (result) { + call_target_matcher_.DescribeTo(&desc_stream); + } else { + call_target_matcher_.DescribeNegationTo(&desc_stream); + } + sub_listener << desc_stream.str(); + } + *listener << "custom-call with call target" << sub_listener.str(); + return result; +} + } // namespace testing void PrintTo(const HloInstruction* inst, ::std::ostream* os) { diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 992f55788b4900949f4994ba5b7be015bcd0d3de..103f04a2cb7a1a5ae877d8bf259692f7cbed3408 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -56,8 +56,8 @@ class HloParameterMatcher : public HloMatcher { // index to match. class HloGetTupleElementMatcher : public HloMatcher { public: - explicit HloGetTupleElementMatcher( - ::testing::Matcher operand, int64 tuple_index) + HloGetTupleElementMatcher(::testing::Matcher operand, + int64 tuple_index) : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}), tuple_index_(tuple_index) {} @@ -68,6 +68,24 @@ class HloGetTupleElementMatcher : public HloMatcher { int64 tuple_index_; }; +// Custom matcher for custom-call instructions, which accepts a matcher for its +// call target. +class HloCustomCallMatcher : public HloMatcher { + public: + HloCustomCallMatcher( + ::testing::Matcher call_target_matcher, + std::vector<::testing::Matcher> operands) + : HloMatcher(HloOpcode::kCustomCall, operands), + call_target_matcher_(call_target_matcher) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + void DescribeTo(std::ostream* os) const override; + + private: + ::testing::Matcher call_target_matcher_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -83,6 +101,7 @@ HLO_MATCHER(Abs); HLO_MATCHER(Add); HLO_MATCHER(Bitcast); HLO_MATCHER(Broadcast); +HLO_MATCHER(BatchNormGrad); HLO_MATCHER(Call); HLO_MATCHER(Ceil); HLO_MATCHER(Clamp); @@ -93,7 +112,6 @@ HLO_MATCHER(Convert); HLO_MATCHER(Convolution); HLO_MATCHER(Copy); HLO_MATCHER(CrossReplicaSum); -HLO_MATCHER(CustomCall); HLO_MATCHER(Divide); HLO_MATCHER(Dot); HLO_MATCHER(DynamicSlice); @@ -183,6 +201,36 @@ inline ::testing::Matcher GetTupleElement() { new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {})); } +// - CustomCall(T, operand1, ..., operandN) matches a CustomCall with call +// target T and the given operands. +// +// - CustomCall(operand1, ..., operandN) matches any CustomCall HLO with the +// given operands. +// +// - CustomCall() matches any CustomCall HLO at all. +template +inline ::testing::Matcher CustomCall( + ::testing::Matcher call_target_matcher, M... operands) { + return ::testing::MakeMatcher(new ::xla::testing::HloCustomCallMatcher( + call_target_matcher, {operands...})); +} +// This overload of CustomCall(A, B, C, ...) exists iff A is not convertible to +// ::testing::Matcher. In that case, we want to prefer the overload +// above. +template >::value, + void>::type*> +inline ::testing::Matcher CustomCall( + FirstM operands_first, M... operands_rest) { + return ::testing::MakeMatcher(new ::xla::testing::HloMatcher( + HloOpcode::kCustomCall, {operands_first, operands_rest...})); +} +inline ::testing::Matcher CustomCall() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {})); +} + #undef HLO_MATCHER } // namespace opcode_matchers diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 1465d1cacdc971a04c620bc48bed33239a67a955..1c21703a45e11914854153bc14fabd85e9ea57f2 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -23,6 +23,12 @@ using ::testing::Eq; namespace xla { namespace { +string DescribeHloMatcher(const ::testing::Matcher& m) { + std::stringstream ss; + m.DescribeTo(&ss); + return ss.str(); +} + template string Explain(const T& t, const M& m) { ::testing::StringMatchResultListener listener; @@ -67,5 +73,32 @@ TEST(HloMatchersTest, Test) { "add")); } +TEST(HloMatchersTest, CustomCallMatcher) { + auto c1 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); + auto c2 = HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3})); + auto call = HloInstruction::CreateCustomCall( + ShapeUtil::MakeShape(F32, {1}), {c1.get(), c2.get()}, "foo_target"); + + EXPECT_THAT(call.get(), op::CustomCall()); + EXPECT_THAT(call.get(), op::CustomCall(c1.get(), c2.get())); + EXPECT_THAT(call.get(), op::CustomCall("foo_target")); + EXPECT_THAT(call.get(), op::CustomCall("foo_target", c1.get(), c2.get())); + EXPECT_THAT(call.get(), op::CustomCall(::testing::StartsWith("foo"))); + EXPECT_THAT(call.get(), + op::CustomCall(::testing::Not(::testing::StartsWith("bar")))); + + // Wrong number of operands. + EXPECT_THAT(call.get(), ::testing::Not(op::CustomCall(c1.get()))); + + // Call target does not match. + EXPECT_THAT(call.get(), + ::testing::Not(op::CustomCall(::testing::StartsWith("bar")))); + + EXPECT_THAT(Explain(call.get(), op::CustomCall("bar")), + R"(custom-call with call target that isn't equal to "bar")"); + EXPECT_THAT(DescribeHloMatcher(op::CustomCall("foo_target")), + R"(custom-call with call target that is equal to "foo_target")"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index faaf73ea1ce5c77b0522cb3276b4efd78aabde16..60270b0595dcfca8f1fcea5ab0914428880f35b5 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -35,14 +35,19 @@ namespace xla { HloModule::HloModule(const string& name, const VersionedComputationHandle& entry_computation_handle, const HloModuleConfig& config) - : name_(name), + : name_(NameUniquer::GetSanitizedName(name)), config_(config), has_entry_computation_handle_(true), - entry_computation_handle_(entry_computation_handle) {} + entry_computation_handle_(entry_computation_handle), + unique_id_(next_unique_module_id_++) {} -HloModule::HloModule(const string& name) : name_(name) {} +HloModule::HloModule(const string& name) + : name_(NameUniquer::GetSanitizedName(name)), + unique_id_(next_unique_module_id_++) {} HloModule::HloModule(const string& name, const HloModuleConfig& config) - : name_(name), config_(config) {} + : name_(NameUniquer::GetSanitizedName(name)), + config_(config), + unique_id_(next_unique_module_id_++) {} HloComputation* HloModule::AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -170,17 +175,14 @@ void HloModule::ReplaceComputations( computations_ = std::move(new_computations); } -string HloModule::ToString(bool include_large_constants) const { +string HloModule::ToString(const HloPrintOptions& options) const { std::ostringstream s; - s << "HloModule " << name() << ":\n\n"; + s << "HloModule " << name() << "\n\n"; for (const HloComputation* computation : MakeComputationPostOrder()) { if (computation == entry_computation()) { s << "ENTRY "; } - s << computation->ToString( - /*nested_level=*/0, - /*include_large_constants=*/include_large_constants) - << "\n\n"; + s << computation->ToString(options) << "\n\n"; } return s.str(); } @@ -232,8 +234,8 @@ StatusOr ProgramShapeFromProto(const HloModuleProto& module) { << "Entry computation has more than one parameter instruction " "with parameter number " << instruction.parameter_number(); - parameters[instruction.parameter_number()] = { - instruction.parameter_name(), &instruction.shape()}; + parameters[instruction.parameter_number()] = {instruction.name(), + &instruction.shape()}; } } TF_RET_CHECK(root != nullptr) @@ -459,6 +461,14 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( return call; } +int64 HloModule::instruction_count() const { + int64 n = 0; + for (const auto& computation : computations_) { + n += computation->instruction_count(); + } + return n; +} + std::list HloModule::MakeComputationPostOrder() const { // First determine all root computations by building a set of nonroot // computations (computations which are called by an instruction in the @@ -517,7 +527,15 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { std::unordered_map clone_map; for (auto& computation : computations_) { - auto cloned_computation = computation->Clone(suffix); + if (computation->IsFusionComputation()) { + // Cloning of a fused computation is handled by its fusion instruction. + continue; + } + + // When cloning a computation, pass in the new module, so that for any + // fusion instruction in this computation, the fused computation will be + // deep cloned to the new module. + auto cloned_computation = computation->Clone(suffix, module.get()); InsertOrDie(&clone_map, computation.get(), cloned_computation.get()); if (entry_computation_ == computation.get()) { @@ -531,8 +549,15 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { for (auto* instruction : cloned_computation->instructions()) { // Rewrite instruction's called_computation to point to the cloned // computations. - instruction->ReplaceCalledComputations( - [&](HloComputation* hlo) { return FindOrDie(clone_map, hlo); }); + instruction->ReplaceCalledComputations([&](HloComputation* hlo) { + if (hlo->IsFusionComputation()) { + // Cloning of a fused computation has already been handled when its + // fusion instruction is cloned. So this hlo computation is already + // the cloned one. + return hlo; + } + return FindOrDie(clone_map, hlo); + }); } } return module; @@ -543,4 +568,6 @@ uint64 HloModule::RandomNew64() const { return rng_(); } +/* static */ std::atomic HloModule::next_unique_module_id_(0); + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index 5141e7bc8d4cf0ef4cd83310772e0c5d66b5da12..4bfe8d89ce0a285de6d05d4867aaa6b266d78d12 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_H_ +#include #include #include #include @@ -98,6 +99,10 @@ class HloModule { return config_.mutable_entry_computation_layout(); } + ComputationLayout entry_computation_layout() const { + return config_.entry_computation_layout(); + } + const VersionedComputationHandle& entry_computation_handle() const { return entry_computation_handle_; } @@ -125,6 +130,9 @@ class HloModule { // Gets the number of computations in this module. int64 computation_count() const { return computations_.size(); } + // Gets the number of instructions in this module. + int64 instruction_count() const; + // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. @@ -143,7 +151,12 @@ class HloModule { const HloModuleConfig& config() const { return config_; } - string ToString(bool include_large_constants = false) const; + // Return a string representation of the module. + // + // (We express the default options using an overload rather than a default + // param because gdb ignores default params, but does resolve overloads.) + string ToString() const { return ToString(HloPrintOptions()); } + string ToString(const HloPrintOptions& options) const; // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; @@ -189,6 +202,10 @@ class HloModule { // this point are guaranteed to be in the range [0..NumUniqueInstructionIds()) int NumUniqueInstructionIds() const { return next_unique_id_; } + // Returns an id that is unique to this module across all modules created over + // the lifetime of this process. + int unique_id() const { return unique_id_; } + private: HloComputation* AddComputationInternal( std::unique_ptr computation, bool is_entry, @@ -215,6 +232,11 @@ class HloModule { NameUniquer computation_name_uniquer_{/*separator=*/"."}; NameUniquer instruction_name_uniquer_{/*separator=*/"."}; int next_unique_id_ = 0; + + // Used to keep track of the next unique module id that should be assigned. + static std::atomic next_unique_module_id_; + // A unique id to label modules with. + int unique_id_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index bf6440d66cac0d3a929c377202b212aba262f887..7f28a804bfec9c2f1bbb5fa08f7dd4e68be14d35 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -105,6 +105,48 @@ TEST_F(HloModuleTest, CloneTest) { } } +TEST_F(HloModuleTest, CloneHasFusion) { + auto module = CreateNewModule(); + + // Create the fused computation. + HloComputation* fused_computation; + { + auto b = HloComputation::Builder("Fused"); + auto x = b.AddInstruction(HloInstruction::CreateParameter(0, r0f32_, "x")); + b.AddInstruction( + HloInstruction::CreateBinary(r0f32_, HloOpcode::kAdd, x, x)); + fused_computation = module->AddEmbeddedComputation(b.Build()); + } + + // Create the entry computation. + { + auto b = HloComputation::Builder("Entry"); + auto input = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0f))); + b.AddInstruction( + HloInstruction::CreateFusion(r0f32_, HloInstruction::FusionKind::kInput, + /*operands=*/{input}, fused_computation)); + module->AddEntryComputation(b.Build()); + } + + auto post_order = module->MakeComputationPostOrder(); + auto cloned_module = module->Clone("copy"); + auto post_order_copied = cloned_module->MakeComputationPostOrder(); + + EXPECT_EQ(post_order.size(), post_order_copied.size()); + for (auto origin = post_order.begin(), copied = post_order_copied.begin(); + origin != post_order.end() && copied != post_order_copied.end(); + ++origin, ++copied) { + if ((*origin)->name() == "Fused") { + // Clone of the fused computation is handled when its fusion instruction + // is cloned, which always use suffix ".clone". + EXPECT_EQ((*origin)->name() + ".clone", (*copied)->name()); + } else { + EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name()); + } + } +} + TEST_F(HloModuleTest, DiamondComputationsPostOrder) { // Create a module with a diamond call graph of computations. auto module = CreateNewModule(); @@ -135,14 +177,21 @@ TEST_F(HloModuleTest, LargeConstantToString) { module->AddEntryComputation(builder.Build()); EXPECT_EQ( - "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " "ROOT %constant = f32[16]{0} constant({...})\n}\n\n", - module->ToString(/*include_large_constants=*/false)); + module->ToString(HloPrintOptions().set_print_large_constants(false))); + EXPECT_EQ( - "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "HloModule LargeConstantToString\n\nENTRY %Constant () -> f32[16] {\n " "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, " "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n", - module->ToString(/*include_large_constants=*/true)); + module->ToString(HloPrintOptions().set_print_large_constants(true))); +} + +TEST_F(HloModuleTest, UniqueModuleId) { + auto module_a = CreateNewModule(); + auto module_b = CreateNewModule(); + EXPECT_NE(module_a->unique_id(), module_b->unique_id()); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index f3f79357582ac7661a532e94031acdbca0b86784..3d64523a79fc50638fdf378b5d521a5cd4482b90 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -73,6 +73,7 @@ namespace xla { V(kDynamicUpdateSlice, "dynamic-update-slice") \ V(kEq, "equal-to", kHloOpcodeIsComparison) \ V(kExp, "exponential") \ + V(kFft, "fft") \ V(kFloor, "floor") \ V(kFusion, "fusion", kHloOpcodeIsVariadic) \ V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 6f6e679a21870e46da85963c3b2998465ac43420..68e3c9618c1fe9daacb0aee3ee98862c8b9e4bc4 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -249,7 +249,7 @@ bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector pieces; pieces.push_back(name); - for (auto* computation : module_->computations()) { + for (auto* computation : module_->MakeNonfusionComputations()) { pieces.push_back(tensorflow::strings::Printf("computation %s:", computation->name().c_str())); const auto all = computation->MakeInstructionPostOrder(); diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc index 33bafd05c15c47abaa313f92eb53a791de43d7d9..aba66114de649ce7667ae77174e9c4073b010b90 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -310,5 +311,56 @@ TEST_F(HloOrderingTest, ValuesInWhileComputations) { *dataflow)); } +// Regression test for HloOrdering::ToString() crashing when fed a computation +// containing a fusion node. +TEST_F(HloOrderingTest, ToStringDoesNotCrash) { + const char* module_str = R"( +HloModule test_module + +body.v8 { + prev.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) + get-tuple-element.4 = s32[] get-tuple-element(prev.1), index=0 + constant.1 = s32[] constant(1) + add = s32[] add(get-tuple-element.4, constant.1) + get-tuple-element.5 = f32[3]{0} get-tuple-element(prev.1), index=3 + get-tuple-element.6 = f32[3]{0} get-tuple-element(prev.1), index=1 + get-tuple-element.7 = f32[3]{0} get-tuple-element(prev.1), index=2 + ROOT tuple = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(add, get-tuple-element.5, get-tuple-element.6, get-tuple-element.7) +} + +condition.v4 { + constant.2 = s32[] constant(2) + prev.2 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) parameter(0) + get-tuple-element.8 = s32[] get-tuple-element(prev.2), index=0 + ROOT greater-than = pred[] greater-than(constant.2, get-tuple-element.8) +} + +fused_computation { + get-tuple-element.5.param_1 = f32[3]{0} parameter(1) + get-tuple-element.6.param_2 = f32[3]{0} parameter(2) + add.4 = f32[3]{0} add(get-tuple-element.5.param_1, get-tuple-element.6.param_2) + get-tuple-element.7.param_1.1 = f32[3]{0} parameter(0) + ROOT add.5 = f32[3]{0} add(add.4, get-tuple-element.7.param_1.1) +} + +ENTRY while.v11 { + constant.5 = s32[] constant(0) + constant.6 = f32[3]{0} constant({1, 1, 1}) + constant.7 = f32[3]{0} constant({2, 2, 2}) + constant.8 = f32[3]{0} constant({3, 3, 3}) + tuple.1 = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) tuple(constant.5, constant.6, constant.7, constant.8) + while = (s32[], f32[3]{0}, f32[3]{0}, f32[3]{0}) while(tuple.1), condition=condition.v4, body=body.v8 + get-tuple-element.9 = f32[3]{0} get-tuple-element(while), index=3 + get-tuple-element.10 = f32[3]{0} get-tuple-element(while), index=1 + get-tuple-element.11 = f32[3]{0} get-tuple-element(while), index=2 + ROOT fusion = f32[3]{0} fusion(get-tuple-element.9, get-tuple-element.10, get-tuple-element.11), kind=kLoop, calls=fused_computation +})"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + tools::Parse(module_str)); + DependencyHloOrdering ordering(module.get()); + ordering.ToString(); // Shouldn't crash. +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index 53bd46a641afcba1b9551895955742e74a9f374b..5120775737bfa32bbb656421216f2b3fbef590ea 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" @@ -32,12 +33,28 @@ using ::tensorflow::strings::StrCat; namespace xla { namespace { -void DumpModule(const HloModule& module, - const string& message) { +void DumpModuleGraph(const HloModule& module, const string& message) { hlo_graph_dumper::MaybeDumpHloModule(module, message); VLOG(3) << "HLO " << message << ":"; XLA_VLOG_LINES(3, module.ToString()); } + +void DumpModuleProto(const HloModule& module, const string& dump_to, + const string& pipeline_name, const string& pass_name) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static auto* const module_id_to_pass_number = + new tensorflow::gtl::FlatMap(); + + tensorflow::mutex_lock lock(mu); + const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++; + + const string mod_name = SanitizeFileName(tensorflow::strings::Printf( + "module_%04d.%04lld.%s.after_%s", module.unique_id(), pass_number, + pipeline_name.c_str(), pass_name.c_str())); + + TF_QCHECK_OK(protobuf_util::DumpProtoToDirectory(MakeHloProto(module), + dump_to, mod_name)); +} } // namespace StatusOr HloPassPipeline::Run(HloModule* module) { @@ -78,6 +95,13 @@ StatusOr HloPassPipeline::Run(HloModule* module) { string message; TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("before running pipeline: ", name()))); + const string xla_dump_per_pass_hlo_proto_to = + module->config().debug_options().xla_dump_per_pass_hlo_proto_to(); + if (!xla_dump_per_pass_hlo_proto_to.empty()) { + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, name().ToString(), + "pipeline_start"); + } + for (auto& pass : passes_) { if (disabled_passes.count(pass->name().ToString()) > 0) { VLOG(1) << " Skipping HLO pass " << pass->name() @@ -90,17 +114,21 @@ StatusOr HloPassPipeline::Run(HloModule* module) { // Emit label containing: "after foo-pass, before bar-pass". message.clear(); StrAppend(&message, prefix, ", before ", pass->name()); - DumpModule(*module, message); + DumpModuleGraph(*module, message); TF_ASSIGN_OR_RETURN(bool changed_this_pass, pass->Run(module)); TF_RETURN_IF_ERROR( run_invariant_checkers(StrCat("after running pass: ", pass->name()))); + if (!xla_dump_per_pass_hlo_proto_to.empty()) { + DumpModuleProto(*module, xla_dump_per_pass_hlo_proto_to, + name().ToString(), pass->name().ToString()); + } changed |= changed_this_pass; prefix.clear(); StrAppend(&prefix, name(), ": after ", pass->name()); } - DumpModule(*module, prefix + ", pipeline end"); + DumpModuleGraph(*module, prefix + ", pipeline end"); return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.cc b/tensorflow/compiler/xla/service/hlo_profile_printer.cc index e944ad15139af0d2f98e8e68d3d48303f47ecf1c..dcc22793015147aaf3229875078b2989e4ef7559 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.cc +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.cc @@ -18,20 +18,20 @@ limitations under the License. #include "tensorflow/compiler/xla/service/human_readable_profile_builder.h" namespace xla { -string HloProfilePrinter::ToString(const int64* counters, - double clock_rate_ghz) const { +string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, + const int64* counters, double clock_rate_ghz) { + using HloComputationInfo = HloProfilePrinterData::HloComputationInfo; + using HloInstructionInfo = HloProfilePrinterData::HloInstructionInfo; + string result; - for (int computation_idx = 0; computation_idx < computation_infos_size_; - computation_idx++) { - const HloComputationInfo& computation = computation_infos_[computation_idx]; - const HloInstructionInfo* instructions_begin = computation.instructions; - const HloInstructionInfo* instructions_end = - computation.instructions + computation.instructions_size; + for (const HloComputationInfo& computation_info : + hlo_profile_printer_data.computation_infos()) { + const auto& instruction_infos = computation_info.instruction_infos(); bool any_instruction_profiled = - std::any_of(instructions_begin, instructions_end, + std::any_of(instruction_infos.begin(), instruction_infos.end(), [&](const HloInstructionInfo& instruction_info) { - return counters[instruction_info.profile_index] != 0; + return counters[instruction_info.profile_index()] != 0; }); if (!any_instruction_profiled) { @@ -41,16 +41,19 @@ string HloProfilePrinter::ToString(const int64* counters, // Once we start using this in AOT for real, we will probably need a more // minimal version of HumanReadableProfileBuilder. HumanReadableProfileBuilder builder( - computation.name, counters[computation.profile_index], clock_rate_ghz); + computation_info.name(), counters[computation_info.profile_index()], + clock_rate_ghz); - for (const auto* instruction = instructions_begin; - instruction != instructions_end; instruction++) { + for (const auto& instruction_info : instruction_infos) { builder.AddOp( - /*op_name=*/instruction->long_name, - /*short_name=*/instruction->short_name, instruction->category, - counters[instruction->profile_index], instruction->flop_count, - instruction->transcendental_count, instruction->bytes_accessed, - instruction->optimal_seconds); + /*op_name=*/instruction_info.long_name(), + /*short_name=*/instruction_info.short_name(), + instruction_info.category(), + counters[instruction_info.profile_index()], + instruction_info.flop_count(), + instruction_info.transcendental_count(), + instruction_info.bytes_accessed(), + instruction_info.optimal_seconds()); } result += builder.ToString(); @@ -58,10 +61,4 @@ string HloProfilePrinter::ToString(const int64* counters, return result; } - -HloProfilePrinter::~HloProfilePrinter() { - if (deleter_) { - deleter_(computation_infos_, computation_infos_size_); - } -} } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h index 316753a82ab2a9b5459b71c723a8e817ee2cacbf..b72325c7554acad258c2da55a18e5e18ec1b06a6 100644 --- a/tensorflow/compiler/xla/service/hlo_profile_printer.h +++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h @@ -13,85 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ #include #include #include +#include "tensorflow/compiler/xla/service/hlo_profile_printer_data.pb.h" #include "tensorflow/compiler/xla/types.h" namespace xla { -// Instances of this class can pretty-print profile counters gathered from -// running an XLA computation without having access to the backing module. -class HloProfilePrinter { - public: - // Holds meta information about an HloInstruction. - // - // The pointer-typed fields can be owning or non-owning -- this decision is - // manifested as the deleter_ function in the containing HloProfilePrinter. - struct HloInstructionInfo { - // Textual information for pretty printing. - const char* long_name; - const char* short_name; - const char* category; - - // Metrics computed by HloCostAnalysis. - float flop_count; - float transcendental_count; - float bytes_accessed; - float optimal_seconds; - - // The index into the profile counters array for the HloInstruction - // corresponding to this HloInstructionInfo. - int64 profile_index; - }; - - // Holds meta information about an HloComputation. - // - // The pointer-typed fields can be owning or non-owning -- this decision is - // manifested as the deleter_ function in the containing HloProfilePrinter. - struct HloComputationInfo { - const char* name; - - // The index into the profile counters array for the HloInstruction - // corresponding to this HloComputationInfo. - int64 profile_index; - - HloInstructionInfo* instructions; - int64 instructions_size; - }; - - HloProfilePrinter( - HloComputationInfo* computation_infos, int64 computation_infos_size, - std::function deleter = nullptr) - : computation_infos_(computation_infos), - computation_infos_size_(computation_infos_size), - deleter_(std::move(deleter)) {} - - HloProfilePrinter(HloProfilePrinter&& other) { - std::swap(other.computation_infos_, computation_infos_); - std::swap(other.computation_infos_size_, computation_infos_size_); - std::swap(other.deleter_, deleter_); - } - - HloProfilePrinter(const HloProfilePrinter&) = delete; - HloProfilePrinter& operator=(const HloProfilePrinter&) = delete; - - // Convert the profile counter sequence `counters` to a human readable string - // representation. - string ToString(const int64* counters, double clock_rate_ghz) const; - - ~HloProfilePrinter(); - - private: - // The `computation_infos_` field can be owning or non-owning -- this decision - // is manifested as the deleter_ function. - HloComputationInfo* computation_infos_ = nullptr; - int64 computation_infos_size_ = 0; - std::function deleter_; -}; +// Pretty-print an array of profile counters using hlo_profile_printer_data. +string PrintHloProfile(const HloProfilePrinterData& hlo_profile_printer_data, + const int64* counters, double clock_rate_ghz); } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROFILE_PRINTER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto new file mode 100644 index 0000000000000000000000000000000000000000..9f22b733fe1d676b177039a9d7a3064b8638d7bc --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_profile_printer_data.proto @@ -0,0 +1,60 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +syntax = "proto3"; + +package xla; + +option cc_enable_arenas = true; + +// Describes how to pretty-print a profile counter array gathered for a specific +// HloModule. +message HloProfilePrinterData { + // Pretty-printer information about an HloInstruction. + message HloInstructionInfo { + string long_name = 1; + string short_name = 2; + string category = 3; + + // Metrics computed by HloCostAnalysis. + float flop_count = 4; + float transcendental_count = 5; + float bytes_accessed = 6; + float optimal_seconds = 7; + + // The index into the profile counters array for the HloInstruction + // corresponding to this HloInstructionInfo. + int64 profile_index = 8; + } + + // Pretty-printer information about an HloComputation. + message HloComputationInfo { + string name = 1; + + // The index into the profile counters array for the HloComputation + // corresponding to this HloComputationInfo. + int64 profile_index = 2; + + // HloInstructionInfos for every HloInstruction in the HloComputation for + // corresponding to this HloComputattionInfo. + repeated HloInstructionInfo instruction_infos = 3; + } + + // HloComputationInfos for every HloComputation in the HloModule. + repeated HloComputationInfo computation_infos = 1; + + // The size of the profile counters array we will pretty-print. + int64 profile_counters_size = 2; +} diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.cc b/tensorflow/compiler/xla/service/hlo_proto_util.cc index 727ad0178c6227cd2e64c31a4618e781671b9393..78e6a101c10a1e812e3e2631d520139fd0bc425c 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.cc +++ b/tensorflow/compiler/xla/service/hlo_proto_util.cc @@ -19,15 +19,20 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment) { - HloModuleProto proto_module = module.ToProto(); HloOrderingProto proto_ordering = assignment.liveness().hlo_ordering().ToProto(); BufferAssignmentProto proto_assignment = assignment.ToProto(); - HloProto proto; - proto.mutable_hlo_module()->Swap(&proto_module); + HloProto proto = MakeHloProto(module); proto.mutable_hlo_ordering()->Swap(&proto_ordering); proto.mutable_buffer_assignment()->Swap(&proto_assignment); return proto; } +HloProto MakeHloProto(const HloModule& module) { + HloModuleProto proto_module = module.ToProto(); + HloProto proto; + proto.mutable_hlo_module()->Swap(&proto_module); + return proto; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_proto_util.h b/tensorflow/compiler/xla/service/hlo_proto_util.h index 603259a11fcdca59f58653d9a7a164c983711a57..320288fdb9aa0810b306b1d78bd1ff4cfc366ed2 100644 --- a/tensorflow/compiler/xla/service/hlo_proto_util.h +++ b/tensorflow/compiler/xla/service/hlo_proto_util.h @@ -31,6 +31,10 @@ namespace xla { HloProto MakeHloProto(const HloModule& module, const BufferAssignment& assignment); +// Returns a serialized representation of the HLO state, but buffer assignment +// will not be included in the output. +HloProto MakeHloProto(const HloModule& module); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_PROTO_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h index d7bdac9c86579f19afbba133772c2c50894853d1..553ec11f6f9a2997ab7113f9b8241e04c7fe20d5 100644 --- a/tensorflow/compiler/xla/service/hlo_reachability.h +++ b/tensorflow/compiler/xla/service/hlo_reachability.h @@ -30,11 +30,17 @@ namespace xla { class HloInstruction; -// A class for computing and representing reachability between HloInstructions. +// A class for representing reachability between HloInstructions. +// +// !!! THIS CLASS DOES NOT COMPUTE REACHABILITY !!! It has an adjacency matrix +// and it is up to the user of the class to set the adjacency matrix such that +// it represents reachability, i.e. such that it is transitive. That the graph +// be transitive is thus not an invariant of this class, but it is required for +// the name of the class and its methods to make sense. class HloReachabilityMap { public: - // Sets up an empty reachable matrix for the full set of instructions - // specified in 'instructions'. + // Sets up a graph with no edges and where the nodes correspond to the given + // instructions. explicit HloReachabilityMap(const std::list& instructions); // Set the reachability set of 'instruction' to the union of the reachability @@ -42,17 +48,33 @@ class HloReachabilityMap { // 'x' is not 'instruction' will return true iff IsReachable(x, input) is true // for some 'input' in 'inputs'. Also sets 'instruction' to be reachable from // itself. Returns whether the reachability set of 'instruction' changed. + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // vector in the internal graph of this HloReachabilityMap for the given + // instruction and does not transitively update any other part of the + // adjacency matrix. bool SetReachabilityToUnion( tensorflow::gtl::ArraySlice inputs, const HloInstruction* instruction); // Sets entry so that IsReachable(a, b) will return true + // + // !!! THIS FUNCTION DOES NOT COMPUTE REACHABILITY !!! It sets the adjacency + // matrix in the internal graph of this HloReachabilityMap to have an edge + // from a to b and does not transitively update any other part of the + // adjacency matrix. void SetReachable(const HloInstruction* a, const HloInstruction* b); // Returns true if "b" is reachable from "a" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsReachable(const HloInstruction* a, const HloInstruction* b) const; // Returns true if "b" is reachable from "a" or "a" is reachable from "b" + // + // Note that this function only correctly answers queries about reachability + // if the set of edges that have been provided to this class are transitive. bool IsConnected(const HloInstruction* a, const HloInstruction* b) const; private: diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 017f996bc4d1902c81f96425b7bc28d52622df0f..c6b4dc0368d92fd477decdfb38045f74f8696803 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -566,7 +566,9 @@ Status MemoryUsageTracker::BeginInstruction(Item* item) { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -603,8 +605,9 @@ Status MemoryUsageTracker::EndInstruction() { VLOG(3) << " memory usage = " << memory_usage_; VLOG(10) << ToString(); - DCHECK(Check()); - + if (VLOG_IS_ON(1)) { + DCHECK(Check()); + } return Status::OK(); } @@ -1021,7 +1024,9 @@ StatusOr HloRematerialization::RematerializeComputation( HloInstruction* best = best_item->instruction; VLOG(1) << "Rematerializing instruction " << best->name() << " (saving " - << memory_tracker.MemoryReducedIfRematerialized(best_item) << ")"; + << HumanReadableNumBytes( + memory_tracker.MemoryReducedIfRematerialized(best_item)) + << ")"; changed = true; remat_count++; @@ -1101,8 +1106,8 @@ StatusOr HloRematerialization::RematerializeComputation( net_instructions_added++; } - VLOG(3) << "memory_usage after rematerialization = " - << memory_tracker.memory_usage(); + VLOG(1) << "memory_usage after rematerialization = " + << HumanReadableNumBytes(memory_tracker.memory_usage()); } const CallSite* callsite = call_graph_node.GetCallSite(instruction); @@ -1208,11 +1213,12 @@ StatusOr HloRematerialization::Run( XLA_VLOG_LINES(3, "Before HloRematerialization:\n" + module->ToString()); // Create initial sequence of HLO instructions. - TF_ASSIGN_OR_RETURN(*sequence, - CreateMemoryMinimizingSequence( - *module, [this](const LogicalBuffer& buffer) { - return size_function_(buffer.shape()); - })); + TF_ASSIGN_OR_RETURN(*sequence, CreateMemoryMinimizingSequence( + *module, + [this](const LogicalBuffer& buffer) { + return size_function_(buffer.shape()); + }, + scheduler_algorithm_)); // Compute peak memory usage of all computations in the module called in a // sequential context. call_graph_ = CallGraph::Build(module); @@ -1313,9 +1319,10 @@ StatusOr HloRematerialization::Run( /* static */ StatusOr HloRematerialization::RematerializeAndSchedule( const HloRematerialization::ShapeSizeFunction& size_function, int64 memory_limit_bytes, HloModule* hlo_module, + SchedulerAlgorithm scheduler_algorithm, SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes) { - HloRematerialization remat(size_function); + HloRematerialization remat(scheduler_algorithm, size_function); return remat.Run(hlo_module, sequence, memory_limit_bytes, sizes); } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h index 11f79a6d4158c6251c2faf63e9cac4e742440863..52553439033a3bcfa4b472f13f9cd4b1ecf5ed96 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.h +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h @@ -20,6 +20,7 @@ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_scheduling.h" #include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" namespace xla { @@ -65,12 +66,15 @@ class HloRematerialization { // code generation. static StatusOr RematerializeAndSchedule( const ShapeSizeFunction& size_function, int64 memory_limit_bytes, - HloModule* hlo_module, SequentialHloOrdering::HloModuleSequence* sequence, + HloModule* hlo_module, SchedulerAlgorithm scheduler_algorithm, + SequentialHloOrdering::HloModuleSequence* sequence, RematerializationSizes* sizes = nullptr); protected: - HloRematerialization(const ShapeSizeFunction& size_function) - : size_function_(size_function) {} + HloRematerialization(SchedulerAlgorithm scheduler_algorithm, + const ShapeSizeFunction& size_function) + : scheduler_algorithm_(scheduler_algorithm), + size_function_(size_function) {} ~HloRematerialization() {} // Runs rematerialization on the given module. Returns whether the module was @@ -103,6 +107,9 @@ class HloRematerialization { StatusOr CalledComputationsMemoryUsage( const HloInstruction* instruction) const; + // Selects an algorithm to use for HLO scheduling. + SchedulerAlgorithm scheduler_algorithm_; + // Function which computes the size of the top-level buffer of a shape. const ShapeSizeFunction size_function_; diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index d88aa4bb567c6c5f6eab54f12239bf7040339c39..1b7d26dde501a6a0955d62ea0938e0683a32d49d 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -158,11 +158,11 @@ TEST_F(HloRematerializationTest, SingleComputation) { SequentialHloOrdering::HloModuleSequence sequence; // Computation requires 16KB without rematerialization, but uses only 12KB // with rematerialization so pick a memory limit between these values (14KB). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/14 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/14 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Root should not have changed. @@ -191,11 +191,11 @@ TEST_F(HloRematerializationTest, SingleComputationNoRematerialization) { EXPECT_EQ(computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/20 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/20 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); // No instructions should have been materialized. EXPECT_FALSE(changed); @@ -232,11 +232,11 @@ TEST_F(HloRematerializationTest, RematerializeAroundWhile) { // while so the peak memory use of the module is 18KB. Set the memory limit a // bit lower (17KB) to force rematerialization of the entry computation. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/17 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/17 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Only the entry computation should have a rematerialized instruction added. @@ -268,11 +268,11 @@ TEST_F(HloRematerializationTest, RematerializeEntryAndWhileBody) { EXPECT_EQ(body_computation->instruction_count(), 7); SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/15 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/15 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // Both computations should have a rematerialized instruction added. @@ -310,11 +310,11 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { // If all computations are maximally rematerialized then peak memory usage is // ~12K so pick something slightly larger. SequentialHloOrdering::HloModuleSequence sequence; - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/13 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/13 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // All computations should have a rematerialized instruction added. @@ -323,6 +323,76 @@ TEST_F(HloRematerializationTest, RematerializeNestedComputations) { EXPECT_EQ(inner_computation->instruction_count(), 8); } +TEST_F(HloRematerializationTest, RngNotRematerialized) { + // Test that a single rng is not rematerialized: + // + // Entry computation: + // F32[] %param = {...} + // F32[1024] rng = rng(param) + // F32[1024] tanh = tanh(rng) + // F32[1024] exp = exp(rng) + // F32[1024] add_0 = add(rng, tanh) // LIVE: add_0 + rng + + // // tanh + exp + // + // F32[1024] add_1 = add(rng, add(exp, add_0)) // LIVE: add_1 + add_0 + + // // rng + tanh + exp + // + // F32[1024] add_2 = add(rng, add(tanh, add_1)) // LIVE: add_2 + add_1 + + // // rng + tanh + exp + auto module = CreateNewModule(); + + auto builder = HloComputation::Builder(TestName()); + auto param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + auto rng = builder.AddInstruction(HloInstruction::CreateRng( + vec1024_shape_, RandomDistribution::RNG_UNIFORM, {param, param})); + auto tanh = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kTanh, rng)); + auto exp = builder.AddInstruction( + HloInstruction::CreateUnary(vec1024_shape_, HloOpcode::kExp, rng)); + auto add_0 = builder.AddInstruction( + HloInstruction::CreateBinary(vec1024_shape_, HloOpcode::kAdd, rng, tanh)); + auto add_1 = builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, exp, add_0)))); + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, rng, + builder.AddInstruction(HloInstruction::CreateBinary( + vec1024_shape_, HloOpcode::kAdd, tanh, add_1)))); + HloComputation* entry_computation = + module->AddEntryComputation(builder.Build()); + + auto count_rngs = [](const HloComputation* computation) { + int64 rng_count = 0; + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kRng) { + ++rng_count; + } + } + return rng_count; + }; + // Before rematerialization there should be a single broadcast rng in + // the graph. + ASSERT_EQ(count_rngs(entry_computation), 1); + const int64 original_instruction_count = + entry_computation->instruction_count(); + SequentialHloOrdering::HloModuleSequence sequence; + // Pick a memory limit some where between 24KB (initial peak memory including + // parameter and output) and 20KB (peak memory possible with + // rematerialization). + TF_ASSERT_OK_AND_ASSIGN( + bool changed, HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/4 * ByteSizeOf(vec1024_shape_), + module.get(), SchedulerAlgorithm::kAuto, &sequence)); + EXPECT_TRUE(changed); + // The rng should not have been rematerialized. + EXPECT_EQ(count_rngs(entry_computation), 1); + // There should have been rematerialization. + EXPECT_GT(entry_computation->instruction_count(), original_instruction_count); +} + TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Test that a single instruction is rematerialized several times. Module: // @@ -406,11 +476,11 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); EXPECT_TRUE(changed); // The broadcast should have been rematerialized 3 times. @@ -503,11 +573,11 @@ TEST_P(IndirectUseTest, IndirectUseNotRematerialized) { // Pick a memory limit some where between 24KB (initial peak memory including // parameter and output) and 20KB (peak memory possible with // rematerialization). - TF_ASSERT_OK_AND_ASSIGN( - bool changed, - HloRematerialization::RematerializeAndSchedule( - ByteSizeOf, - /*memory_limit_bytes=*/22 * 1024, module.get(), &sequence)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, + HloRematerialization::RematerializeAndSchedule( + ByteSizeOf, + /*memory_limit_bytes=*/22 * 1024, module.get(), + SchedulerAlgorithm::kAuto, &sequence)); // Rematerialization should only occur if the rematerializable instruction has // no indirect uses. if (indirectly_used) { diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc index 6b6d48233a7da50927207b8334186ee5105db268..41b079eb799d06321a31f7d7ae0630dc8d58c46b 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.cc +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -40,21 +40,18 @@ namespace se = ::perftools::gputools; namespace xla { /*static*/ StatusOr> -HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, - const DebugOptions& debug_options) { - HloProto proto; - - const Status s = - tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto); +HloRunner::CreateModuleFromString(const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options) { + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} - if (!s.ok()) { - const Status s2 = - tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto); - if (!s2.ok()) { - return Status(s2.code(), s.error_message() + "\n" + s2.error_message()); - } - } +namespace { +// Creates an HloModule from the given proto. +StatusOr> HloProtoToModule( + const HloProto& proto, const DebugOptions& debug_options) { TF_ASSIGN_OR_RETURN( HloModuleConfig config, HloModule::CreateModuleConfigFromProto(proto.hlo_module())); @@ -64,9 +61,29 @@ HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, return std::move(module); } +} // namespace + /*static*/ StatusOr> -HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, +HloRunner::ReadModuleFromBinaryProtoFile(const std::string& filename, const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), + filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromTextProtoFile(const std::string& filename, + const DebugOptions& debug_options) { + HloProto proto; + TF_RETURN_IF_ERROR( + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto)); + return HloProtoToModule(proto, debug_options); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloTextFile(const std::string& filename, + const DebugOptions& debug_options) { string hlo_string; TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), filename, &hlo_string)); @@ -75,19 +92,6 @@ HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, return tools::Parse(hlo_string, config); } -/*static*/ StatusOr> HloRunner::ReadModule( - const std::string& filename, const DebugOptions& debug_options) { - auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options); - if (module.ok()) { - return module; - } - const std::string e = module.status().error_message(); - module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options); - return module.ok() ? std::move(module) - : Status(module.status().code(), - e + "\n" + module.status().error_message()); -} - // Define this in .cc file to avoid having to include eigen or forward declare // these types in the header. struct HloRunner::EigenThreadPoolWrapper { @@ -104,31 +108,29 @@ HloRunner::HloRunner(se::Platform* platform) { VLOG(1) << "Created HloRunner for platform: " << platform->Name(); } -HloRunner::~HloRunner() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} +HloRunner::~HloRunner() {} -StatusOr HloRunner::Execute( +StatusOr> HloRunner::ExecuteInternal( std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments, - Shape* result_shape, bool run_hlo_passes) { + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes) { if (run_hlo_passes) { TF_ASSIGN_OR_RETURN( module, backend().compiler()->RunHloPasses( - std::move(module), backend().default_stream_executor())); + std::move(module), backend().default_stream_executor(), + /*device_allocator=*/nullptr)); } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, backend().compiler()->RunBackend(std::move(module), - backend().default_stream_executor())); + backend().default_stream_executor(), + /*device_allocator=*/nullptr)); se::Stream stream(backend().default_stream_executor()); stream.Init(); ExecutableRunOptions run_options; + run_options.set_device_ordinal(backend().default_device_ordinal()); run_options.set_stream(&stream); run_options.set_allocator(backend().memory_allocator()); run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); @@ -138,73 +140,43 @@ StatusOr HloRunner::Execute( ServiceExecutableRunOptions service_run_options( run_options, backend().StreamBorrower(), backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - /*hlo_execution_profile=*/nullptr)); - TF_RET_CHECK(stream.BlockHostUntilDone()); - allocations_.push_back(result); - - *result_shape = executable->result_shape(); - - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + // Copy arguments to device. + std::vector> argument_buffers; + std::vector argument_buffer_ptrs; + for (Literal* argument : arguments) { TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } + std::unique_ptr argument_buffer, + backend().transfer_manager()->AllocateScopedShapedBuffer( + argument->shape(), run_options.allocator(), + run_options.device_ordinal())); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + stream.parent(), *argument, *argument_buffer)); + argument_buffers.push_back(std::move(argument_buffer)); + argument_buffer_ptrs.push_back(argument_buffers.back().get()); } - return result; -} - -StatusOr HloRunner::TransferToDevice( - const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; -} - -StatusOr> HloRunner::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return std::move(literal); -} + TF_ASSIGN_OR_RETURN( + std::unique_ptr result, + executable->ExecuteOnStream(&service_run_options, argument_buffer_ptrs, + /*hlo_execution_profile=*/nullptr)); -StatusOr> HloRunner::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments, - bool run_hlo_passes) { - Shape result_shape; + // Create a ScopedShapedBuffer of the result to manage deallocation. This will + // deallocate all the device memory when it goes out of scope. TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase device_base, - Execute(std::move(module), arguments, &result_shape, run_hlo_passes)); - return TransferFromDevice(result_shape, device_base); + std::unique_ptr scoped_result, + ScopedShapedBuffer::MakeScoped(result.get(), run_options.allocator())); + + auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice( + stream.parent(), *scoped_result); + if (result_literal.ok()) { + VLOG(4) << "Executed binary and got result: " + << result_literal.ValueOrDie()->ToString(); + } else { + VLOG(4) << "Executed binary and got status: " + << result_literal.status().ToString(); + } + return result_literal; } Backend& HloRunner::backend() { diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h index 95cddafc91ff40948efc4b0744343d994cf84f3a..cbaebc68bee708090b8ccb2eae19b556c4d6d453 100644 --- a/tensorflow/compiler/xla/service/hlo_runner.h +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -35,7 +35,8 @@ namespace xla { // A base class for running an HloModule. This executes the given HloModule on a // certain backend directly without using the client interface. HloModule can be -// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +// explicitly built, or loaded from a serialization file (e.g., hlo proto +// file), or parsed from a hlo textual IR string. class HloRunner { public: HloRunner(); @@ -44,56 +45,34 @@ class HloRunner { ~HloRunner(); + // Converts an HloModule from the given hlo textual IR string (in + // HloModule::ToString format). + static StatusOr> CreateModuleFromString( + const tensorflow::StringPiece hlo_string, + const DebugOptions& debug_options); + // Reads the proto file in xla.HloProto format, creates and returns the - // HloModule. Will try to parse the filename as binary proto, then try as - // text proto if that fails. - static StatusOr> ReadModuleFromHloProtoFile( + // HloModule. + static StatusOr> ReadModuleFromBinaryProtoFile( + const std::string& filename, const DebugOptions& debug_options); + static StatusOr> ReadModuleFromTextProtoFile( const std::string& filename, const DebugOptions& debug_options); // Reads the hlo text dump file in HloModule::ToString format, creates and // returns the HloModule. - static StatusOr> ReadModuleFromHloTextDumpFile( - const std::string& filename, const DebugOptions& debug_options); - - // Tries to parse the filename specified first as binary proto format, then - // as a textual proto format, then textual IR, then gives up if both fail. - // ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used - // explicitly when you know the format, this if you don't. - static StatusOr> ReadModule( + static StatusOr> ReadModuleFromHloTextFile( const std::string& filename, const DebugOptions& debug_options); // Executes the given module with given literals as input and returns the // result as a Literal. The LiteralPtr type accepts Literal* or // std::unique_ptr. - // If run_hlo_passes is true, the module will be executed without Hlo + // + // If run_hlo_passes is false, the module will be executed without Hlo // optimization. template StatusOr> Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals, - bool run_hlo_passes = true); - - // Executes the given module and returns a global data handle. - StatusOr Execute( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape, bool run_hlo_passes = true); - - // Transfers the given literal to the device and returns the data handle. - StatusOr TransferToDevice( - const Literal& literal); - - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - StatusOr> TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); - - // Executes the given module and return the result as a Literal. - StatusOr> ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes = true); // If backend is not created in the constructor, creates and returns the @@ -104,9 +83,12 @@ class HloRunner { Backend& backend(); private: - struct EigenThreadPoolWrapper; + StatusOr> ExecuteInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + bool run_hlo_passes = true); - std::vector allocations_; + struct EigenThreadPoolWrapper; std::unique_ptr thread_pool_wrapper_; @@ -116,15 +98,14 @@ class HloRunner { template StatusOr> HloRunner::Execute( std::unique_ptr module, - const tensorflow::gtl::ArraySlice literals, + const tensorflow::gtl::ArraySlice arguments, bool run_hlo_passes) { - std::vector arguments; - for (const auto& literal : literals) { - TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, - TransferToDevice(*literal)); - arguments.push_back(argument); + // Construct a vector of plain pointers for the arguments. + std::vector argument_pointers; + for (const auto& argument : arguments) { + argument_pointers.push_back(&*argument); } - return ExecuteAndTransfer(std::move(module), arguments, run_hlo_passes); + return ExecuteInternal(std::move(module), argument_pointers, run_hlo_passes); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 8ccbcaeee4a9c9e94b344231953e20ac8f4b2053..5f5a930dad002c215a5332286ade97ef19cc67af 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_scheduling.h" +#include #include #include @@ -31,6 +32,8 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +using ::tensorflow::strings::HumanReadableNumBytes; + namespace xla { StatusOr MinimumMemoryForSequence( @@ -215,32 +218,26 @@ class ListScheduler { } } - std::list ready_list; + auto priority_comparator = [this](const ReadyListEntry& lhs, + const ReadyListEntry& rhs) { + return GetPriority(lhs) < GetPriority(rhs); + }; + std::priority_queue, + decltype(priority_comparator)> + ready_queue(priority_comparator); for (auto* instruction : computation_.instructions()) { // Instruction with no operands or control predecessors will // not be in the map. if (unscheduled_pred_count.count(instruction) == 0) { - ready_list.push_back(MakeReadyListEntry(instruction)); + ready_queue.emplace(MakeReadyListEntry(instruction)); } } - while (!ready_list.empty()) { - // Select the highest priority HLO instruction from the ready list. - auto best_it = ready_list.begin(); - Priority best_priority = GetPriority(*best_it); - for (auto ready_it = std::next(ready_list.begin()); - ready_it != ready_list.end(); ++ready_it) { - Priority priority = GetPriority(*ready_it); - if (priority > best_priority) { - best_it = ready_it; - best_priority = priority; - } - } - + while (!ready_queue.empty()) { // Remove the selected instruction from the ready list and add it to the // schedule. - const HloInstruction* best = best_it->instruction; - ready_list.erase(best_it); + const HloInstruction* best = ready_queue.top().instruction; + ready_queue.pop(); schedule.push_back(best); scheduled_instructions_.insert(best); @@ -255,7 +252,7 @@ class ListScheduler { int64 pred_count = --unscheduled_pred_count.at(inst); CHECK_GE(pred_count, 0); if (pred_count == 0) { - ready_list.push_back(MakeReadyListEntry(inst)); + ready_queue.emplace(MakeReadyListEntry(inst)); } }; // TODO(b/34466113): Replace this and above with successors() or @@ -367,7 +364,17 @@ StatusOr MinimumMemoryForComputation( StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, const TuplePointsToAnalysis& points_to_analysis, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { + VLOG(2) << "Computation: " << computation.name(); + if (algorithm == SchedulerAlgorithm::kListSchedule) { + return ListScheduler::Run(computation, points_to_analysis, size_function); + } + if (algorithm == SchedulerAlgorithm::kDfsSchedule) { + return RunDFSMemoryScheduler(computation, points_to_analysis, + size_function); + } + // We try both a list-scheduler based ordering and a DFS based ordering, and // choose whichever returns a lower min-memory, not accounting for // fragmentation. @@ -382,7 +389,7 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 list_memory, MinimumMemoryForComputation(computation, list_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Min-memory list sequence: " << HumanReadableNumBytes(list_memory); TF_ASSIGN_OR_RETURN( std::vector dfs_sequence, @@ -391,13 +398,15 @@ StatusOr> CreateMemoryMinimizingSequence( const int64 dfs_memory, MinimumMemoryForComputation(computation, dfs_sequence, points_to_analysis, size_function)); - VLOG(2) << "Min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Min-memory dfs sequence: " << HumanReadableNumBytes(dfs_memory); if (list_memory <= dfs_memory) { - VLOG(2) << "Chose min-memory list sequence: " << list_memory << " bytes"; + VLOG(2) << "Chose min-memory list sequence: " + << HumanReadableNumBytes(list_memory); return list_sequence; } else { - VLOG(2) << "Chose min-memory dfs sequence: " << dfs_memory << " bytes"; + VLOG(2) << "Chose min-memory dfs sequence: " + << HumanReadableNumBytes(dfs_memory); return dfs_sequence; } } @@ -405,27 +414,30 @@ StatusOr> CreateMemoryMinimizingSequence( } // namespace StatusOr -CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function) { +CreateMemoryMinimizingSequence(const HloModule& module, + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); for (const auto* computation : module.MakeNonfusionComputations()) { - TF_ASSIGN_OR_RETURN(sequence[computation], - CreateMemoryMinimizingSequence( - *computation, *points_to_analysis, size_function)); + TF_ASSIGN_OR_RETURN( + sequence[computation], + CreateMemoryMinimizingSequence(*computation, *points_to_analysis, + size_function, algorithm)); } return sequence; } StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function) { + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm) { CHECK(!computation.IsFusionComputation()); TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(computation.parent())); return CreateMemoryMinimizingSequence(computation, *points_to_analysis, - size_function); + size_function, algorithm); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.h b/tensorflow/compiler/xla/service/hlo_scheduling.h index ec92a56b962152b15981f868369683144aa7c76a..1d1eb1e064f75c2220b39e84b010e720a0c37880 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.h +++ b/tensorflow/compiler/xla/service/hlo_scheduling.h @@ -33,17 +33,28 @@ StatusOr MinimumMemoryForSequence( const SequentialHloOrdering::HloModuleSequence& module_sequence, const LogicalBuffer::SizeFunction& size_function); +enum class SchedulerAlgorithm { + kListSchedule, + kDfsSchedule, + + // Selects the available scheduler algorithm that had the minimum memory in + // the resulting sequence (a la MinimumMemoryForSequence). + kAuto, +}; + // Returns an HloModuleSequence which seeks to minimize the memory required for // the computation. size_function is the function returning the number of bytes // required for a LogicalBuffer. StatusOr CreateMemoryMinimizingSequence( - const HloModule& module, const LogicalBuffer::SizeFunction& size_function); + const HloModule& module, const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); // Overload of above that computes the sequence for a single computation. StatusOr> CreateMemoryMinimizingSequence( const HloComputation& computation, - const LogicalBuffer::SizeFunction& size_function); + const LogicalBuffer::SizeFunction& size_function, + SchedulerAlgorithm algorithm = SchedulerAlgorithm::kAuto); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc index d1adec31c21fe55001db4d522ddda27dd538bc95..447c2446668253c932b44b51b2db22bfd47f9957 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.cc +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -246,7 +246,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape, // The tile rank must be the same as the input rank. if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { return tensorflow::errors::InvalidArgument( - "Tile rank is different to the input rank"); + "Tile rank is different to the input rank. sharding=", ToString(), + ", input_shape=", ShapeUtil::HumanString(shape)); } // The tile shape must not be the same as the input shape without maximal_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h index 1a6988a2dc872a39ff6b0551adf7ddb871f0d72a..7263198385cf0c84b1dac1e15177dcac99adaafb 100644 --- a/tensorflow/compiler/xla/service/hlo_sharding.h +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -80,6 +80,17 @@ class HloSharding { return HloSharding(flattened_list); } + // Creates a new sharding for a tuple type. The requested tuple shape must not + // be nested. For nested tuples, use the ShapeTree overload. + static HloSharding Tuple(const Shape& tuple_shape, + tensorflow::gtl::ArraySlice shardings) { + CHECK(ShapeUtil::IsTuple(tuple_shape)); + CHECK(!ShapeUtil::IsNestedTuple(tuple_shape)); + std::vector flattened_list(shardings.begin(), shardings.end()); + CHECK_EQ(flattened_list.size(), ShapeUtil::TupleElementCount(tuple_shape)); + return HloSharding(flattened_list); + } + // Create a new sharding from a protobuf OpSharding. static StatusOr FromProto(const OpSharding& proto); diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 101a710d1cad9401134fdfe1d0ec9df241bc01e1..3dc733940fc89952bd5e75a9b28d9cbf356f8000 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -166,7 +166,7 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, layout_string = ShapeUtil::HumanStringWithLayout(instruction->shape()); } else { layout_string = StrCat( - "{", Join(instruction->shape().layout().minor_to_major(), ","), "}"); + "{", Join(LayoutUtil::MinorToMajor(instruction->shape()), ","), "}"); } attrs["layout"].set_s(layout_string); } diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h index 9aa3e501d5f85e3b61b20555e3d13c5687f33f2f..c4876b852e32d34693202f4023aa20ad2b301ffd 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/xla.pb.h" @@ -56,4 +56,4 @@ class HloTfGraphBuilder { } // namespace hlo_graph_dumper } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_TFGRAPH_BUILDER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 15188c4057eca8eea1805e599cd020c045fdd10a..e2b3bb9d71497c352b0b92add2d2f6b4b777bee8 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -13,413 +13,529 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "tensorflow/compiler/xla/service/hlo_verifier.h" -#include "tensorflow/compiler/xla/service/shape_inference.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { -namespace { +Status ShapeVerifier::HandleElementwiseUnary(HloInstruction* hlo) { + return CheckUnaryShape(hlo); +} -// Visitor which verifies that the output shape is correctly set. Verifies -// against the inferred shape for the instruction. -// TODO(b/26024837): Check output shape for all instruction types. -class ShapeVerifier : public DfsHloVisitor { - public: - explicit ShapeVerifier( - const std::function& shape_size_fn) - : shape_size_fn_(shape_size_fn) {} - - Status HandleElementwiseUnary(HloInstruction* hlo) override { - return CheckUnaryShape(hlo); - } +Status ShapeVerifier::HandleElementwiseBinary(HloInstruction* hlo) { + return CheckBinaryShape(hlo); +} - Status HandleElementwiseBinary(HloInstruction* hlo) override { - return CheckBinaryShape(hlo); - } +Status ShapeVerifier::HandleClamp(HloInstruction* clamp) { + return CheckTernaryShape(clamp); +} - Status HandleClamp(HloInstruction* clamp) override { - return CheckTernaryShape(clamp); - } +Status ShapeVerifier::HandleSelect(HloInstruction* select) { + return CheckTernaryShape(select); +} - Status HandleSelect(HloInstruction* select) override { - return CheckTernaryShape(select); +Status ShapeVerifier::HandleConcatenate(HloInstruction* concatenate) { + std::vector operand_shapes; + for (const HloInstruction* operand : concatenate->operands()) { + operand_shapes.push_back(&operand->shape()); } + return CheckShape(concatenate, + ShapeInference::InferConcatOpShape( + operand_shapes, concatenate->concatenate_dimension())); +} - Status HandleConcatenate(HloInstruction* concatenate) override { - std::vector operand_shapes; - for (const HloInstruction* operand : concatenate->operands()) { - operand_shapes.push_back(&operand->shape()); - } - return CheckShape( - concatenate, ShapeInference::InferConcatOpShape( - operand_shapes, concatenate->concatenate_dimension())); - } +Status ShapeVerifier::HandleConvert(HloInstruction* convert) { + return CheckShape(convert, ShapeInference::InferConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); +} - Status HandleConvert(HloInstruction* convert) override { - return CheckShape(convert, ShapeInference::InferConvertShape( - convert->operand(0)->shape(), - convert->shape().element_type())); - } +Status ShapeVerifier::HandleBitcastConvert(HloInstruction* convert) { + return CheckShape(convert, ShapeInference::InferBitcastConvertShape( + convert->operand(0)->shape(), + convert->shape().element_type())); +} - Status HandleBitcastConvert(HloInstruction* convert) override { - return CheckShape(convert, ShapeInference::InferBitcastConvertShape( - convert->operand(0)->shape(), - convert->shape().element_type())); - } +Status ShapeVerifier::HandleCopy(HloInstruction* copy) { + return CheckUnaryShape(copy); +} - Status HandleCopy(HloInstruction* copy) override { - return CheckUnaryShape(copy); - } +Status ShapeVerifier::HandleDot(HloInstruction* dot) { + TF_ASSIGN_OR_RETURN(const Shape expected, + ShapeInference::InferDotOpShape( + dot->operand(0)->shape(), dot->operand(1)->shape(), + dot->dot_dimension_numbers())); + return CheckShape(dot, expected); +} - Status HandleDot(HloInstruction* dot) override { - return CheckBinaryShape(dot); - } +Status ShapeVerifier::HandleConvolution(HloInstruction* convolution) { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferConvolveShape( + convolution->operand(0)->shape(), convolution->operand(1)->shape(), + convolution->window(), convolution->convolution_dimension_numbers())); + return CheckShape(convolution, expected); +} - Status HandleConvolution(HloInstruction* convolution) override { - TF_ASSIGN_OR_RETURN( - const Shape expected, - ShapeInference::InferConvolveShape( - convolution->operand(0)->shape(), convolution->operand(1)->shape(), - convolution->window(), - convolution->convolution_dimension_numbers())); - return CheckShape(convolution, expected); - } +Status ShapeVerifier::HandleFft(HloInstruction* fft) { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferFftShape(fft->operand(0)->shape(), fft->fft_type(), + fft->fft_length())); + return CheckShape(fft, expected); +} - Status HandleCrossReplicaSum(HloInstruction* crs) override { - return CheckShape(crs, ShapeInference::InferCrossReplicaSumShape( - crs->operand(0)->shape())); +Status ShapeVerifier::HandleCrossReplicaSum(HloInstruction* crs) { + std::vector operand_shapes; + for (const HloInstruction* operand : crs->operands()) { + operand_shapes.push_back(&operand->shape()); } + return CheckShape(crs, + ShapeInference::InferCrossReplicaSumShape(operand_shapes)); +} - Status HandleReducePrecision(HloInstruction* reduce_precision) override { - return CheckShape(reduce_precision, - ShapeInference::InferReducePrecisionShape( - reduce_precision->operand(0)->shape(), - reduce_precision->exponent_bits(), - reduce_precision->mantissa_bits())); - } +Status ShapeVerifier::HandleReducePrecision(HloInstruction* reduce_precision) { + return CheckShape(reduce_precision, ShapeInference::InferReducePrecisionShape( + reduce_precision->operand(0)->shape(), + reduce_precision->exponent_bits(), + reduce_precision->mantissa_bits())); +} - Status HandleInfeed(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleInfeed(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleOutfeed(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { + // Outfeed has a separate shape field for the value which is outfed to the + // host. The shape of the instruction itself is always nil because the outfeed + // produces no HLO value in the graph. + if (!ShapeUtil::Compatible(outfeed->outfeed_shape(), + outfeed->operand(0)->shape())) { + return InvalidArgument( + "Expected outfeed to have shape compatible with operand's shape %s, " + "actual shape is %s:\n%s", + ShapeUtil::HumanString(outfeed->operand(0)->shape()).c_str(), + ShapeUtil::HumanString(outfeed->outfeed_shape()).c_str(), + outfeed->ToString().c_str()); + } + return CheckShape(outfeed, ShapeUtil::MakeNil()); +} - Status HandleRng(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleRng(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleReverse(HloInstruction* reverse) override { - return CheckShape( - reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), - reverse->dimensions())); - } +Status ShapeVerifier::HandleReverse(HloInstruction* reverse) { + return CheckShape( + reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), + reverse->dimensions())); +} - Status HandleSort(HloInstruction* sort) override { - return CheckUnaryShape(sort); - } +Status ShapeVerifier::HandleSort(HloInstruction* sort) { + return CheckUnaryShape(sort); +} - Status HandleConstant(HloInstruction* constant) override { - return CheckShape(constant, constant->literal().shape()); - } +Status ShapeVerifier::HandleConstant(HloInstruction* constant) { + return CheckShape(constant, constant->literal().shape()); +} - Status HandleGetTupleElement(HloInstruction* get_tuple_element) override { - return CheckShape(get_tuple_element, - ShapeInference::InferGetTupleElementShape( - get_tuple_element->operand(0)->shape(), - get_tuple_element->tuple_index())); - } +Status ShapeVerifier::HandleGetTupleElement(HloInstruction* get_tuple_element) { + return CheckShape(get_tuple_element, + ShapeInference::InferGetTupleElementShape( + get_tuple_element->operand(0)->shape(), + get_tuple_element->tuple_index())); +} - Status HandleReduce(HloInstruction* reduce) override { - return CheckShape( - reduce, - ShapeInference::InferReduceShape( - reduce->operand(0)->shape(), reduce->operand(1)->shape(), - reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleReduce(HloInstruction* reduce) { + return CheckShape( + reduce, + ShapeInference::InferReduceShape( + reduce->operand(0)->shape(), reduce->operand(1)->shape(), + reduce->dimensions(), reduce->to_apply()->ComputeProgramShape())); +} - Status HandleBitcast(HloInstruction* bitcast) override { - // Bitcasts can be any shape, as long as the size matches the operand size. - TF_RET_CHECK(shape_size_fn_(bitcast->shape()) == - shape_size_fn_(bitcast->operand(0)->shape())); - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleBitcast(HloInstruction* bitcast) { + return tensorflow::Status::OK(); +} - Status HandleBroadcast(HloInstruction* broadcast) override { - // HLO broadcast has no exact analog at the proto level so there is no - // ShapeInference method. Check the output shape explicitly. - const Shape& operand_shape = broadcast->operand(0)->shape(); - TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == - broadcast->dimensions().size()); - for (int64 operand_dimension = 0; - operand_dimension < ShapeUtil::Rank(operand_shape); - ++operand_dimension) { - int64 output_dimension = broadcast->dimensions()[operand_dimension]; - TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == - operand_shape.dimensions(operand_dimension)); - } - return tensorflow::Status::OK(); +Status ShapeVerifier::HandleBroadcast(HloInstruction* broadcast) { + // HLO broadcast has no exact analog at the proto level so there is no + // ShapeInference method. Check the output shape explicitly. + const Shape& operand_shape = broadcast->operand(0)->shape(); + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(broadcast, broadcast->shape())); + TF_RET_CHECK(ShapeUtil::Rank(operand_shape) == + broadcast->dimensions().size()); + for (int64 operand_dimension = 0; + operand_dimension < ShapeUtil::Rank(operand_shape); + ++operand_dimension) { + int64 output_dimension = broadcast->dimensions()[operand_dimension]; + TF_RET_CHECK(broadcast->shape().dimensions(output_dimension) == + operand_shape.dimensions(operand_dimension)) + << broadcast->ToString() << " operand shape " << operand_shape; } + return tensorflow::Status::OK(); +} - Status HandleReshape(HloInstruction* reshape) override { - TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == - ShapeUtil::ElementsIn(reshape->operand(0)->shape())); - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleReshape(HloInstruction* reshape) { + // Check for mixed precision. + TF_RETURN_IF_ERROR(CheckShape(reshape, reshape->shape())); + TF_RET_CHECK(ShapeUtil::ElementsIn(reshape->shape()) == + ShapeUtil::ElementsIn(reshape->operand(0)->shape())); + return tensorflow::Status::OK(); +} - Status HandleTranspose(HloInstruction* transpose) override { - return CheckShape(transpose, ShapeInference::InferTransposeShape( - transpose->operand(0)->shape(), - transpose->dimensions())); - } +Status ShapeVerifier::HandleTranspose(HloInstruction* transpose) { + return CheckShape( + transpose, ShapeInference::InferTransposeShape( + transpose->operand(0)->shape(), transpose->dimensions())); +} - Status HandleParameter(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleParameter(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleFusion(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleFusion(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleCall(HloInstruction* call) override { - // The shape of kCall should match the shape of the computation it calls. - return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandleCall(HloInstruction* call) { + // The shape of kCall should match the shape of the computation it calls. + return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); +} - Status HandleCustomCall(HloInstruction*) override { - return tensorflow::Status::OK(); - } +Status ShapeVerifier::HandleCustomCall(HloInstruction*) { + return tensorflow::Status::OK(); +} - Status HandleSlice(HloInstruction* slice) override { - return CheckShape(slice, - ShapeInference::InferSliceShape( - slice->operand(0)->shape(), slice->slice_starts(), - slice->slice_limits(), slice->slice_strides())); - } +Status ShapeVerifier::HandleSlice(HloInstruction* slice) { + return CheckShape(slice, + ShapeInference::InferSliceShape( + slice->operand(0)->shape(), slice->slice_starts(), + slice->slice_limits(), slice->slice_strides())); +} - Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { - return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( - dynamic_slice->operand(0)->shape(), - dynamic_slice->operand(1)->shape(), - dynamic_slice->dynamic_slice_sizes())); - } +Status ShapeVerifier::HandleDynamicSlice(HloInstruction* dynamic_slice) { + return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( + dynamic_slice->operand(0)->shape(), + dynamic_slice->operand(1)->shape(), + dynamic_slice->dynamic_slice_sizes())); +} - Status HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice) override { - return CheckShape(dynamic_update_slice, - ShapeInference::InferDynamicUpdateSliceShape( - dynamic_update_slice->operand(0)->shape(), - dynamic_update_slice->operand(1)->shape(), - dynamic_update_slice->operand(2)->shape())); - } +Status ShapeVerifier::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) { + return CheckShape(dynamic_update_slice, + ShapeInference::InferDynamicUpdateSliceShape( + dynamic_update_slice->operand(0)->shape(), + dynamic_update_slice->operand(1)->shape(), + dynamic_update_slice->operand(2)->shape())); +} - Status HandleTuple(HloInstruction* tuple) override { - return CheckVariadicShape(tuple); - } +Status ShapeVerifier::HandleTuple(HloInstruction* tuple) { + return CheckVariadicShape(tuple); +} - Status HandleMap(HloInstruction* map) override { - std::vector operand_shapes; - int64 max_operand_rank = 0; - for (const HloInstruction* operand : map->operands()) { - operand_shapes.push_back(&operand->shape()); - max_operand_rank = - std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); - } - // TODO(b/65689298) Remove code below once Map is generalized to accept - // arbitrary map dimensions. - std::vector map_dims(max_operand_rank); - std::iota(map_dims.begin(), map_dims.end(), 0); - return CheckShape( - map, - ShapeInference::InferMapShape( - operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)); - } +Status ShapeVerifier::HandleMap(HloInstruction* map) { + std::vector operand_shapes; + int64 max_operand_rank = 0; + for (const HloInstruction* operand : map->operands()) { + operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); + } + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + std::vector map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); + return CheckShape(map, ShapeInference::InferMapShape( + operand_shapes, + map->to_apply()->ComputeProgramShape(), map_dims)); +} - Status HandleReduceWindow(HloInstruction* reduce_window) override { - return CheckShape( - reduce_window, - ShapeInference::InferReduceWindowShape( - reduce_window->operand(0)->shape(), - reduce_window->operand(1)->shape(), reduce_window->window(), - reduce_window->to_apply()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleReduceWindow(HloInstruction* reduce_window) { + return CheckShape( + reduce_window, + ShapeInference::InferReduceWindowShape( + reduce_window->operand(0)->shape(), + reduce_window->operand(1)->shape(), reduce_window->window(), + reduce_window->to_apply()->ComputeProgramShape())); +} - Status HandleSelectAndScatter(HloInstruction* instruction) override { - return CheckShape( - instruction, - ShapeInference::InferSelectAndScatterShape( - instruction->operand(0)->shape(), - instruction->select()->ComputeProgramShape(), instruction->window(), - instruction->operand(1)->shape(), instruction->operand(2)->shape(), - instruction->scatter()->ComputeProgramShape())); - } +Status ShapeVerifier::HandleSelectAndScatter(HloInstruction* instruction) { + return CheckShape( + instruction, + ShapeInference::InferSelectAndScatterShape( + instruction->operand(0)->shape(), + instruction->select()->ComputeProgramShape(), instruction->window(), + instruction->operand(1)->shape(), instruction->operand(2)->shape(), + instruction->scatter()->ComputeProgramShape())); +} - Status HandleWhile(HloInstruction* xla_while) override { - // The shape of kWhile should match the shape of the body computation it - // calls. - return CheckShape(xla_while, - xla_while->while_body()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandleWhile(HloInstruction* xla_while) { + // The shape of kWhile should match the shape of the body computation it + // calls. + return CheckShape(xla_while, + xla_while->while_body()->ComputeProgramShape().result()); +} - Status HandleConditional(HloInstruction* conditional) override { - TF_RETURN_IF_ERROR(CheckShape( - conditional, - conditional->true_computation()->ComputeProgramShape().result())); - return CheckShape( - conditional, - conditional->false_computation()->ComputeProgramShape().result()); - } +Status ShapeVerifier::HandleConditional(HloInstruction* conditional) { + TF_RETURN_IF_ERROR(CheckShape( + conditional, + conditional->true_computation()->ComputeProgramShape().result())); + return CheckShape( + conditional, + conditional->false_computation()->ComputeProgramShape().result()); +} - Status HandlePad(HloInstruction* pad) override { - return CheckShape(pad, - ShapeInference::InferPadShape(pad->operand(0)->shape(), - pad->operand(1)->shape(), - pad->padding_config())); - } +Status ShapeVerifier::HandlePad(HloInstruction* pad) { + return CheckShape(pad, ShapeInference::InferPadShape(pad->operand(0)->shape(), + pad->operand(1)->shape(), + pad->padding_config())); +} - Status HandleSend(HloInstruction* send) override { - TF_RET_CHECK(send->users().size() == 1); - const HloInstruction* send_done = send->users().front(); - TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape( - send, ShapeUtil::MakeTupleShape( - {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); - } +Status ShapeVerifier::HandleSend(HloInstruction* send) { + TF_RET_CHECK(send->users().size() == 1); + const HloInstruction* send_done = send->users().front(); + TF_RET_CHECK(send_done->opcode() == HloOpcode::kSendDone); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape( + send, ShapeUtil::MakeTupleShape( + {send->operand(0)->shape(), ShapeUtil::MakeShape(U32, {})})); +} - Status HandleSendDone(HloInstruction* send_done) override { - TF_RET_CHECK(send_done->operands().size() == 1); - const HloInstruction* send = send_done->operand(0); - TF_RET_CHECK(send->opcode() == HloOpcode::kSend); - TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); - return CheckShape(send_done, ShapeUtil::MakeNil()); - } +Status ShapeVerifier::HandleSendDone(HloInstruction* send_done) { + TF_RET_CHECK(send_done->operands().size() == 1); + const HloInstruction* send = send_done->operand(0); + TF_RET_CHECK(send->opcode() == HloOpcode::kSend); + TF_RETURN_IF_ERROR(CheckSameChannel(send, send_done)); + return CheckShape(send_done, ShapeUtil::MakeNil()); +} - Status HandleRecv(HloInstruction* recv) override { - TF_RET_CHECK(recv->users().size() == 1); - const HloInstruction* recv_done = recv->users().front(); - TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv, - ShapeUtil::MakeTupleShape( - {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); - } +Status ShapeVerifier::HandleRecv(HloInstruction* recv) { + TF_RET_CHECK(recv->users().size() == 1); + const HloInstruction* recv_done = recv->users().front(); + TF_RET_CHECK(recv_done->opcode() == HloOpcode::kRecvDone); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv, + ShapeUtil::MakeTupleShape( + {recv_done->shape(), ShapeUtil::MakeShape(U32, {})})); +} - Status HandleRecvDone(HloInstruction* recv_done) override { - TF_RET_CHECK(recv_done->operands().size() == 1); - const HloInstruction* recv = recv_done->operand(0); - TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); - TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); - return CheckShape(recv_done, recv->shape().tuple_shapes(0)); - } +Status ShapeVerifier::HandleRecvDone(HloInstruction* recv_done) { + TF_RET_CHECK(recv_done->operands().size() == 1); + const HloInstruction* recv = recv_done->operand(0); + TF_RET_CHECK(recv->opcode() == HloOpcode::kRecv); + TF_RETURN_IF_ERROR(CheckSameChannel(recv, recv_done)); + return CheckShape(recv_done, recv->shape().tuple_shapes(0)); +} - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override { - return CheckShape(batch_norm_training, - ShapeInference::InferBatchNormTrainingShape( - batch_norm_training->operand(0)->shape(), - batch_norm_training->operand(1)->shape(), - batch_norm_training->operand(2)->shape(), - batch_norm_training->feature_index())); - } +Status ShapeVerifier::HandleBatchNormTraining( + HloInstruction* batch_norm_training) { + return CheckShape(batch_norm_training, + ShapeInference::InferBatchNormTrainingShape( + batch_norm_training->operand(0)->shape(), + batch_norm_training->operand(1)->shape(), + batch_norm_training->operand(2)->shape(), + batch_norm_training->feature_index())); +} - Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) override { - return CheckShape(batch_norm_inference, - ShapeInference::InferBatchNormInferenceShape( - batch_norm_inference->operand(0)->shape(), - batch_norm_inference->operand(1)->shape(), - batch_norm_inference->operand(2)->shape(), - batch_norm_inference->operand(3)->shape(), - batch_norm_inference->operand(4)->shape(), - batch_norm_inference->feature_index())); - } +Status ShapeVerifier::HandleBatchNormInference( + HloInstruction* batch_norm_inference) { + return CheckShape(batch_norm_inference, + ShapeInference::InferBatchNormInferenceShape( + batch_norm_inference->operand(0)->shape(), + batch_norm_inference->operand(1)->shape(), + batch_norm_inference->operand(2)->shape(), + batch_norm_inference->operand(3)->shape(), + batch_norm_inference->operand(4)->shape(), + batch_norm_inference->feature_index())); +} - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override { - return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( - batch_norm_grad->operand(0)->shape(), - batch_norm_grad->operand(1)->shape(), - batch_norm_grad->operand(2)->shape(), - batch_norm_grad->operand(3)->shape(), - batch_norm_grad->operand(4)->shape(), - batch_norm_grad->feature_index())); - } +Status ShapeVerifier::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { + return CheckShape(batch_norm_grad, ShapeInference::InferBatchNormGradShape( + batch_norm_grad->operand(0)->shape(), + batch_norm_grad->operand(1)->shape(), + batch_norm_grad->operand(2)->shape(), + batch_norm_grad->operand(3)->shape(), + batch_norm_grad->operand(4)->shape(), + batch_norm_grad->feature_index())); +} - Status FinishVisit(HloInstruction*) override { - return tensorflow::Status::OK(); - } +namespace { - private: - // Check the instruction's shape against the given expected shape and return - // an appropriate error if there is a mismatch. - Status CheckShape(const HloInstruction* instruction, - const Shape& expected_shape) { - if (!ShapeUtil::Compatible(instruction->shape(), expected_shape)) { - return InvalidArgument( - "Expected instruction to have shape compatible with %s, actual " - "shape is %s:\n%s", - ShapeUtil::HumanString(expected_shape).c_str(), - ShapeUtil::HumanString(instruction->shape()).c_str(), - instruction->ToString().c_str()); +// Checks that the instruction does not have mixed precision floating point +// inputs. +Status CheckMixedPrecisionOperands(const HloInstruction* instruction) { + switch (instruction->opcode()) { + // White list the following opcodes for mixed-precision check, because they + // involve data pass through or grouping via tuples, where the precisions + // of buffers can be different. + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConstant: + case HloOpcode::kCrossReplicaSum: + case HloOpcode::kCustomCall: + case HloOpcode::kFusion: + case HloOpcode::kGetTupleElement: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kReducePrecision: + case HloOpcode::kSelect: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kTuple: + case HloOpcode::kWhile: + break; + default: { + PrimitiveType fp_type = PRIMITIVE_TYPE_INVALID; + for (auto operand : instruction->operands()) { + TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( + operand->shape(), + [&](const Shape& subshape, const ShapeIndex& index) { + if (!ShapeUtil::ElementIsFloating(subshape)) { + return Status::OK(); + } + if (fp_type == PRIMITIVE_TYPE_INVALID) { + fp_type = subshape.element_type(); + } else if (fp_type != subshape.element_type()) { + return FailedPrecondition( + "Seen floating point types of different precisions in " + "%s, but mixed precision is disallowed.", + instruction->ToString().c_str()); + } + return Status::OK(); + })); + } } - return tensorflow::Status::OK(); } + return Status::OK(); +} - // Overload which takes a StatusOr to reduce boilerplate in the caller. - Status CheckShape(const HloInstruction* instruction, - const StatusOr& expected_shape_status) { - if (!expected_shape_status.ok()) { - Status s = expected_shape_status.status(); - tensorflow::errors::AppendToMessage(&s, ", for instruction ", - instruction->ToString()); - return s; - } - return CheckShape(instruction, expected_shape_status.ValueOrDie()); - } +} // namespace - // Check a unary (binary, etc) instruction's shape against the inferred shape. - Status CheckUnaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferUnaryOpShape( - instruction->opcode(), instruction->operand(0))); - } - Status CheckBinaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferBinaryOpShape( - instruction->opcode(), instruction->operand(0), - instruction->operand(1))); - } - Status CheckTernaryShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferTernaryOpShape( - instruction->opcode(), instruction->operand(0), - instruction->operand(1), instruction->operand(2))); +Status ShapeVerifier::CheckShape(const HloInstruction* instruction, + const Shape& inferred_shape) { + // If allow_mixed_precision_ is false, check if there are operands with + // different precisions. We need this check because ShapeInference allows + // mixed precision inputs. + if (!allow_mixed_precision_) { + TF_RETURN_IF_ERROR(CheckMixedPrecisionOperands(instruction)); + } + + // Check if the output shape matches the expected shape. + bool compatible; + // We treat BF16 and F32 as compatible types if mixed precision is allowed, + // but only when the instruction defines the BF16/F32 buffer. + switch (instruction->opcode()) { + case HloOpcode::kSelect: + if (ShapeUtil::IsTuple(inferred_shape) || !allow_mixed_precision_) { + // Select only defines the top-level buffer, which in this case is the + // tuple, so we cannot allow mixed precision. + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } + break; + case HloOpcode::kGetTupleElement: + case HloOpcode::kTuple: + // Tuple and GetTupleElement do not define BF16/F32 buffers, so mixed + // precision is disallowed. + case HloOpcode::kConstant: + case HloOpcode::kBitcast: + case HloOpcode::kBitcastConvert: + case HloOpcode::kCall: + case HloOpcode::kConditional: + case HloOpcode::kConvert: + case HloOpcode::kCustomCall: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kParameter: + case HloOpcode::kRecv: + case HloOpcode::kRecvDone: + case HloOpcode::kSend: + case HloOpcode::kSendDone: + case HloOpcode::kWhile: + // The above opcodes should match the expected shapes exactly. + compatible = ShapeUtil::Compatible(instruction->shape(), inferred_shape); + break; + default: + if (allow_mixed_precision_) { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision( + instruction->shape(), inferred_shape); + } else { + compatible = + ShapeUtil::Compatible(instruction->shape(), inferred_shape); + } } - Status CheckVariadicShape(const HloInstruction* instruction) { - return CheckShape(instruction, - ShapeInference::InferVariadicOpShape( - instruction->opcode(), instruction->operands())); + if (!compatible) { + return InvalidArgument( + "Expected instruction to have shape compatible with %s, actual " + "shape is %s:\n%s", + ShapeUtil::HumanString(inferred_shape).c_str(), + ShapeUtil::HumanString(instruction->shape()).c_str(), + instruction->ToString().c_str()); } + return tensorflow::Status::OK(); +} - // Checks if the given two instructions shares the same channel id. - Status CheckSameChannel(const HloInstruction* instr1, - const HloInstruction* instr2) { - if (instr1->channel_id() != instr2->channel_id()) { - return FailedPrecondition( - "Expected to have the same channel id, actual channel ids are: %s " - "(%lld), %s (%lld)", - instr1->ToString().c_str(), instr1->channel_id(), - instr2->ToString().c_str(), instr2->channel_id()); - } - return tensorflow::Status::OK(); +Status ShapeVerifier::CheckShape(const HloInstruction* instruction, + const StatusOr& inferred_shape_status) { + if (!inferred_shape_status.ok()) { + Status s = inferred_shape_status.status(); + tensorflow::errors::AppendToMessage(&s, ", for instruction ", + instruction->ToString()); + return s; } + return CheckShape(instruction, inferred_shape_status.ValueOrDie()); +} - // Returns the size of a Shape in bytes. - const std::function shape_size_fn_; -}; +Status ShapeVerifier::CheckUnaryShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferUnaryOpShape(instruction->opcode(), + instruction->operand(0))); +} + +Status ShapeVerifier::CheckBinaryShape(const HloInstruction* instruction) { + return CheckShape( + instruction, ShapeInference::InferBinaryOpShape(instruction->opcode(), + instruction->operand(0), + instruction->operand(1))); +} + +Status ShapeVerifier::CheckTernaryShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferTernaryOpShape( + instruction->opcode(), instruction->operand(0), + instruction->operand(1), instruction->operand(2))); +} + +Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) { + return CheckShape(instruction, + ShapeInference::InferVariadicOpShape( + instruction->opcode(), instruction->operands())); +} + +// Checks if the given two instructions shares the same channel id. +Status ShapeVerifier::CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2) { + if (instr1->channel_id() != instr2->channel_id()) { + return FailedPrecondition( + "Expected to have the same channel id, actual channel ids are: %s " + "(%lld), %s (%lld)", + instr1->ToString().c_str(), instr1->channel_id(), + instr2->ToString().c_str(), instr2->channel_id()); + } + return tensorflow::Status::OK(); +} string ComputationsToString( tensorflow::gtl::ArraySlice computations) { @@ -429,7 +545,62 @@ string ComputationsToString( }); } -} // namespace +// Verifies various invariants about the structure of the HLO: +// +// (1) each instruction has a non-null parent() set to the HloComputation which +// contains it. +// +// (2) each computation has a non-null parent() set to the HloModule which +// contains it. +// +// (3) the operands of each instruction are in the same computation as the +// instruction. +Status VerifyHloStructure(HloModule* module) { + for (const HloComputation* computation : module->computations()) { + if (computation->parent() == nullptr) { + return FailedPrecondition("Computation %s has a null parent pointer", + computation->name().c_str()); + } + if (computation->parent() != module) { + return FailedPrecondition( + "Computation %s parent() does not point to parent module", + computation->name().c_str()); + } + + for (const HloInstruction* instruction : computation->instructions()) { + if (instruction->parent() == nullptr) { + return FailedPrecondition("Instruction %s has a null parent pointer", + instruction->name().c_str()); + } + if (instruction->parent() != computation) { + return FailedPrecondition( + "Instruction %s parent() does not point to parent computation", + instruction->name().c_str()); + } + } + } + + // Check that operands are in the same computation separately from verifying + // parent() correctness so conditions like a null HloInstruction::parent() are + // identified and reported explicitly above rather than reporting a mismatched + // operand. + for (const HloComputation* computation : module->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { + for (int i = 0; i < instruction->operand_count(); ++i) { + const HloInstruction* operand = instruction->operand(i); + if (operand->parent() != instruction->parent()) { + return FailedPrecondition( + "Operand %d (%s) of instruction %s is in a different " + "computation: %s vs %s", + i, operand->name().c_str(), instruction->name().c_str(), + operand->parent()->name().c_str(), + instruction->parent()->name().c_str()); + } + } + } + } + return tensorflow::Status::OK(); +} Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // The parent fusion instruction of the fusion computation must be 'fusion'. @@ -549,8 +720,9 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { } StatusOr HloVerifier::Run(HloModule* module) { + TF_RETURN_IF_ERROR(VerifyHloStructure(module)); + tensorflow::gtl::FlatMap instructions; - ShapeVerifier shape_verifier(shape_size_fn_); for (auto* computation : module->computations()) { for (const auto& instruction : computation->instructions()) { @@ -630,7 +802,8 @@ StatusOr HloVerifier::Run(HloModule* module) { instructions[instruction->name()] = instruction; } - TF_RETURN_IF_ERROR(computation->Accept(&shape_verifier)); + std::unique_ptr shape_verifier = shape_verifier_factory_(); + TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get())); } return false; diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index e35a7f3642ccf91df37f69a3a11bd8c8e428b846..7eccf834bbd3ac6af0d5762a7241758b416a3523 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -13,19 +13,124 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" + namespace xla { +// Visitor which verifies that the output shape is correctly set. Verifies +// against the inferred shape for the instruction. +// TODO(b/26024837): Check output shape for all instruction types. +class ShapeVerifier : public DfsHloVisitor { + public: + explicit ShapeVerifier() : allow_mixed_precision_(false) {} + explicit ShapeVerifier(bool allow_mixed_precision) + : allow_mixed_precision_(allow_mixed_precision) {} + + Status HandleElementwiseUnary(HloInstruction* hlo) override; + Status HandleElementwiseBinary(HloInstruction* hlo) override; + Status HandleClamp(HloInstruction* clamp) override; + Status HandleSelect(HloInstruction* select) override; + Status HandleConcatenate(HloInstruction* concatenate) override; + Status HandleConvert(HloInstruction* convert) override; + Status HandleBitcastConvert(HloInstruction* convert) override; + Status HandleCopy(HloInstruction* copy) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleFft(HloInstruction* fft) override; + Status HandleCrossReplicaSum(HloInstruction* crs) override; + Status HandleReducePrecision(HloInstruction* reduce_precision) override; + Status HandleInfeed(HloInstruction*) override; + Status HandleOutfeed(HloInstruction*) override; + Status HandleRng(HloInstruction*) override; + Status HandleReverse(HloInstruction* reverse) override; + Status HandleSort(HloInstruction* sort) override; + Status HandleConstant(HloInstruction* constant) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleBroadcast(HloInstruction* broadcast) override; + Status HandleReshape(HloInstruction* reshape) override; + Status HandleTranspose(HloInstruction* transpose) override; + Status HandleParameter(HloInstruction*) override; + Status HandleFusion(HloInstruction*) override; + Status HandleCall(HloInstruction* call) override; + Status HandleCustomCall(HloInstruction*) override; + Status HandleSlice(HloInstruction* slice) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleMap(HloInstruction* map) override; + Status HandleReduceWindow(HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleWhile(HloInstruction* xla_while) override; + Status HandleConditional(HloInstruction* conditional) override; + Status HandlePad(HloInstruction* pad) override; + Status HandleSend(HloInstruction* send) override; + Status HandleSendDone(HloInstruction* send_done) override; + Status HandleRecv(HloInstruction* recv) override; + Status HandleRecvDone(HloInstruction* recv_done) override; + Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleBatchNormInference( + HloInstruction* batch_norm_inference) override; + Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; + + Status FinishVisit(HloInstruction*) override { + return tensorflow::Status::OK(); + } + + protected: + // Check the instruction's shape against the shape given by ShapeInference + // and return an appropriate error if there is a mismatch. + Status CheckShape(const HloInstruction* instruction, + const Shape& inferred_shape); + + // Overload which takes a StatusOr to reduce boilerplate in the caller. + Status CheckShape(const HloInstruction* instruction, + const StatusOr& inferred_shape_status); + + // Check a unary (binary, etc) instruction's shape against the inferred shape. + Status CheckUnaryShape(const HloInstruction* instruction); + Status CheckBinaryShape(const HloInstruction* instruction); + Status CheckTernaryShape(const HloInstruction* instruction); + Status CheckVariadicShape(const HloInstruction* instruction); + + // Checks if the given two instructions shares the same channel id. + Status CheckSameChannel(const HloInstruction* instr1, + const HloInstruction* instr2); + + private: + // Whether the inputs and output of an instruction can contain both F32s and + // BF16s. Tuples that include both F32s and BF16s are allowed regardless of + // this flag. + bool allow_mixed_precision_; +}; + // HLO pass that verifies invariants of HLO instructions for each computation in // the module. class HloVerifier : public HloPassInterface { public: - explicit HloVerifier(const std::function& shape_size_fn) - : shape_size_fn_(shape_size_fn) {} + using ShapeVerifierFactory = std::function()>; + + // Uses standard shape inference. + explicit HloVerifier() + : shape_verifier_factory_( + [] { return MakeUnique(false); }) {} + + explicit HloVerifier(bool allow_mixed_precision) + : shape_verifier_factory_([allow_mixed_precision] { + return MakeUnique(allow_mixed_precision); + }) {} + + // Uses custom shape verification. + explicit HloVerifier(ShapeVerifierFactory shape_verifier_factory) + : shape_verifier_factory_(std::move(shape_verifier_factory)) {} + ~HloVerifier() override = default; tensorflow::StringPiece name() const override { return "verifier"; } @@ -37,10 +142,13 @@ class HloVerifier : public HloPassInterface { // CHECKs various invariants of a fusion instruction. Status CheckFusionInstruction(HloInstruction* fusion) const; - // Returns the size of a Shape in bytes. - const std::function shape_size_fn_; + // Creates a ShapeVerifier that checks that shapes match inferred + // expectations. This is a factory function because ShapeVerifier, Note that + // ShapeVerifier, being a DfsHloVisitor, is stateful. We want a clean object + // for each run of the verifier. + ShapeVerifierFactory shape_verifier_factory_; }; } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_VERIFIER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_verifier_test.cc b/tensorflow/compiler/xla/service/hlo_verifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c92db0be14dceb32ea86521dcc99b8f63738e4a5 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_verifier_test.cc @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +using ::testing::HasSubstr; + +using HloVerifierTest = HloTestBase; + +TEST_F(HloVerifierTest, NullInstructionParent) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + negate->set_parent(nullptr); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); +} + +TEST_F(HloVerifierTest, NullComputationParent) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + HloComputation* computation = module->AddEntryComputation(builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + + computation->set_parent(nullptr); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), HasSubstr("has a null parent pointer")); +} + +TEST_F(HloVerifierTest, DifferentOperandParents) { + HloComputation::Builder builder(TestName()); + const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + HloInstruction* negate = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape, HloOpcode::kNegate, param)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + HloComputation::Builder emb_builder(TestName()); + HloInstruction* emb_param = emb_builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape, "param")); + module->AddEmbeddedComputation(emb_builder.Build()); + + TF_ASSERT_OK(verifier().Run(module.get()).status()); + TF_ASSERT_OK(negate->ReplaceOperandWith(0, emb_param)); + + auto status = verifier().Run(module.get()).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + HasSubstr("is in a different computation")); +} + +TEST_F(HloVerifierTest, ResetsShapeVerifierState) { + HloComputation::Builder builder(TestName()); + Shape s1 = ShapeUtil::MakeShape(F32, {1}); + Shape s2 = ShapeUtil::MakeShape(F32, {2}); + + HloInstruction* param = + builder.AddInstruction(HloInstruction::CreateParameter(0, s1, "param")); + + // Create an add instruction with the incorrect shape. + HloInstruction* add = builder.AddInstruction( + HloInstruction::CreateBinary(s2, HloOpcode::kAdd, param, param)); + + // In order to trigger the bug we're checking for, the instruction with the + // bad shape can't be the root of the computation. + builder.AddInstruction( + HloInstruction::CreateBinary(s2, HloOpcode::kMultiply, add, add)); + + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + // Run the verifier twice. It should fail both times, because it shouldn't + // carry state in its DFS visitor between runs. + EXPECT_FALSE(verifier().Run(module.get()).status().ok()); + EXPECT_FALSE(verifier().Run(module.get()).status().ok()); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index b7c40fdeeb157fc74900bd9cf9d68a06a2cb1d56..13e4557317f74b3fb46f07fb91c339fd2f34752f 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -25,6 +25,7 @@ namespace xla { using tensorflow::strings::Appendf; using tensorflow::strings::HumanReadableElapsedTime; using tensorflow::strings::HumanReadableNumBytes; +using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; string HumanReadableProfileBuilder::ToString() const { @@ -43,7 +44,12 @@ string HumanReadableProfileBuilder::ToString() const { } else { bytes_per_sec = HumanReadableNumBytes(op.bytes_accessed / CyclesToSeconds(op.cycles)); - bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + if (op.bytes_accessed > op.cycles) { + bytes_per_cycle = HumanReadableNumBytes(op.bytes_accessed / op.cycles); + } else { + bytes_per_cycle = + Printf("%.3fB", static_cast(op.bytes_accessed) / op.cycles); + } } double cycles_percent = 0; diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc new file mode 100644 index 0000000000000000000000000000000000000000..ada21345014dac70d61129aaf7bbc7466a7db914 --- /dev/null +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_dce.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +namespace { + +// Visitor for removing implicit broadcasts. +class ImplicitBroadcastVisitor : public DfsHloVisitorWithDefault { + public: + Status DefaultAction(HloInstruction* hlo_instruction) override { + return Status::OK(); + } + + Status HandleElementwiseBinary(HloInstruction* hlo) override { + return ReplaceImplicitBroadcastOperands(hlo); + } + + Status HandleClamp(HloInstruction* hlo) override { + // Clamp is the only element-wise ternary operation. + return ReplaceImplicitBroadcastOperands(hlo); + } + + // Returns whether any modification has been made to any visited instruction. + bool changed() const { return changed_; } + + private: + // Iterates through the operands of 'hlo' and replace any operands which are + // implicitly broadcast with the equivalent sequence of broadcast and reshape + // instructions. An operand is considered to be implicitly broadcast if the + // operand shape does have the same dimensions as the shape of 'hlo'. + Status ReplaceImplicitBroadcastOperands(HloInstruction* hlo) { + auto fadd = [hlo](std::unique_ptr x) { + return hlo->parent()->AddInstruction(std::move(x)); + }; + std::vector operands; + bool operands_changed = false; + for (int i = 0; i < hlo->operand_count(); ++i) { + HloInstruction* operand = hlo->mutable_operand(i); + if (!ShapeUtil::SameDimensions(hlo->shape(), operand->shape())) { + HloInstruction* new_operand = hlo->parent()->AddInstruction( + HloInstruction::CreateBroadcastSequence(hlo->shape(), operand, + fadd)); + operands.push_back(new_operand); + operands_changed = true; + } else { + operands.push_back(operand); + } + } + if (operands_changed) { + // Create a new HLO instruction because the HloInstruction::Replace* + // methods check that the shape does not change with the replacement. + HloInstruction* new_hlo = hlo->parent()->AddInstruction( + hlo->CloneWithNewOperands(hlo->shape(), operands)); + TF_RETURN_IF_ERROR(hlo->ReplaceAllUsesWith(new_hlo)); + changed_ = true; + } + return Status::OK(); + } + + bool changed_ = false; +}; + +} // namespace + +StatusOr ImplicitBroadcastRemover::Run(HloModule* module) { + VLOG(1) << "Removing implicit broadcast from module " << module->name(); + XLA_VLOG_LINES(2, + "Before removing implicit broadcasts:\n" + module->ToString()); + + ImplicitBroadcastVisitor visitor; + for (HloComputation* computation : module->computations()) { + TF_RETURN_IF_ERROR(computation->Accept(&visitor)); + } + + if (visitor.changed()) { + // HLO instructions with implicitly broadcast operands are cloned and left + // for dead. Remove them. + HloDCE dce; + TF_RETURN_IF_ERROR(dce.Run(module).status()); + } + + XLA_VLOG_LINES(2, + "After removing implicit broadcasts:\n" + module->ToString()); + + return visitor.changed(); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover.h b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h new file mode 100644 index 0000000000000000000000000000000000000000..aa325dc8a353c5bfbfded0c2774c66bfcc71c9cb --- /dev/null +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover.h @@ -0,0 +1,42 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// Pass which replaces all implicit broadcasts with their equivalent sequence of +// explicit broadcast and reshape instructions. +class ImplicitBroadcastRemover : public HloPassInterface { + public: + ImplicitBroadcastRemover() {} + ~ImplicitBroadcastRemover() override {} + + tensorflow::StringPiece name() const override { + return "implicit-broadcast-remover"; + } + + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_IMPLICIT_BROADCAST_REMOVER_H_ diff --git a/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8c7b38dd1bf73e0be7b669d7215812aaef1cee17 --- /dev/null +++ b/tensorflow/compiler/xla/service/implicit_broadcast_remover_test.cc @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/implicit_broadcast_remover.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class ImplicitBroadcastRemoverTest : public HloVerifiedTestBase { + protected: + ImplicitBroadcastRemover remover_; +}; + +TEST_F(ImplicitBroadcastRemoverTest, NoImplicitBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_FALSE(remover_.Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Parameter(), op::Parameter())); +} + +TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kPower, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + + EXPECT_FALSE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + root = computation->root_instruction(); + + EXPECT_THAT(root, op::Power(op::Broadcast(op::Parameter()), op::Parameter())); + + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, DegenerateDimensionBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "p0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 4, 1}), "p1")); + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Subtract(op::Parameter(), + op::Broadcast(op::Reshape(op::Parameter())))); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, ScalarBroadcastToDegenerateDimensions) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {1, 4, 1}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {}), "scalar_param")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "p1")); + builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kSubtract, param0, param1)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, + op::Subtract(op::Broadcast(op::Parameter()), op::Parameter())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, TernaryDegenerateDimensionBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6, 8}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 4, 1, 8}), "p0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 1, 6, 8}), "p1")); + auto param2 = builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(F32, {2, 1, 6, 8}), "p2")); + builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, + param0, param1, param2)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Reshape(op::Parameter())), + op::Broadcast(op::Reshape(op::Parameter())), + op::Broadcast(op::Reshape(op::Parameter())))); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); +} + +TEST_F(ImplicitBroadcastRemoverTest, + TernaryScalarAndDegenerateDimensionBroadcast) { + auto builder = HloComputation::Builder(TestName()); + + const Shape shape = ShapeUtil::MakeShape(F32, {2, 4, 6}); + auto param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {}), "p0")); + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(F32, {1, 4, 6}), "p1")); + auto param2 = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "p2")); + builder.AddInstruction(HloInstruction::CreateTernary(shape, HloOpcode::kClamp, + param0, param1, param2)); + + HloComputation* computation = module().AddEntryComputation(builder.Build()); + + EXPECT_TRUE(remover_.Run(&module()).ValueOrDie()); + + HloInstruction* root = computation->root_instruction(); + EXPECT_THAT(root, op::Clamp(op::Broadcast(op::Parameter()), + op::Broadcast(op::Reshape(op::Parameter())), + op::Parameter())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(0)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(1)->shape())); + EXPECT_TRUE(ShapeUtil::Compatible(root->shape(), root->operand(2)->shape())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index ba901b99e4f3c72c84c1ecdf4e19e58ad9ab6506..90e1f0acdc4cdeda280dabaab2df66b181d0f407 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -100,6 +100,7 @@ namespace xla { case HloOpcode::kDivide: case HloOpcode::kDot: case HloOpcode::kExp: + case HloOpcode::kFft: case HloOpcode::kFusion: case HloOpcode::kLog: case HloOpcode::kMap: diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 2704a805a91b93c69b751cdb61305ea7780f0ef2..0819ab3b90b2360c6b0b2afaa89f322afe566eb3 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -92,6 +92,7 @@ cc_library( "//tensorflow/compiler/xla/service:hlo_execution_profile", "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:shaped_buffer", + "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", ], diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index dc63a2224d659fa427d4d1a30c5dc0f94d643b36..9171e859c6f84ceef9664aa1eb90a07c87dfab40 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -44,41 +44,26 @@ namespace interpreter { namespace se = ::perftools::gputools; namespace sep = ::perftools::gputools::interpreter; -/* - * Run optimization passes on the module. The graph is transformed by - * each pass in the optimization pipeline. The service subdirectory - * contains useful optimization passes. - */ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { HloPassPipeline pipeline("Interpreter"); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(false); - - pipeline.AddPass>( - false, [](const Shape&, const Shape&) { return false; }); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(); - pipeline.AddPass(true); + pipeline.AddPass( hlo_module->mutable_entry_computation_layout()); - pipeline.AddPass(); - pipeline.AddPass(); return pipeline.Run(hlo_module).status(); } StatusOr> InterpreterCompiler::RunHloPasses( - std::unique_ptr hlo_module, - se::StreamExecutor* /*stream_exec*/) { + std::unique_ptr hlo_module, se::StreamExecutor* /*stream_exec*/, + DeviceMemoryAllocator* /*device_allocator*/) { VLOG(1) << "Run hlo passes on graph " << hlo_module->name(); TF_RETURN_IF_ERROR(RunHloOptimization(hlo_module.get())); return std::move(hlo_module); } StatusOr> InterpreterCompiler::RunBackend( - std::unique_ptr hlo_module, se::StreamExecutor* stream_exec) { + std::unique_ptr hlo_module, se::StreamExecutor* stream_exec, + DeviceMemoryAllocator* /*device_allocator*/) { TF_RET_CHECK(stream_exec != nullptr); VLOG(1) << "Run backend " << hlo_module->name(); @@ -96,7 +81,8 @@ StatusOr> InterpreterCompiler::RunBackend( StatusOr>> InterpreterCompiler::Compile( std::vector> /*hlo_modules*/, - std::vector> /*stream_execs*/) { + std::vector> /*stream_execs*/, + DeviceMemoryAllocator* /*device_allocator*/) { return tensorflow::errors::Unimplemented( "Compilation of multiple HLO modules is not supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index 278cf5184227ae25518b1d46c0e16e4cce7bd1a8..c8660c04d86a82e7dfcfd1658310c2a0e4fa0083 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -45,16 +45,19 @@ class InterpreterCompiler : public Compiler { StatusOr> RunHloPasses( std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr> RunBackend( std::unique_ptr hlo_module, - perftools::gputools::StreamExecutor* stream_exec) override; + perftools::gputools::StreamExecutor* stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> Compile( std::vector> hlo_modules, std::vector> - stream_exec) override; + stream_exec, + DeviceMemoryAllocator* device_allocator) override; StatusOr>> CompileAheadOfTime(std::vector> hlo_modules, diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 9183a1d1bfb8c2f6e1933c004f9c9f5f9ad8eced..0cb9b5d8107cd8bf468b07d5fe2a22930d9e8b8c 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/interpreter/executor.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/lib/core/errors.h" @@ -38,7 +39,6 @@ namespace xla { namespace interpreter { namespace se = ::perftools::gputools; -namespace sep = ::perftools::gputools::interpreter; InterpreterExecutable::InterpreterExecutable( std::unique_ptr hlo_module) @@ -47,44 +47,18 @@ InterpreterExecutable::InterpreterExecutable( InterpreterExecutable::~InterpreterExecutable() {} -static se::DeviceMemoryBase AllocateSingleOutput( - sep::InterpreterExecutor* executor, const Literal& literal) { - int64 size(xla::ShapeUtil::ByteSizeOf(literal.shape())); - void* buf = executor->Allocate(size); - const void* src = literal.InternalData(); - memcpy(buf, src, size); - return se::DeviceMemoryBase(buf, size); -} - -static se::DeviceMemoryBase AllocateOutputBuffer( - sep::InterpreterExecutor* executor, const Literal& literal) { - const Shape& shape = literal.shape(); - if (shape.element_type() != xla::TUPLE) { - return AllocateSingleOutput(executor, literal); - } else { - int64 size(xla::ShapeUtil::ByteSizeOf(shape, sizeof(void*))); - void** buf = reinterpret_cast(executor->Allocate(size)); - void** buf_rc = buf; - for (int64 n = 0; n < xla::ShapeUtil::TupleElementCount(shape); n++) { - se::DeviceMemoryBase out = - AllocateSingleOutput(executor, literal.tuple_literals(n)); - *buf++ = out.opaque(); - } - - return se::DeviceMemoryBase(buf_rc, size); - } -} - -StatusOr InterpreterExecutable::ExecuteOnStream( +StatusOr> InterpreterExecutable::ExecuteOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); + se::StreamExecutor* executor = stream->parent(); + const se::Platform* platform = executor->platform(); VLOG(1) << "Execute " << module().name(); if (VLOG_IS_ON(2)) { for (const auto& a : arguments) { - VLOG(2) << "-- argument " << a.opaque(); + VLOG(2) << "-- argument " << *a; } } @@ -96,33 +70,32 @@ StatusOr InterpreterExecutable::ExecuteOnStream( "Mismatch between argument count and graph parameter count."); } - // Create the arguments as an vector of XLA literals + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(platform)); + + // Transform the ShapedBuffer arguments into literals which the evaluator + // consumes. std::vector> arg_literals; - std::vector arg_literals_ptrs; for (int64 p = 0; p < computation->num_parameters(); ++p) { - // Create the input literal for the parameter - HloInstruction* param = computation->parameter_instruction(p); - arg_literals.emplace_back(Literal::CreateFromShape(param->shape())); - arg_literals_ptrs.push_back(arg_literals.back().get()); - - // Copy in the data from the stream_executor buffers - void* buffer = arg_literals.back()->MutableInternalData(); - memcpy(buffer, arguments[p].opaque(), - ShapeUtil::ByteSizeOf(param->shape())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr arg_literal, + transfer_manager->TransferLiteralFromDevice(executor, *arguments[p])); + arg_literals.push_back(std::move(arg_literal)); } // Execute the graph using the HloEvaluator. HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(std::unique_ptr output, - evaluator.Evaluate(*computation, arg_literals_ptrs)); - - // Copy the result into the return buffer - perftools::gputools::StreamExecutor* executor(stream->parent()); - sep::InterpreterExecutor* interpreter_executor( - static_cast(executor->implementation())); - - se::DeviceMemoryBase ret = - AllocateOutputBuffer(interpreter_executor, *(output.get())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_literal, + evaluator.Evaluate>(*computation, arg_literals)); + + // Transform the result literal back into a ShapedBuffer. + TF_ASSIGN_OR_RETURN(std::unique_ptr result, + transfer_manager->AllocateShapedBuffer( + result_literal->shape(), run_options->allocator(), + run_options->device_ordinal())); + TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice( + executor, *result_literal, *result)); uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -132,20 +105,13 @@ StatusOr InterpreterExecutable::ExecuteOnStream( execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0)); } - return ret; -} - -StatusOr> InterpreterExecutable::ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments, - HloExecutionProfile* hlo_execution_profile) { - return tensorflow::errors::Unimplemented( - "ExecuteOnStream is not yet supported on Interpreter."); + return std::move(result); } -StatusOr InterpreterExecutable::ExecuteAsyncOnStream( +StatusOr> +InterpreterExecutable::ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice arguments) { + tensorflow::gtl::ArraySlice arguments) { return tensorflow::errors::Unimplemented( "ExecuteAsyncOnStream is not yet supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 0e87eb90bff4b896fc4bc0efc4fa7b851631be6f..410110a1adf04c83001c38ed03f5d60dd203dc7e 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -43,21 +43,14 @@ class InterpreterExecutable : public Executable { InterpreterExecutable(std::unique_ptr hlo_module); ~InterpreterExecutable() override; - StatusOr ExecuteOnStream( - const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments, - HloExecutionProfile* hlo_execution_profile) override; - StatusOr> ExecuteOnStream( const ServiceExecutableRunOptions* run_options, tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) override; - StatusOr ExecuteAsyncOnStream( + StatusOr> ExecuteAsyncOnStream( const ServiceExecutableRunOptions* run_options, - tensorflow::gtl::ArraySlice - arguments) override; + tensorflow::gtl::ArraySlice arguments) override; static int64 ShapeSizeBytes(const Shape& shape); diff --git a/tensorflow/compiler/xla/service/interpreter/executor.cc b/tensorflow/compiler/xla/service/interpreter/executor.cc index 0bb3259ef43915067e614e72038387e8300ecc41..68371910d76f42c0b6d4b1adad9d6a83bdb858e6 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.cc +++ b/tensorflow/compiler/xla/service/interpreter/executor.cc @@ -85,7 +85,7 @@ bool InterpreterExecutor::HostCallback(Stream *stream, bool InterpreterExecutor::CreateStreamDependency(Stream *dependent, Stream *other) { AsExecutorStream(dependent)->EnqueueTask( - [other]() { other->BlockHostUntilDone(); }); + [other]() { SE_CHECK_OK(other->BlockHostUntilDone()); }); AsExecutorStream(dependent)->BlockUntilDone(); return true; } @@ -100,9 +100,9 @@ bool InterpreterExecutor::StopTimer(Stream *stream, Timer *timer) { return true; } -bool InterpreterExecutor::BlockHostUntilDone(Stream *stream) { +port::Status InterpreterExecutor::BlockHostUntilDone(Stream *stream) { AsExecutorStream(stream)->BlockUntilDone(); - return true; + return port::Status::OK(); } DeviceDescription *InterpreterExecutor::PopulateDeviceDescription() const { diff --git a/tensorflow/compiler/xla/service/interpreter/executor.h b/tensorflow/compiler/xla/service/interpreter/executor.h index c59b2ccb1505b78be0c459ac9311428d65cc7e44..c5d07e906dafb033905c50c604069e80e1ce80cd 100644 --- a/tensorflow/compiler/xla/service/interpreter/executor.h +++ b/tensorflow/compiler/xla/service/interpreter/executor.h @@ -157,7 +157,7 @@ class InterpreterExecutor : public internal::StreamExecutorInterface { bool StartTimer(Stream *stream, Timer *timer) override; bool StopTimer(Stream *stream, Timer *timer) override; - bool BlockHostUntilDone(Stream *stream) override; + port::Status BlockHostUntilDone(Stream *stream) override; int PlatformDeviceCount() override { return 1; } diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 7eda7c2284c2457703fcfcd4226172e41dd4ae01..fce135ef61a7868386b869def1a79167c428d928 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -61,8 +61,8 @@ std::ostream& operator<<(std::ostream& out, BufferLayoutConstraint::BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory) - : LayoutConstraint(mandatory), layout_(layout), buffer_(&buffer) { + bool mandatory, bool dfs) + : LayoutConstraint(mandatory, dfs), layout_(layout), buffer_(&buffer) { CHECK(LayoutUtil::ValidateLayoutForShape(layout, buffer.shape()).ok()); } @@ -74,14 +74,17 @@ string BufferLayoutConstraint::ToString() const { OperandLayoutConstraint::OperandLayoutConstraint( const ShapeLayout& shape_layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory) - : LayoutConstraint(mandatory), + int64 operand_no, bool mandatory, bool dfs) + : LayoutConstraint(mandatory, dfs), shape_layout_(shape_layout), instruction_(instruction), operand_no_(operand_no) { CHECK(shape_layout_.LayoutIsSet()); CHECK(ShapeUtil::Compatible(shape_layout.shape(), - instruction->operand(operand_no)->shape())); + instruction->operand(operand_no)->shape())) + << shape_layout.shape() << " is not compatible with " + << instruction->operand(operand_no)->shape() << " (for operand " + << operand_no << " of instruction " << instruction->ToString() << ")"; } string OperandLayoutConstraint::ToString() const { @@ -131,7 +134,7 @@ bool LayoutConstraints::OperandBufferForwarded( Status LayoutConstraints::SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory) { + bool mandatory, bool dfs) { VLOG(3) << "SetBufferLayout : " << buffer << " : " << LayoutUtil::HumanString(layout); @@ -168,10 +171,11 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, if (!overwrite) { iter = buffer_constraints_ .insert(std::make_pair( - &buffer, BufferLayoutConstraint(layout, buffer, mandatory))) + &buffer, + BufferLayoutConstraint(layout, buffer, mandatory, dfs))) .first; } else { - iter->second = BufferLayoutConstraint(layout, buffer, /*mandatory=*/true); + iter->second = BufferLayoutConstraint(layout, buffer, mandatory, dfs); } added_constraints_.push_back(&iter->second); @@ -185,7 +189,8 @@ Status LayoutConstraints::SetBufferLayout(const Layout& layout, Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory) { + int64 operand_no, bool mandatory, + bool dfs) { VLOG(3) << "SetOperandLayout : " << instruction->name() << ", operand " << operand_no << " : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -223,12 +228,12 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, if (iter == operand_constraints_.end()) { auto pair = std::make_pair( key, OperandLayoutConstraint(ShapeLayout(shape_with_layout), - instruction, operand_no, mandatory)); + instruction, operand_no, mandatory, dfs)); iter = operand_constraints_.insert(pair).first; } else { iter->second = OperandLayoutConstraint(ShapeLayout(shape_with_layout), instruction, - operand_no, /*mandatory=*/true); + operand_no, mandatory, dfs); } added_constraints_.push_back(&iter->second); @@ -237,16 +242,17 @@ Status LayoutConstraints::SetOperandLayout(const Shape& shape_with_layout, Status LayoutConstraints::SetArrayOperandLayout( const Layout& layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory) { + bool mandatory, bool dfs) { const HloInstruction* operand = instruction->operand(operand_no); TF_RET_CHECK(ShapeUtil::IsArray(operand->shape())); Shape shape(operand->shape()); *shape.mutable_layout() = layout; TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutInShape(shape)); - return SetOperandLayout(shape, instruction, operand_no, mandatory); + return SetOperandLayout(shape, instruction, operand_no, mandatory, dfs); } -Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { +Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout, + bool dfs) { VLOG(3) << "SetResultLayout : " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -264,14 +270,15 @@ Status LayoutConstraints::SetResultLayout(const Shape& shape_with_layout) { } result_constraint_.reset( - new ResultLayoutConstraint(ShapeLayout(shape_with_layout))); + new ResultLayoutConstraint(ShapeLayout(shape_with_layout), dfs)); added_constraints_.push_back(result_constraint_.get()); return Status::OK(); } Status LayoutConstraints::SetInstructionLayout( - const Shape& shape_with_layout, const HloInstruction* instruction) { + const Shape& shape_with_layout, const HloInstruction* instruction, + bool mandatory, bool dfs) { VLOG(3) << "SetInstructionLayout : " << instruction->name() << ", " << ShapeUtil::HumanStringWithLayout(shape_with_layout); @@ -287,8 +294,8 @@ Status LayoutConstraints::SetInstructionLayout( // instruction. return ShapeUtil::ForEachSubshapeWithStatus( shape_with_layout, - [this, instruction](const Shape& subshape, - const ShapeIndex& index) -> Status { + [this, instruction, mandatory](const Shape& subshape, + const ShapeIndex& index) -> Status { // The precondition for this method is that the instruction defines all // buffers in its output. auto buffers = @@ -297,7 +304,7 @@ Status LayoutConstraints::SetInstructionLayout( CHECK_EQ(buffers[0]->instruction(), instruction); if (ShapeUtil::IsArray(subshape)) { - return SetBufferLayout(subshape.layout(), *buffers[0]); + return SetBufferLayout(subshape.layout(), *buffers[0], mandatory); } else { return Status::OK(); } @@ -369,8 +376,9 @@ string LayoutConstraints::ToString() const { } Status LayoutAssignment::AddMandatoryConstraints( - const ComputationLayout& computation_layout, HloComputation* computation, - LayoutConstraints* constraints) { + const ComputationLayout& computation_layout, + const ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, LayoutConstraints* constraints) { VLOG(3) << "Adding mandatory layout constraints to computation " << computation->name(); @@ -390,8 +398,7 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - instruction->outfeed_shape(), instruction, 0, - /*mandatory=*/true)); + instruction->outfeed_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in // ComputationLayout. @@ -403,6 +410,37 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(*shape_with_layout, instruction)); } + + if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { + CHECK(channel_constraints) + << "Multi-module layout assignment requires ChannelLayoutConstraints"; + int64 channel_id = instruction->channel_id(); + if (!channel_constraints->IsChannelConstrained(channel_id)) { + continue; + } + if (instruction->opcode() == HloOpcode::kSend) { + // TODO(b/68493863): Change to use SetOperandLayout(). + const Shape send_buffer_shape = instruction->operand(0)->shape(); + TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape)); + Shape new_buffer_shape = channel_constraints->LayoutShapeForChannel( + send_buffer_shape, instruction->channel_id()); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + new_buffer_shape, instruction->operand(0))); + } else { + const Shape recv_buffer_shape = + ShapeUtil::GetTupleElementShape(instruction->shape(), 0); + TF_RET_CHECK(ShapeUtil::IsArray(recv_buffer_shape)); + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + constraints->points_to_analysis().GetBufferDefinedAt(instruction, + {0})); + Shape new_shape = channel_constraints->LayoutShapeForChannel( + recv_buffer_shape, instruction->channel_id()); + TF_RETURN_IF_ERROR( + constraints->SetBufferLayout(new_shape.layout(), *buffer)); + } + } } // Constrain layouts of instructions which call computations which have @@ -422,7 +460,7 @@ Status LayoutAssignment::AddMandatoryConstraints( for (int64 i = 0; i < instruction->operand_count(); ++i) { TF_RETURN_IF_ERROR(constraints->SetOperandLayout( called_computation_layout.parameter_layout(i).shape(), instruction, - i, /*mandatory=*/true)); + i)); } } else if (instruction->opcode() == HloOpcode::kWhile) { // Layout of input and output of kWhile instruction must be equal and must @@ -473,20 +511,16 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - body_layout.result_shape(), instruction, 0, - /*mandatory=*/true)); + body_layout.result_shape(), instruction, 0)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { + if (!CustomCallRequiresMajorFirstLayout(instruction)) { + continue; + } // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support row major layouts for all inputs and outputs. - auto row_major_shape = [](const Shape& old_shape) { - Shape new_shape(old_shape); - std::vector dimension_order(new_shape.dimensions_size()); - std::iota(dimension_order.rbegin(), dimension_order.rend(), 0); - *new_shape.mutable_layout() = LayoutUtil::MakeLayout(dimension_order); - return new_shape; - }; - - Shape result_shape(row_major_shape(instruction->shape())); + // For now we only support major-first layouts for all inputs and outputs. + Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( + instruction->shape().element_type(), + AsInt64Slice(instruction->shape().dimensions())); TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(result_shape, instruction)); for (int64 i = 0; i < instruction->operand_count(); ++i) { @@ -496,9 +530,12 @@ Status LayoutAssignment::AddMandatoryConstraints( continue; } - Shape row_major_operand_shape(row_major_shape(operand_shape)); + Shape row_major_operand_shape = + ShapeUtil::MakeShapeWithDescendingLayout( + operand_shape.element_type(), + AsInt64Slice(operand_shape.dimensions())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i, /*mandatory=*/true)); + row_major_operand_shape, instruction, i)); } } } @@ -530,9 +567,11 @@ Status CheckCallLayout(HloInstruction* call, Status CheckCustomCallLayout(HloInstruction* custom_call) { for (const HloInstruction* operand : custom_call->operands()) { TF_RET_CHECK( + ShapeUtil::IsOpaque(operand->shape()) || LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); } TF_RET_CHECK( + ShapeUtil::IsOpaque(custom_call->shape()) || LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -601,11 +640,9 @@ Status CheckConstantLayout(HloInstruction* constant) { return Status::OK(); } -// Check that all layouts in the module have been set and satisfy all necessary -// conditions. -Status CheckLayouts( - HloModule* module, - const std::map& computation_layouts) { +} // namespace + +Status LayoutAssignment::CheckLayouts(HloModule* module) { TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); for (auto* computation : module->MakeNonfusionComputations()) { @@ -649,10 +686,12 @@ Status CheckLayouts( case HloOpcode::kCall: TF_RETURN_IF_ERROR(CheckCallLayout( instruction, - FindOrDie(computation_layouts, instruction->to_apply()))); + FindOrDie(computation_layouts_, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); + if (CustomCallRequiresMajorFirstLayout(instruction)) { + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); + } break; case HloOpcode::kFusion: TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); @@ -660,7 +699,7 @@ Status CheckLayouts( case HloOpcode::kParameter: TF_RETURN_IF_ERROR(CheckParameterLayout( instruction, - FindOrDie(computation_layouts, instruction->parent()))); + FindOrDie(computation_layouts_, instruction->parent()))); break; case HloOpcode::kConstant: TF_RETURN_IF_ERROR(CheckConstantLayout(instruction)); @@ -668,8 +707,8 @@ Status CheckLayouts( case HloOpcode::kWhile: TF_RETURN_IF_ERROR(CheckWhileLayout( instruction, - FindOrDie(computation_layouts, instruction->while_condition()), - FindOrDie(computation_layouts, instruction->while_body()))); + FindOrDie(computation_layouts_, instruction->while_condition()), + FindOrDie(computation_layouts_, instruction->while_body()))); break; default: break; @@ -681,17 +720,18 @@ Status CheckLayouts( // computation root. TF_RET_CHECK(ShapeUtil::Equal( module->entry_computation()->root_instruction()->shape(), - FindOrDie(computation_layouts, module->entry_computation()) + FindOrDie(computation_layouts_, module->entry_computation()) .result_layout() .shape())); return Status::OK(); } -} // namespace - -LayoutAssignment::LayoutAssignment(ComputationLayout* entry_computation_layout) - : entry_computation_layout_(entry_computation_layout) { +LayoutAssignment::LayoutAssignment( + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints) + : entry_computation_layout_(entry_computation_layout), + channel_layout_constraints_(channel_constraints) { VLOG(1) << "entry computation layout given to layout assignment: " << entry_computation_layout_->ToString(); // Layouts of all parameter instructions must be set. @@ -711,8 +751,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( int64 operand_no) { const HloInstruction* operand = instruction->operand(operand_no); - CHECK(ShapeUtil::IsArray(instruction->shape()) && - ShapeUtil::IsArray(operand->shape())); + CHECK(ShapeUtil::IsArray(instruction->shape())); + CHECK(ShapeUtil::IsArray(operand->shape())); if (instruction->IsElementwiseOnOperand(operand_no) && !ShapeUtil::IsScalar(operand->shape()) && @@ -742,7 +782,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( const Shape& output_shape = instruction->shape(); Shape output_shape_with_layout = ShapeUtil::MakeShapeWithLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions()), - AsInt64Slice(output_layout.minor_to_major())); + LayoutUtil::MinorToMajor(output_layout)); Shape operand_shape = operand->shape(); *operand_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(operand_shape); @@ -771,7 +811,7 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( int64 rank = ShapeUtil::Rank(instruction->shape()); std::vector new_minor_to_major(rank); for (int64 i = 0; i < rank; ++i) { - int64 output_dim = output_layout.minor_to_major(i); + int64 output_dim = LayoutUtil::Minor(output_layout, i); int64 operand_dim = instruction->dimensions(output_dim); new_minor_to_major[i] = operand_dim; } @@ -814,7 +854,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( Shape operand_shape_with_layout = ShapeUtil::MakeShapeWithLayout( operand->shape().element_type(), AsInt64Slice(operand->shape().dimensions()), - AsInt64Slice(operand_layout.minor_to_major())); + LayoutUtil::MinorToMajor(operand_layout)); Shape output_shape = user->shape(); *output_shape.mutable_layout() = LayoutUtil::GetDefaultLayoutForShape(output_shape); @@ -844,7 +884,7 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( std::vector new_minor_to_major(rank); auto inverse_dimensions = InversePermutation(user->dimensions()); for (int64 i = 0; i < rank; ++i) { - int64 operand_dim = operand_layout.minor_to_major(i); + int64 operand_dim = LayoutUtil::Minor(operand_layout, i); int64 user_dim = inverse_dimensions[operand_dim]; new_minor_to_major[i] = user_dim; } @@ -869,7 +909,11 @@ Status LayoutAssignment::PropagateConstraints(LayoutConstraints* constraints) { auto add_new_constraints_to_worklist = [constraints, &worklist]() { // Add constraints to the front of the deque for DFS ordering. for (auto* constraint : constraints->ConsumeAddedConstraints()) { - worklist.push_front(constraint); + if (constraint->dfs()) { + worklist.push_front(constraint); + } else { + worklist.push_back(constraint); + } } }; add_new_constraints_to_worklist(); @@ -1198,7 +1242,8 @@ Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, // instruction itself. Status SetFusionLayouts(HloInstruction* fusion) { TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); - for (auto* fused_instruction : fusion->fused_instructions()) { + for (auto* fused_instruction : + fusion->fused_instructions_computation()->MakeInstructionPostOrder()) { if (fused_instruction->opcode() == HloOpcode::kParameter) { const HloInstruction* fusion_operand = fusion->operand(fused_instruction->parameter_number()); @@ -1213,11 +1258,22 @@ Status SetFusionLayouts(HloInstruction* fusion) { ShapeUtil::Compatible(fusion->shape(), fused_instruction->shape())); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fusion->shape(), fused_instruction->mutable_shape())); - } else if (fused_instruction->opcode() != HloOpcode::kConstant && - fused_instruction->opcode() != HloOpcode::kGetTupleElement && - fused_instruction->opcode() != HloOpcode::kInfeed) { - // Internal fused instructions with the exception of constants - // and infeed need no layout. + } else if (fused_instruction->opcode() == HloOpcode::kGetTupleElement) { + // A GTE inherits its layout from its operand (which should ultimately be + // a parameter). + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + fused_instruction->operand(0)->shape().tuple_shapes( + fused_instruction->tuple_index()), + fused_instruction->mutable_shape())); + } else if (fused_instruction->opcode() == HloOpcode::kConstant) { + // Give constants the layout of their literal. + TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( + fused_instruction->literal().shape(), + fused_instruction->mutable_shape())); + } else if (fused_instruction->opcode() == HloOpcode::kInfeed) { + // Nop; leave the infeed layout alone. + } else { + // Other instructions don't have layouts inside of fusion nodes. LayoutUtil::ClearLayout(fused_instruction->mutable_shape()); } } @@ -1303,8 +1359,8 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); } - // Copy the root instrucion's result if the it does not match the result - // layout constraint + // Copy the root instruction's result if its layout does not match the result + // layout constraint. if (constraints.ResultLayout() != nullptr && !constraints.ResultLayout()->MatchesLayoutInShape( computation->root_instruction()->shape())) { @@ -1321,7 +1377,8 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints, Status LayoutAssignment::RunOnComputation( const ComputationLayout& computation_layout, const TuplePointsToAnalysis& points_to_analysis, - HloComputation* computation) { + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints) { DCHECK(computation_layout.LayoutIsSet()); InsertOrDie(&computation_layouts_, computation, computation_layout); VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name() @@ -1333,13 +1390,13 @@ Status LayoutAssignment::RunOnComputation( // Add constraints required for correctness on all backends (eg, entry // parameter layout constraints). - TF_RETURN_IF_ERROR( - AddMandatoryConstraints(computation_layout, computation, &constraints)); + TF_RETURN_IF_ERROR(AddMandatoryConstraints( + computation_layout, channel_constraints, computation, &constraints)); // Add any backend-specific constraints. TF_RETURN_IF_ERROR(AddBackendConstraints(&constraints)); - // Propagates layouts from an HLO to its neighbors. + // Propagates layouts from mandatory and backend constraints. TF_RETURN_IF_ERROR(PropagateConstraints(&constraints)); // While any unconstrained buffers remain, pick an arbitrary buffer, give it a @@ -1373,7 +1430,20 @@ Status LayoutAssignment::RunOnComputation( // All logical buffers should have constraints at this point. All that // remains is assign the constraints to the buffers and infer layouts for // aliased buffers. - return AssignLayouts(constraints, computation); + TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation)); + + // Record the layouts assigned for any communication ops in + // channel_constraints so that they are constrained for future modules. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kSend) { + channel_constraints->ConstrainChannel( + instruction->channel_id(), instruction->operand(0)->shape().layout()); + } else if (instruction->opcode() == HloOpcode::kRecvDone) { + channel_constraints->ConstrainChannel(instruction->channel_id(), + instruction->shape().layout()); + } + } + return Status::OK(); } StatusOr LayoutAssignment::Run(HloModule* module) { @@ -1391,24 +1461,39 @@ StatusOr LayoutAssignment::Run(HloModule* module) { // Assign layouts to computations in an order such that a callee computation // is handled before its caller computation. This ensures that the layout of // all callers of a computation will agree. + std::list computation_post_order = + module->MakeComputationPostOrder(); for (auto* computation : module->MakeComputationPostOrder()) { - if (computation == module->entry_computation()) { - TF_RETURN_IF_ERROR(RunOnComputation(*entry_computation_layout_, - *points_to_analysis, - module->entry_computation())); - } else if (computation->IsFusionComputation()) { + if (computation->IsFusionComputation()) { continue; + } + // Clear existing layouts of the instructions. All layouts must be assigned + // by the LayoutAssignment pass, except for those on infeeds, parameters, + // and the computation result. The latter two are specified in + // computation_layout, so we only need to keep the existing layouts for + // infeeds. Clearing the layouts here avoids hiding potential bugs in the + // layout assignment pass that may accidently use the existing layout. + for (HloInstruction* instruction : computation->instructions()) { + if (instruction->opcode() != HloOpcode::kInfeed) { + LayoutUtil::ClearLayout(instruction->mutable_shape()); + } + } + if (computation == module->entry_computation()) { + TF_RETURN_IF_ERROR(RunOnComputation( + *entry_computation_layout_, *points_to_analysis, + module->entry_computation(), channel_layout_constraints_)); } else { ComputationLayout computation_layout(computation->ComputeProgramShape()); // Setting all embedded computations to the default layout is potentially // suboptimal. computation_layout.SetToDefaultLayout(); TF_RETURN_IF_ERROR(RunOnComputation(computation_layout, - *points_to_analysis, computation)); + *points_to_analysis, computation, + channel_layout_constraints_)); } } - TF_RETURN_IF_ERROR(CheckLayouts(module, computation_layouts_)); + TF_RETURN_IF_ERROR(CheckLayouts(module)); VLOG(3) << "After layout assignment:"; XLA_VLOG_LINES(3, module->ToString()); diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 0b97fba744923b8afc3fb539566b68f1bca47d38..29018584487cabfd740d7914625c2a50f552d6ff 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -46,7 +46,8 @@ namespace xla { // gathered together in LayoutConstraints object. class LayoutConstraint { public: - LayoutConstraint(bool mandatory) : mandatory_(mandatory) {} + LayoutConstraint(bool mandatory, bool dfs) + : mandatory_(mandatory), dfs_(dfs) {} virtual ~LayoutConstraint() = default; virtual string ToString() const = 0; @@ -54,8 +55,12 @@ class LayoutConstraint { // True if this constraint cannot be overwritten by a different constraint. bool mandatory() const { return mandatory_; } + // When true, propagate in DFS. When false, constraint will propagate in BFS. + bool dfs() const { return dfs_; } + private: bool mandatory_; + bool dfs_; }; std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); @@ -65,7 +70,7 @@ std::ostream& operator<<(std::ostream& out, const LayoutConstraint& constraint); class BufferLayoutConstraint : public LayoutConstraint { public: BufferLayoutConstraint(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory); + bool mandatory, bool dfs); const LogicalBuffer& buffer() const { return *buffer_; } const Layout& layout() const { return layout_; } @@ -86,7 +91,7 @@ class OperandLayoutConstraint : public LayoutConstraint { public: OperandLayoutConstraint(const ShapeLayout& shape_layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory); + bool mandatory, bool dfs); const ShapeLayout& shape_layout() const { return shape_layout_; } const HloInstruction* instruction() const { return instruction_; } @@ -106,8 +111,10 @@ class OperandLayoutConstraint : public LayoutConstraint { // Constraint on the layout of the result of the entry computation. class ResultLayoutConstraint : public LayoutConstraint { public: - explicit ResultLayoutConstraint(const ShapeLayout& shape_layout) - : LayoutConstraint(/*mandatory=*/true), shape_layout_(shape_layout) {} + explicit ResultLayoutConstraint(const ShapeLayout& shape_layout, + bool dfs = false) + : LayoutConstraint(/*mandatory=*/true, dfs), + shape_layout_(shape_layout) {} const ShapeLayout& shape_layout() const { return shape_layout_; } string ToString() const override; @@ -157,23 +164,25 @@ class LayoutConstraints { // operand of the instruction, or the layout of the result of the computation, // respectively. Status SetBufferLayout(const Layout& layout, const LogicalBuffer& buffer, - bool mandatory = true); + bool mandatory = true, bool dfs = true); Status SetOperandLayout(const Shape& shape_with_layout, const HloInstruction* instruction, int64 operand_no, - bool mandatory = true); - Status SetResultLayout(const Shape& shape_with_layout); + bool mandatory = true, bool dfs = true); + Status SetResultLayout(const Shape& shape_with_layout, bool dfs = true); // Convenience wrapper around SetOperandLayout for setting the layout of a // operand using a Layout object. The operand must be array-shaped. Status SetArrayOperandLayout(const Layout& layout, const HloInstruction* instruction, - int64 operand_no, bool mandatory = true); + int64 operand_no, bool mandatory = true, + bool dfs = true); // Convenience wrapper around SetBufferLayout. Sets the layouts of all buffers // created by the instruction to the layouts in the given shape. The // instruction must define every logical buffer in its output. Status SetInstructionLayout(const Shape& shape_with_layout, - const HloInstruction* instruction); + const HloInstruction* instruction, + bool mandatory = true, bool dfs = true); // Returns true if any buffer in the given operand is forwarded to the output // of the given instruction. For example, the Tuple instruction forwards the @@ -215,13 +224,62 @@ class LayoutConstraints { HloComputation* computation_; }; +// Contains constraints on the layout of channels; sends and recvs. +class ChannelLayoutConstraints { + public: + // Construct an empty constraint set. + ChannelLayoutConstraints() {} + + // Returns true if channel_id has a layout constraint. + bool IsChannelConstrained(int64 channel_id) const { + return constraints_.count(channel_id) > 0; + } + + // Given `shape`, apply the layout for `channel_id`. `channel_id` must already + // be constrained. + Shape LayoutShapeForChannel(Shape shape, int64 channel_id) const { + CHECK(IsChannelConstrained(channel_id)); + *shape.mutable_layout() = constraints_.at(channel_id); + return shape; + } + + // Returns the layout constraint for `channel_id`, which must already be + // constrained. + Layout LayoutForChannel(int64 channel_id) const { + CHECK(IsChannelConstrained(channel_id)); + return constraints_.at(channel_id); + } + + // Adds a new layout constraint for `channel_id`. If a constraint for + // `channel_id` already exists, this operation requires that the new layout is + // the same as the previously constrained layout. + void ConstrainChannel(int64 channel_id, const Layout& layout) { + CHECK(!IsChannelConstrained(channel_id) || + LayoutUtil::Equal(layout, constraints_[channel_id])); + constraints_[channel_id] = layout; + } + + private: + std::unordered_map constraints_; +}; + // HLO pass which assigns layouts to all instructions in the HLO module while // satisfying all necessary invariants and minimizing cost. class LayoutAssignment : public HloPassInterface { public: // entry_computation_layout is modified to populate a layout for the result in // the case that no particular layout is requested. - explicit LayoutAssignment(ComputationLayout* entry_computation_layout); + // + // channel_constraints is both an input and output. Any sends or recvs that + // are present in channel_constraints will be layed out as constrained. Any + // unconstrained sends or recvs will be layed out as locally optimal and their + // layout will be added as a constraint to channel_constraints. + // + // If channel_constraints is nullptr, no kSend or kRecvs must be contained + // within any module passed to `Run`. + explicit LayoutAssignment( + ComputationLayout* entry_computation_layout, + ChannelLayoutConstraints* channel_constraints = nullptr); ~LayoutAssignment() override {} tensorflow::StringPiece name() const override { return "layout-assignment"; } @@ -247,6 +305,19 @@ class LayoutAssignment : public HloPassInterface { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); + // By default LayoutAssignment ensures that inputs and outputs of CustomCalls + // have the "major-first" layout (i.e. {n, n-1, ..., 0}). + // + // If this function returns true, LayoutAssignment does not set a layout for + // the given CustomCall. It's up to the backend to set one in + // AddBackendConstraints, if necessary. + // + // Precondition: instruction->opcode() == HloOpcode::kCustomCall. + virtual bool CustomCallRequiresMajorFirstLayout( + const HloInstruction* /*instruction*/) { + return true; + } + // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { @@ -283,9 +354,10 @@ class LayoutAssignment : public HloPassInterface { private: // Adds constraints which must be satisfied for correctness on all // backends. Called once prior to propagating constraints. - Status AddMandatoryConstraints(const ComputationLayout& computation_layout, - HloComputation* computation, - LayoutConstraints* constraints); + Status AddMandatoryConstraints( + const ComputationLayout& computation_layout, + const ChannelLayoutConstraints* channel_constraints, + HloComputation* computation, LayoutConstraints* constraints); // This method can be overridden to add backend-specific constraints to the // layout of the instructions of a computation. This method is called after @@ -301,7 +373,8 @@ class LayoutAssignment : public HloPassInterface { // constrained. Status RunOnComputation(const ComputationLayout& computation_layout, const TuplePointsToAnalysis& points_to_analysis, - HloComputation* computation); + HloComputation* computation, + ChannelLayoutConstraints* channel_constraints); // Assign layouts to the instructions of a computation which satisfy the given // layout constraints. Copies may be added to satisfy the constraints. The @@ -315,7 +388,12 @@ class LayoutAssignment : public HloPassInterface { // required for correctness. Status PropagateConstraints(LayoutConstraints* constraints); + // Check that all layouts in the module have been set and satisfy all + // necessary conditions. + Status CheckLayouts(HloModule* module); + ComputationLayout* entry_computation_layout_; + ChannelLayoutConstraints* channel_layout_constraints_; protected: // Map containing the layouts of all computations assigned so diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index d51c0d1dfb727801d6d2a8328eba60838373479f..e269a13459f1146f1d2952870399827d9e705e38 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -35,9 +35,11 @@ limitations under the License. #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" namespace op = xla::testing::opcode_matchers; @@ -587,5 +589,74 @@ TEST_F(LayoutAssignmentTest, TransposeToBitcastToUser) { EXPECT_TRUE(ShapeUtil::TransposeIsBitcast(transpose->operand(0)->shape(), transpose->shape(), {2, 3, 0, 1})); } + +// A GTE inside of a fusion node inherits the layout of its operand (which +// should, if we keep following operands, eventually be a parameter). +TEST_F(LayoutAssignmentTest, GTEInheritsLayoutFromOperand) { + const char* module_str = R"( + HloModule test_module + + fused_computation { + fparam = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0) + gte0 = f32[2,2,2] get-tuple-element(fparam), index=0 + gte1 = (f32[2,2,2], f32[2,2,2]) get-tuple-element(fparam), index=1 + gte1a = f32[2,2,2] get-tuple-element(gte1), index=0 + gte1b = f32[2,2,2] get-tuple-element(gte1), index=1 + add = f32[2,2,2] add(gte1a, gte1b) + ROOT fresult = f32[2,2,2] add(gte0, add) + } + + ENTRY entry_computation { + param = (f32[2,2,2], (f32[2,2,2], f32[2,2,2])) parameter(0) + ROOT fusion = + f32[2,2,2] fusion(param), kind=kLoop, calls=fused_computation + } + )"; + + auto module = tools::Parse(module_str).ValueOrDie(); + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + Shape param_shape = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {0, 1, 2}), + ShapeUtil::MakeTupleShape({ + ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {1, 2, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 2, 2}, {2, 0, 1}), + })}); + TF_ASSERT_OK( + computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape( + param_shape)); + computation_layout.mutable_result_layout()->ResetLayout( + LayoutUtil::MakeLayout({2, 1, 0})); + AssignLayouts(module.get(), &computation_layout); + + HloComputation* fused_computation = *std::find_if( + module->computations().begin(), module->computations().end(), + [](const HloComputation* c) { return c->name() == "fused_computation"; }); + + auto fused_instr = [&](const string& name) { + auto it = std::find_if( + fused_computation->instructions().begin(), + fused_computation->instructions().end(), + [&](const HloInstruction* i) { return i->name() == name; }); + CHECK(it != fused_computation->instructions().end()); + return *it; + }; + + EXPECT_THAT(fused_instr("gte0")->shape().layout().minor_to_major(), + ElementsAre(0, 1, 2)); + EXPECT_THAT( + fused_instr("gte1")->shape().tuple_shapes(0).layout().minor_to_major(), + ElementsAre(1, 2, 0)); + EXPECT_THAT( + fused_instr("gte1")->shape().tuple_shapes(1).layout().minor_to_major(), + ElementsAre(2, 0, 1)); + EXPECT_THAT(fused_instr("gte1a")->shape().layout().minor_to_major(), + ElementsAre(1, 2, 0)); + EXPECT_THAT(fused_instr("gte1b")->shape().layout().minor_to_major(), + ElementsAre(2, 0, 1)); + EXPECT_THAT(fused_instr("fresult")->shape().layout().minor_to_major(), + ElementsAre(2, 1, 0)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc index 476e86fa72ad691cda52097c953ba15132f206a7..2c2a02f6375343d67dfb155bbb03729ff6e490d2 100644 --- a/tensorflow/compiler/xla/service/liveness_util_test.cc +++ b/tensorflow/compiler/xla/service/liveness_util_test.cc @@ -277,8 +277,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) { auto b = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{2.0, 2.0}, {2.0, 2.0}}))); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b)); + HloInstruction::CreateDot(data_shape, a, b, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); @@ -312,8 +315,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) { auto b_t = builder.AddInstruction( HloInstruction::CreateTranspose(data_shape, b, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto dot = builder.AddInstruction( - HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t)); + HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums)); auto one = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.0))); diff --git a/tensorflow/compiler/xla/service/llvm_compiler.cc b/tensorflow/compiler/xla/service/llvm_compiler.cc index 34f3419269abbc73cd0ddb13c723a8da38ab19ff..f98fc0400a7d827a29dcddc5eecf9a4a01e76590 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.cc +++ b/tensorflow/compiler/xla/service/llvm_compiler.cc @@ -18,8 +18,8 @@ limitations under the License. namespace xla { StatusOr>> LLVMCompiler::Compile( std::vector> modules, - std::vector> - stream_execs) { + std::vector> stream_execs, + DeviceMemoryAllocator* device_allocator) { std::vector> result; for (size_t i = 0; i < modules.size(); i++) { if (stream_execs[i].size() != 1) { @@ -27,10 +27,12 @@ StatusOr>> LLVMCompiler::Compile( "Model partitioning not implemented for the CPU/GPU compilers!"); } - TF_ASSIGN_OR_RETURN( - modules[i], RunHloPasses(std::move(modules[i]), stream_execs[i][0])); + TF_ASSIGN_OR_RETURN(modules[i], + RunHloPasses(std::move(modules[i]), stream_execs[i][0], + device_allocator)); TF_ASSIGN_OR_RETURN(std::unique_ptr executable, - RunBackend(std::move(modules[i]), stream_execs[i][0])); + RunBackend(std::move(modules[i]), stream_execs[i][0], + device_allocator)); result.push_back(std::move(executable)); } diff --git a/tensorflow/compiler/xla/service/llvm_compiler.h b/tensorflow/compiler/xla/service/llvm_compiler.h index c5393cef4f961c5d04c32d0d4291732b8ec702f1..d74e81bb7f622ac5e89203a3d02ca5ad839da07e 100644 --- a/tensorflow/compiler/xla/service/llvm_compiler.h +++ b/tensorflow/compiler/xla/service/llvm_compiler.h @@ -60,17 +60,20 @@ class LLVMCompiler : public Compiler { // Bring in // StatusOr> RunBackend( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec) + // perftools::gputools::StreamExecutor* stream_exec, + // DeviceMemoryAllocator* device_allocator) // StatusOr> RunHloPasses( // std::unique_ptr module, - // perftools::gputools::StreamExecutor* stream_exec) + // perftools::gputools::StreamExecutor* stream_exec, + // DeviceMemoryAllocator* device_allocator) using Compiler::RunBackend; using Compiler::RunHloPasses; StatusOr>> Compile( std::vector> modules, std::vector> - stream_execs) override; + stream_execs, + DeviceMemoryAllocator* device_allocator) override; protected: ModuleHook user_pre_optimization_hook_; diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index d878061f724de1c82f8285b0f082d0be4d5778df..37261ed1e665ebed9685751161a412ad114a9e96 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -48,11 +48,13 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_module_config", "//tensorflow/compiler/xla/service:name_uniquer", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", "@llvm//:target", + "@llvm//:transform_utils", ], ) @@ -156,18 +158,6 @@ cc_library( ], ) -cc_library( - name = "vector_support_library", - srcs = ["vector_support_library.cc"], - hdrs = ["vector_support_library.h"], - deps = [ - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "@llvm//:core", - ], -) - cc_library( name = "kernel_support_library", srcs = ["kernel_support_library.cc"], diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index 9ad7cd82cb8ca862fd7acec3dfb12c9fd61f6e27..b3b6026ef17daa184c0a015fdea618597ef068b3 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -32,8 +32,23 @@ limitations under the License. namespace xla { -// Unlike IrEmitter, this creates host functions which emit IR to generate the -// output element at the given index. It is used to generate fused operations. +// FusedIrEmitter is used to generate code for fusion nodes. +// +// Unlike IrEmitter and its ilk, which directly create LLVM IR in an LLVM +// Module, FusedIrEmitter is better understood as "IR generator generator". +// FusedIrEmitter recursively creates a generator (a host function) which the +// compiler can invoke at a later time. Invoking the generator emits LLVM IR +// that, when run, produces the value at a particular index of the output. +// +// After building this generator, the compiler creates a loop (or its moral +// equivalent, e.g. a GPU kernel) and calls the generator from within the loop. +// This generates code that produces each element of the output. +// +// This class handles both vanilla fusion and multi-output fusion. In the MOF +// case, the fusion node ends with a kTuple instruction, and the generator +// created produces an LLVM struct with N elements, one for each element of the +// arrays in the tuple. It follows that the arrays in the tuple must have the +// same length. class FusedIrEmitter : public DfsHloVisitorWithDefault { public: using Generator = llvm_ir::ElementGenerator; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index 7224bd689842d89563b374f3db3d4e314be18764..6384c7f46f5ebbedaeda232b40095611a5d738a4 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -39,13 +39,27 @@ IrArray::Index::Index(llvm::Value* linear, const Shape& shape, << "Shape " << ShapeUtil::HumanStringWithLayout(shape) << " should have a layout."; int64 divisor = 1; - for (int64 dimension : layout_.minor_to_major()) { + for (int64 i = 0; i < layout_.minor_to_major_size(); ++i) { + int64 dimension = layout_.minor_to_major(i); int64 size_of_current_dimension = shape.dimensions(dimension); - // Emit IR instructions that compute - // (linear_index / divisor) % current_dimension - multidim_[dimension] = ir_builder->CreateURem( - ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)), - ir_builder->getInt64(size_of_current_dimension)); + + // If i is not the last dimension, compute + // (linear_index / divisor) % current_dimension. + // If i is the last dimension, we can skip the mod, because we assume that + // linear is in bounds. + // + // TODO(jlebar): We could add bounds checks here and elsewhere in this file, + // guarded under some sort of xla-memcheck flag. This might be particularly + // useful because cuda-memcheck can't help us much in XLA: Most of our + // memory lives in one big allocation, so cuda-memcheck can't detect + // out-of-bounds accesses. + auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor)); + if (i < layout_.minor_to_major_size() - 1) { + multidim_[dimension] = ir_builder->CreateURem( + quot, ir_builder->getInt64(size_of_current_dimension)); + } else { + multidim_[dimension] = quot; + } divisor *= size_of_current_dimension; } } @@ -244,8 +258,8 @@ llvm::Value* IrArray::EmitArrayElementAddress( // // getelementptr base_ptr_, 0, most major index, ..., most minor index std::vector gep_indices(1, ir_builder->getInt64(0)); - for (int64 i = shape_->layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape_->layout().minor_to_major(i); + for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_->layout(), i); gep_indices.push_back(actual_index[dimension]); } return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices, diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc index 29cc0f81bd2c06538e28d1b593ee6a897fea0f27..23d2d4e87d26f4988ebddcf20f5a27af6a7fe0d6 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" namespace xla { void KernelSupportLibrary::For( @@ -62,4 +63,72 @@ void KernelSupportLibrary::If( false_block_generator(); llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_); } + +void KernelSupportLibrary::EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + KernelSupportLibrary::ArgumentVector arguments, + const std::function& + kernel_body_generator) { + llvm::Module* module = ir_builder->GetInsertBlock()->getModule(); + llvm::Function* function = + module->getFunction(llvm_ir::AsStringRef(kernel_name)); + + int64 null_arg_idx = -1; + std::vector sanitized_args; + sanitized_args.reserve(arguments.size()); + for (int64 i = 0, e = arguments.size(); i < e; i++) { + if (arguments[i]) { + sanitized_args.push_back(arguments[i]); + } else { + CHECK_EQ(null_arg_idx, -1); + null_arg_idx = i; + } + } + + if (!function) { + VLOG(2) << "Generating kernel for " << kernel_name; + std::vector arg_types; + std::transform(sanitized_args.begin(), sanitized_args.end(), + std::back_inserter(arg_types), + [](llvm::Value* arg) { return arg->getType(); }); + + auto* function_type = llvm::FunctionType::get( + ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false); + + function = llvm_ir::CreateFunction( + function_type, llvm::GlobalValue::InternalLinkage, + /*enable_fast_math=*/enable_fast_math, + /*optimize_for_size=*/optimize_for_size, kernel_name, module); + + llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder); + + auto* entry_bb = + llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function); + auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(), + /*retVal=*/nullptr, entry_bb); + // Set the insert point to before return_inst. + ir_builder->SetInsertPoint(return_inst); + + std::vector arg_values; + /* + * clang on OSX doesn't like std::transform or range for loop here. + * See https://github.com/tensorflow/tensorflow/issues/15196 + */ + for (llvm::Function::arg_iterator arg = function->arg_begin(), + arg_e = function->arg_end(); + arg != arg_e; ++arg) { + arg_values.push_back(arg); + } + if (null_arg_idx != -1) { + arg_values.insert(arg_values.begin() + null_arg_idx, nullptr); + } + kernel_body_generator(arg_values); + } else { + VLOG(3) << "Re-using kernel for " << kernel_name; + } + + ir_builder->CreateCall(function, llvm_ir::AsArrayRef(sanitized_args)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h index 9bafb7b57740b7acd0286c113c8a0585c0f93689..1c00b2aabd182da72e78d2c9c01cbe70cfd8e33c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h +++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ #include @@ -118,6 +118,60 @@ class KernelSupportLibrary { const std::function& true_block_generator, const std::function& false_block_generator = []() {}); + using ArgumentVector = tensorflow::gtl::ArraySlice; + + // Generates the following control flow structure: + // + // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) { + // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`}); + // } + // + // ... + // call @`kernel_name`(arguments[0], arguments[1] ...) + // ... + // + // If a function called `kernel_name` is already present in the module then + // that function is re-used. In that sense we're using the llvm::Module as a + // cache of outlined kernels, keyed by function name. + // + // If any of the values in `arguments` is nullptr (i.e. a nullptr + // llvm::Value*) then we ignore it when generating LLVM IR, and instead pass + // in a nullptr llvm::Value* in its position to `kernel_body_generator`. + // Currently we only support at most one nullptr value in `arguments`. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + ArgumentVector arguments, + const std::function& kernel_body_generator); + + // Thin wrappers around the more general EmitAndCallOutlinedKernel above. + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + const std::function& + kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2]); + }); + } + + static void EmitAndCallOutlinedKernel( + bool enable_fast_math, bool optimize_for_size, + llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name, + llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2, + llvm::Value* arg3, + const std::function& kernel_body_generator) { + EmitAndCallOutlinedKernel( + enable_fast_math, optimize_for_size, ir_builder, kernel_name, + {arg0, arg1, arg2, arg3}, [&](ArgumentVector args) { + kernel_body_generator(args[0], args[1], args[2], args[3]); + }); + } + private: llvm::IRBuilder<>* ir_builder_; bool prevent_unrolling_; @@ -125,4 +179,4 @@ class KernelSupportLibrary { }; } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_KERNEL_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index cd0c4a371e2b1cd0e1c52b77e47e8b081ab8e836..22141e7e00756483957f9cd4bc065a64556e854c 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -20,9 +20,11 @@ limitations under the License. #include #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -61,6 +63,16 @@ llvm::StringRef AsStringRef(tensorflow::StringPiece str) { return llvm::StringRef(str.data(), str.size()); } +std::unique_ptr DropConstantInitializers( + const llvm::Module& module) { + std::unique_ptr cloned_module = CloneModule(&module); + for (llvm::GlobalVariable& global_var : cloned_module->globals()) { + global_var.setInitializer(nullptr); + global_var.setLinkage(llvm::GlobalValue::LinkageTypes::ExternalLinkage); + } + return cloned_module; +} + string DumpModuleToString(const llvm::Module& module) { std::string buffer_string; llvm::raw_string_ostream ostream(buffer_string); @@ -142,7 +154,16 @@ llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: + case BF16: + // For BF16 we just need some type that is 16 bits wide so that it will + // take up the right amount of space in memory. LLVM does not have a BF16 + // type (the LLVM half type is IEEE 16 bit floating point, not bfloat), so + // we can't map it directly to an LLVM type. We will not map a BF16 + // addition to an addition on this type (int16) - this is just the type + // used for storage. return llvm::Type::getInt16Ty(module->getContext()); + case F16: + return llvm::Type::getHalfTy(module->getContext()); case S32: case U32: return llvm::Type::getInt32Ty(module->getContext()); @@ -200,8 +221,8 @@ llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { if (ShapeUtil::IsTuple(shape)) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); - } else { - for (int64 dimension : shape.layout().minor_to_major()) { + } else if (ShapeUtil::IsArray(shape)) { + for (int64 dimension : LayoutUtil::MinorToMajor(shape)) { result_type = llvm::ArrayType::get(result_type, shape.dimensions(dimension)); } @@ -280,6 +301,16 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); break; + case BF16: + value = llvm::ConstantInt::get( + ir_element_type, + tensorflow::bit_cast(literal.Get(*multi_index))); + break; + case F16: + value = llvm::ConstantFP::get( + ir_element_type, + static_cast(literal.Get(*multi_index))); + break; case F64: value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); @@ -304,7 +335,7 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, // decrements with each recursive call. We want to iterate through the // dimensions in major-to-minor order as we recurse so just index into // minor_to_major to get the dimension number for this level of the recursion. - int64 dimension = shape.layout().minor_to_major(dimension_index); + int64 dimension = LayoutUtil::Minor(shape.layout(), dimension_index); // Recursively call LiteralToConstant to construct subarrays for the // more-minor dimensions. Gather the subarrays into a vector for bundling into @@ -320,7 +351,7 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, if (elements.empty()) { element_type = ir_element_type; for (int i = 0; i < dimension_index; ++i) { - int64 index = shape.layout().minor_to_major(i); + int64 index = LayoutUtil::Minor(shape.layout(), i); element_type = llvm::ArrayType::get(element_type, shape.dimensions(index)); } @@ -653,6 +684,19 @@ static string GetProcessUniqueIrFileName(tensorflow::StringPiece prefix) { return uniquer->GetUniqueName(prefix); } +static Status CreateAndWriteStringToFile(const string& directory_name, + const string& file_name, + const string& text) { + std::unique_ptr f; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewWritableFile(file_name, &f)); + TF_RETURN_IF_ERROR(f->Append(text)); + TF_RETURN_IF_ERROR(f->Close()); + return Status::OK(); +} + Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized) { @@ -667,13 +711,70 @@ Status DumpIRToDirectory(const string& directory_name, directory_name, tensorflow::strings::StrCat(unique_and_safe_file_name, ".ll")); - std::unique_ptr f; - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->NewWritableFile(ir_file_name, &f)); - TF_RETURN_IF_ERROR(f->Append(DumpModuleToString(llvm_module))); - return f->Close(); + // For some models the embedded constants can be huge, so also dump the module + // with the constants stripped to get IR that is easier to manipulate. + string ir_no_constant_initializers_file_name = tensorflow::io::JoinPath( + directory_name, + tensorflow::strings::StrCat(unique_and_safe_file_name, "-noconst.ll")); + + TF_RETURN_IF_ERROR(CreateAndWriteStringToFile( + directory_name, ir_file_name, DumpModuleToString(llvm_module))); + return CreateAndWriteStringToFile( + directory_name, ir_no_constant_initializers_file_name, + DumpModuleToString(*DropConstantInitializers(llvm_module))); +} + +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module) { + llvm::Function* function = + llvm::Function::Create(function_type, linkage, AsStringRef(name), module); + function->setCallingConv(llvm::CallingConv::C); + function->addFnAttr("no-frame-pointer-elim", "false"); + + if (enable_fast_math) { + function->addFnAttr("unsafe-fp-math", "true"); + function->addFnAttr("no-infs-fp-math", "true"); + function->addFnAttr("no-nans-fp-math", "true"); + function->addFnAttr("no-signed-zeros-fp-math", "true"); + } + + // Add the optize attribute to the function if optimizing for size. This + // controls internal behavior of some optimization passes (e.g. loop + // unrolling). + if (optimize_for_size) { + function->addFnAttr(llvm::Attribute::OptimizeForSize); + } + + return function; +} + +void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) { + auto options = config.debug_options().xla_backend_extra_options(); + if (!options.empty()) { + std::vector fake_argv_storage; + fake_argv_storage.push_back(""); + for (const auto& it : options) { + // Skip options the XLA backend itself consumes. + if (!tensorflow::StringPiece(it.first).starts_with("xla_")) { + if (it.second.empty()) { + fake_argv_storage.push_back(it.first); + } else { + fake_argv_storage.push_back(it.first + "=" + it.second); + } + } + } + + VLOG(2) << "Passing argv to LLVM:"; + std::vector fake_argv; + for (const auto& s : fake_argv_storage) { + fake_argv.push_back(s.c_str()); + VLOG(2) << s; + } + llvm::cl::ParseCommandLineOptions(fake_argv.size(), &fake_argv[0]); + } } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index 063ead2b647d8fc5cc4f67004aaded80a2191fe9..4a10ec466dae6fdb56546fb8d8b353dcff6a5b8d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module_config.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -281,6 +282,16 @@ Status DumpIRToDirectory(const string& directory_name, const string& hlo_module_name, const llvm::Module& llvm_module, bool optimized); +llvm::Function* CreateFunction(llvm::FunctionType* function_type, + llvm::GlobalValue::LinkageTypes linkage, + bool enable_fast_math, bool optimize_for_size, + tensorflow::StringPiece name, + llvm::Module* module); + +// Extracts the xla_backend_extra_options from `config` and passes those that +// don't start with xla_ to LLVM. +void InitializeLLVMCommandLineOptions(const HloModuleConfig& config); + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 6fa4cd08c9e0ac30b83c0e2b49d98d930c2e15df..b6b918ec78a27b90325f72eea14b97f9aee43c54 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -51,37 +51,40 @@ LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, shape_(target_array.GetShape()), ir_builder_(ir_builder) {} +static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutputFusion( + const ElementGenerator& target_element_generator, + const std::vector& target_arrays, llvm::IRBuilder<>* ir_builder) { + return [=](const llvm_ir::IrArray::Index array_index) { + TF_ASSIGN_OR_RETURN(llvm::Value * target_element, + target_element_generator(array_index)); + CHECK(target_element->getType()->isStructTy()) + << "This BodyEmitter is for multi-output fusion, but target element " + "generator does not produce values of struct type."; + CHECK_EQ(target_element->getType()->getStructNumElements(), + target_arrays.size()); + + for (int64 i = 0; i < target_arrays.size(); ++i) { + target_arrays[i].EmitWriteArrayElement( + array_index, ir_builder->CreateExtractValue(target_element, i), + ir_builder); + } + return Status::OK(); + }; +} + LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder) - : body_emitter_([=](const llvm_ir::IrArray::Index array_index) - -> ::tensorflow::Status { - // Convert target_element_generator to a BodyEmitter. - TF_ASSIGN_OR_RETURN(llvm::Value * target_element, - target_element_generator(array_index)); - if (target_arrays.size() == 1) { - target_arrays[0].EmitWriteArrayElement(array_index, target_element, - ir_builder); - return tensorflow::Status::OK(); - } - - for (int64 i = 0; i < target_arrays.size(); ++i) { - target_arrays[i].EmitWriteArrayElement( - array_index, ir_builder_->CreateExtractValue(target_element, i), - ir_builder); - } - return tensorflow::Status::OK(); - }), + : body_emitter_(MakeBodyEmitterForMultiOutputFusion( + target_element_generator, + std::vector(target_arrays.begin(), target_arrays.end()), + ir_builder)), + shape_(target_arrays[0].GetShape()), ir_builder_(ir_builder) { - if (target_arrays.size() > 1) { - // The sanity check for multiple outputs. - shape_ = target_arrays[0].GetShape(); - for (int64 i = 1; i < target_arrays.size(); ++i) { - const Shape& element_shape = target_arrays[i].GetShape(); - CHECK(ShapeUtil::SameDimensions(shape_, element_shape)); - } - } else { - shape_ = target_arrays[0].GetShape(); + // Sanity check: In multi-output fusion, all shapes produced must have the + // same dimensions. + for (const IrArray& array : target_arrays) { + CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape())); } } @@ -99,8 +102,8 @@ IrArray::Index LoopEmitter::EmitIndexAndSetExitBasicBlock( // dimension (of the target shape). ForLoopNest loop_nest(loop_name, ir_builder_); IrArray::Index array_index(shape_.dimensions_size()); - for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) { - int64 dimension = shape_.layout().minor_to_major(i); + for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) { + int64 dimension = LayoutUtil::Major(shape_.layout(), i); std::unique_ptr loop = loop_nest.AddLoop( /*start_index=*/0, /*end_index=*/shape_.dimensions(dimension), diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h index 1ef1dc246442041698d96f6aff48794c8788f1d1..0fc528439a0d5bf8382dfcf2d8b3051f8900bf1d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h @@ -47,10 +47,16 @@ class LoopEmitter { // element of the given target array. LoopEmitter(const ElementGenerator& target_element_generator, const IrArray& target_array, llvm::IRBuilder<>* ir_builder); - // Same as previous method except emits multiple targets in an array. + + // Constructs a LoopEmitter that emits one element into each of N separate + // arrays on each iteration of the loop. + // + // This is used for multi-output fusion. target_element_generator must + // produce an LLVM struct with N elements. LoopEmitter(const ElementGenerator& target_element_generator, tensorflow::gtl::ArraySlice target_arrays, llvm::IRBuilder<>* ir_builder); + LoopEmitter(const LoopEmitter&) = delete; LoopEmitter& operator=(const LoopEmitter&) = delete; virtual ~LoopEmitter() = default; diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index f72f482e3128c61e53cc454e7da8b5795ba6f695..175b081e84d31779b15560cb0998011fe046ca01 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" @@ -90,4 +90,4 @@ Status EmitParallelFusedDynamicUpdateSliceInPlace( } // namespace llvm_ir } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc deleted file mode 100644 index e8c6a83618eaa8430521197f1c166cb7eb11a28e..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.cc +++ /dev/null @@ -1,150 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h" - -#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" - -namespace xla { -VectorSupportLibrary::VectorSupportLibrary(PrimitiveType primitive_type, - int64 vector_size, - llvm::IRBuilder<>* ir_builder, - std::string name) - : vector_size_(vector_size), - primitive_type_(primitive_type), - ir_builder_(ir_builder), - name_(std::move(name)) { - scalar_type_ = llvm_ir::PrimitiveTypeToIrType( - primitive_type, ir_builder_->GetInsertBlock()->getModule()); - scalar_pointer_type_ = llvm::PointerType::getUnqual(scalar_type_); - vector_type_ = llvm::VectorType::get(scalar_type_, vector_size); - vector_pointer_type_ = llvm::PointerType::getUnqual(vector_type_); -} - -llvm::Value* VectorSupportLibrary::Mul(llvm::Value* lhs, llvm::Value* rhs) { - if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFMul(lhs, rhs, name()); - } else { - return ir_builder()->CreateMul(lhs, rhs, name()); - } -} - -llvm::Value* VectorSupportLibrary::Add(llvm::Value* lhs, llvm::Value* rhs) { - if (scalar_type_->isFloatingPointTy()) { - return ir_builder()->CreateFAdd(lhs, rhs, name()); - } else { - return ir_builder()->CreateAdd(lhs, rhs, name()); - } -} - -llvm::Value* VectorSupportLibrary::ComputeOffsetPointer( - llvm::Value* base_pointer, llvm::Value* offset_elements) { - if (base_pointer->getType() != scalar_pointer_type()) { - base_pointer = ir_builder()->CreateBitCast(base_pointer, - scalar_pointer_type(), name()); - } - return ir_builder()->CreateInBoundsGEP(base_pointer, {offset_elements}, - name()); -} - -llvm::Value* VectorSupportLibrary::LoadVector(llvm::Value* pointer) { - if (pointer->getType() != vector_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, vector_pointer_type(), name()); - } - return ir_builder()->CreateAlignedLoad( - pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); -} - -llvm::Value* VectorSupportLibrary::LoadScalar(llvm::Value* pointer) { - if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); - } - return ir_builder()->CreateAlignedLoad( - pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_), name()); -} - -void VectorSupportLibrary::StoreVector(llvm::Value* value, - llvm::Value* pointer) { - if (pointer->getType() != vector_pointer_type()) { - pointer = ir_builder()->CreateBitCast(pointer, vector_pointer_type()); - } - ir_builder()->CreateAlignedStore( - value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); -} - -void VectorSupportLibrary::StoreScalar(llvm::Value* value, - llvm::Value* pointer) { - if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); - } - ir_builder()->CreateAlignedStore( - value, pointer, ShapeUtil::ByteSizeOfPrimitiveType(primitive_type_)); -} - -llvm::Value* VectorSupportLibrary::LoadBroadcast(llvm::Value* pointer) { - if (pointer->getType() != scalar_pointer_type()) { - pointer = - ir_builder()->CreateBitCast(pointer, scalar_pointer_type(), name()); - } - return ir_builder()->CreateVectorSplat( - vector_size(), ir_builder()->CreateLoad(pointer), name()); -} - -llvm::Value* VectorSupportLibrary::AddReduce(llvm::Value* vector) { - llvm::SmallVector mask(vector_size(), nullptr); - for (unsigned i = vector_size(); i != 1; i >>= 1) { - // On every iteration, we shuffle half of the remaining lanes to the top - // half of shuffle, and add two old and the new vector. - - for (unsigned j = 0; j < vector_size(); ++j) { - if (j < (i / 2)) { - mask[j] = ir_builder()->getInt32(i / 2 + j); - } else { - mask[j] = llvm::UndefValue::get(ir_builder()->getInt32Ty()); - } - } - - llvm::Value* half_remaining_lanes = ir_builder()->CreateShuffleVector( - vector, llvm::UndefValue::get(vector_type()), - llvm::ConstantVector::get(mask), ""); - vector = Add(vector, half_remaining_lanes); - } - - return ir_builder()->CreateExtractElement(vector, ir_builder()->getInt32(0), - name()); -} - -llvm::Value* VectorSupportLibrary::GetZeroVector() { - return llvm::Constant::getNullValue(vector_type()); -} - -llvm::Value* VectorSupportLibrary::GetZeroScalar() { - return llvm::Constant::getNullValue(scalar_type()); -} - -LlvmVariable::LlvmVariable(llvm::Type* type, llvm::IRBuilder<>* ir_builder) - : ir_builder_(ir_builder) { - alloca_ = llvm_ir::EmitAllocaAtFunctionEntry(type, "", ir_builder_); -} - -llvm::Value* LlvmVariable::Get() { return ir_builder_->CreateLoad(alloca_); } - -void LlvmVariable::Set(llvm::Value* new_value) { - ir_builder_->CreateStore(new_value, alloca_); -} -} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h deleted file mode 100644 index 3072677ab05aa91c736baaa0dc3023329d810a52..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/xla/service/llvm_ir/vector_support_library.h +++ /dev/null @@ -1,174 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ - -#include - -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" - -namespace xla { -// A thin wrapper around llvm_util.h to make code generating vector math flow -// more readable. -class VectorSupportLibrary { - public: - // This VectorSupportLibrary instance remembers `primitive_type` and - // `vector_size`, and these are implicitly used by the methods on this - // instance (i.e. LoadVector will load a vector of type <`vector_size` x - // `primitive_type`>). - VectorSupportLibrary(PrimitiveType primitive_type, int64 vector_size, - llvm::IRBuilder<>* ir_builder, std::string name); - - llvm::Value* Mul(llvm::Value* lhs, llvm::Value* rhs); - llvm::Value* Mul(int64 lhs, llvm::Value* rhs) { - return Mul(ir_builder()->getInt64(lhs), rhs); - } - - llvm::Value* Add(llvm::Value* lhs, llvm::Value* rhs); - llvm::Value* Add(int64 lhs, llvm::Value* rhs) { - return Add(ir_builder()->getInt64(lhs), rhs); - } - - llvm::Value* MulAdd(llvm::Value* a, llvm::Value* b, llvm::Value* c) { - return Add(c, Mul(a, b)); - } - - llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, - llvm::Value* offset_elements); - llvm::Value* ComputeOffsetPointer(llvm::Value* base_pointer, - int64 offset_elements) { - return ComputeOffsetPointer(base_pointer, - ir_builder()->getInt64(offset_elements)); - } - - llvm::Value* LoadVector(llvm::Value* pointer); - - llvm::Value* LoadVector(llvm::Value* base_pointer, - llvm::Value* offset_elements) { - return LoadVector(ComputeOffsetPointer(base_pointer, offset_elements)); - } - - llvm::Value* LoadVector(llvm::Value* base_pointer, int64 offset_elements) { - return LoadVector(base_pointer, ir_builder()->getInt64(offset_elements)); - } - - llvm::Value* LoadScalar(llvm::Value* pointer); - - llvm::Value* LoadScalar(llvm::Value* base_pointer, - llvm::Value* offset_elements) { - return LoadScalar(ComputeOffsetPointer(base_pointer, offset_elements)); - } - - llvm::Value* LoadScalar(llvm::Value* base_pointer, int64 offset_elements) { - return LoadScalar(base_pointer, ir_builder()->getInt64(offset_elements)); - } - - void StoreVector(llvm::Value* value, llvm::Value* pointer); - - void StoreVector(llvm::Value* value, llvm::Value* base_pointer, - llvm::Value* offset_elements) { - StoreVector(value, ComputeOffsetPointer(base_pointer, offset_elements)); - } - - void StoreVector(llvm::Value* value, llvm::Value* base_pointer, - int64 offset_elements) { - StoreVector(value, base_pointer, ir_builder()->getInt64(offset_elements)); - } - - void StoreScalar(llvm::Value* value, llvm::Value* pointer); - void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, - llvm::Value* offset_elements) { - StoreScalar(value, ComputeOffsetPointer(base_pointer, offset_elements)); - } - - void StoreScalar(llvm::Value* value, llvm::Value* base_pointer, - int64 offset_elements) { - StoreScalar(base_pointer, ir_builder()->getInt64(offset_elements)); - } - - llvm::Value* LoadBroadcast(llvm::Value* pointer); - llvm::Value* LoadBroadcast(llvm::Value* base_pointer, - llvm::Value* offset_elements) { - return LoadBroadcast(ComputeOffsetPointer(base_pointer, offset_elements)); - } - llvm::Value* LoadBroadcast(llvm::Value* base_pointer, int64 offset_elements) { - return LoadBroadcast(base_pointer, ir_builder()->getInt64(offset_elements)); - } - - llvm::Value* AddReduce(llvm::Value* vector); - - llvm::Value* GetZeroVector(); - llvm::Value* GetZeroScalar(); - - llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } - int64 vector_size() const { return vector_size_; } - llvm::Type* vector_type() const { return vector_type_; } - llvm::Type* vector_pointer_type() const { return vector_pointer_type_; } - llvm::Type* scalar_type() const { return scalar_type_; } - llvm::Type* scalar_pointer_type() const { return scalar_pointer_type_; } - - const std::string& name() const { return name_; } - - private: - int64 vector_size_; - PrimitiveType primitive_type_; - llvm::IRBuilder<>* ir_builder_; - llvm::Type* vector_type_; - llvm::Type* vector_pointer_type_; - llvm::Type* scalar_type_; - llvm::Type* scalar_pointer_type_; - std::string name_; -}; - -// This wraps an alloca-backed stack variable which LLVM's SSA construction pass -// can later convert to a SSA value. -class LlvmVariable { - public: - LlvmVariable(llvm::Type*, llvm::IRBuilder<>* ir_builder); - - llvm::Value* Get(); - void Set(llvm::Value* new_value); - - private: - llvm::AllocaInst* alloca_; - llvm::IRBuilder<>* ir_builder_; -}; - -class VectorVariable : public LlvmVariable { - public: - VectorVariable(VectorSupportLibrary* vector_support, - llvm::Value* initial_value) - : LlvmVariable(vector_support->vector_type(), - vector_support->ir_builder()) { - Set(initial_value); - } -}; - -class ScalarVariable : public LlvmVariable { - public: - ScalarVariable(VectorSupportLibrary* vector_support, - llvm::Value* initial_value) - : LlvmVariable(vector_support->scalar_type(), - vector_support->ir_builder()) { - Set(initial_value); - } -}; -} // namespace xla - -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_VECTOR_SUPPORT_LIBRARY_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 06f43bd3cb2376d34a3104133c868c4f4e5cc730..07f989d4faea199e812e54d2ae74d3ff9e7fa19a 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/execution_options_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/backend.h" @@ -71,7 +72,7 @@ LocalService::LocalService(const ServiceOptions& options, StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal) { + const ExecutableBuildOptions& build_options) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(computation)); VersionedComputationHandle versioned_handle = @@ -84,27 +85,47 @@ StatusOr> LocalService::CompileExecutable( // Validate incoming layouts. if (argument_layouts.size() != program_shape->parameters_size()) { return InvalidArgument( - "invalid number of arguments for computation: expected %d, got %zu", + "Invalid number of arguments for computation: expected %d, got %zu.", program_shape->parameters_size(), argument_layouts.size()); } for (int i = 0; i < argument_layouts.size(); ++i) { const Shape& argument_shape = *argument_layouts[i]; TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(argument_shape)); if (!ShapeUtil::Compatible(argument_shape, program_shape->parameters(i))) { + tensorflow::gtl::optional metadata = + user_computation->ParameterMetadata(i); + auto metadata_string = [&metadata]() -> string { + if (!metadata.has_value()) { + return ""; + } + CHECK(metadata.value() != nullptr); + const OpMetadata& m = *metadata.value(); + if (!m.source_file().empty()) { + return tensorflow::strings::Printf( + " (%s:%d)", m.source_file().c_str(), m.source_line()); + } + return ""; + }; return InvalidArgument( - "invalid argument shape for argument %d, expected %s, got %s", i, + "Invalid argument shape for argument %d%s, expected %s, got %s.", i, + metadata_string().c_str(), ShapeUtil::HumanString(program_shape->parameters(i)).c_str(), ShapeUtil::HumanString(argument_shape).c_str()); } } - if (result_layout != nullptr) { - TF_RETURN_IF_ERROR( - ValidateResultShapeWithLayout(*result_layout, program_shape->result())); + if (build_options.result_layout() != nullptr) { + TF_RETURN_IF_ERROR(ValidateResultShapeWithLayout( + *build_options.result_layout(), program_shape->result())); } ExecutionOptions execution_options = CreateDefaultExecutionOptions(); - if (result_layout != nullptr) { - *execution_options.mutable_shape_with_output_layout() = *result_layout; + if (build_options.generate_hlo_graph().has_value()) { + execution_options.mutable_debug_options()->set_xla_generate_hlo_graph( + build_options.generate_hlo_graph().value()); + } + if (build_options.result_layout() != nullptr) { + *execution_options.mutable_shape_with_output_layout() = + *build_options.result_layout(); } else { *execution_options.mutable_shape_with_output_layout() = program_shape->result(); @@ -113,15 +134,22 @@ StatusOr> LocalService::CompileExecutable( } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options)); + CreateModuleConfig(*program_shape, argument_layouts, &execution_options, + *user_computation)); - TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, - execute_backend_->stream_executor(device_ordinal)); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(build_options.device_ordinal())); - std::vector argument_buffers( - argument_layouts.size()); return BuildExecutable(versioned_handle, std::move(module_config), - argument_buffers, execute_backend_.get(), executor); + execute_backend_.get(), executor, + build_options.device_allocator()); +} + +StatusOr LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) { + return backend().computation_placer()->DeviceId( + replica_number, /*computation=*/0, options_.number_of_replicas(), + /*computation_count=*/1); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index 52c4346385eb663baa6e7579d7b3883ba084205b..15e120685e1be9190d49fdaf5ed6706bdf991a6c 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -18,6 +18,7 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/client/executable_build_options.h" #include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/compiler.h" #include "tensorflow/compiler/xla/service/device_memory_allocator.h" @@ -41,11 +42,20 @@ class LocalService : public Service { // Builds an Executable with the given argument layouts and options. If // result_layout is non-null, then the executable is compiled to produce a - // result of the given layout. + // result of the given layout. If device_allocator is non-null, then the + // compiler may use it to allocate temp space on the device. The compiler is + // responsible for freeing any memory it allocates this way. StatusOr> CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal); + const ExecutableBuildOptions& options); + + // Returns the device ordinal that corresponds to the given replica number. + // + // This returns an error if there is not a one-to-one correspondence of + // replicas to device ordinals, but is useful as a short term mechanism for + // the "easy" case where a single replica is a single device. + StatusOr ReplicaNumberToDeviceOrdinal(int replica_number); private: explicit LocalService(const ServiceOptions& options, diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index 598d08b7203b25b194dfc3b3125ec58c96b2cd4c..f4c63dd86b4d8a6f598d46047012e4e5bc7b3d7e 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -90,4 +90,4 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LOGICAL_BUFFER_ANALYSIS_H_ diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index a0d08c288dbcc45e83a36ce7b094b04a9dbae532..7d8c05fffa4ab11d7dbf9956d2cb7ebd5bcdd3c4 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -17,12 +17,44 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { +namespace { + +bool IsAllowed(char character) { + auto c = static_cast(character); + return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-'; +} + +} // namespace + +NameUniquer::NameUniquer(const string& separator) { + CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed)) + << "separator should comprises allowed characters only"; + separator_ = separator; +} + +/*static*/ string NameUniquer::GetSanitizedName(const string& name) { + string result = name; + CHECK(!result.empty()) << "name should not be empty"; + char c = static_cast(result[0]); + if (!isalpha(c) && c != '_') { + result[0] = '_'; + } + for (int i = 1; i < result.length(); i++) { + if (!IsAllowed(result[i])) { + result[i] = '_'; + } + } + return result; +} + string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); + root = GetSanitizedName(root); // Strip away numeric suffix (if any). Only recognize separator if it is in // the middle of the name. diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index ed379b52258463b960dea788721c2c4325ef0260..4139c2700b25e8600182a034a8ac6f4f041c12e6 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -28,14 +28,21 @@ namespace xla { // Simple stateful class that helps generate "unique" names. To use it, simply // call GetUniqueName as many times as needed. The names returned by // GetUniqueName are guaranteed to be distinct for this instance of the class. +// Note that the names will be sanitized to match regexp +// "[a-zA-Z_][a-zA-Z0-9_.-]*". class NameUniquer { public: - explicit NameUniquer(const string& separator = "__") - : separator_(separator) {} + // The separator must contain allowed characters only: "[a-zA-Z0-9_.-]". + explicit NameUniquer(const string& separator = "__"); - // Get a unique name in a string, with an optional prefix for convenience. + // Get a sanitized unique name in a string, with an optional prefix for + // convenience. string GetUniqueName(tensorflow::StringPiece prefix = ""); + // Sanitizes and returns the name. Unallowed characters will be replaced with + // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*". + static string GetSanitizedName(const string& name); + private: // The string to use to separate the prefix of the name from the uniquing // integer value. diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc index 9f0747a6e2175a968d8f3661ac51512009e86f29..4258cf16876ab46dce6df062ab701b1b1a4a7580 100644 --- a/tensorflow/compiler/xla/service/name_uniquer_test.cc +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -60,12 +60,30 @@ TEST_F(NameUniquerTest, NumericSuffixes) { EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); +} + +TEST_F(NameUniquerTest, Sanitize) { + NameUniquer uniquer("_"); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo_1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54")); + EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1")); + EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo")); + + // Invalid characters will be replaced with '_'. + EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000")); + EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000")); + EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1")); // Separator is only recognized in the middle of the prefix. - EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); - EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); - EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); - EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("_10", uniquer.GetUniqueName( + ".10")); // the leading '.' is replaced with '_'. + EXPECT_EQ("_10_1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("_10_2", uniquer.GetUniqueName("_10")); + EXPECT_EQ("foobar_", uniquer.GetUniqueName("foobar_")); + EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_")); } } // namespace diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 63f3bfb36cedeb44b190e1e8a5584d334f94b585..aa974ee61a27de9c19e97d8a6eb48f9261ce4bd9 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -33,10 +33,32 @@ namespace se = ::perftools::gputools; namespace xla { +using tensorflow::str_util::Lowercase; + // Minimum supported CUDA compute capability is 3.5. constexpr int kMinCudaComputeCapabilityMajor = 3; constexpr int kMinCudaComputeCapabilityMinor = 5; +// The name of the interpreter platform. +constexpr char kInterpreter[] = "interpreter"; + +namespace { + +string CanonicalPlatformName(const string& name) { + string platform_str = Lowercase(name); + // "cpu" and "host" mean the same thing. + if (platform_str == "cpu") { + platform_str = "host"; + } + // "gpu" and "cuda" mean the same thing. + if (platform_str == "gpu") { + platform_str = "cuda"; + } + return platform_str; +} + +} // namespace + /* static */ StatusOr> PlatformUtil::GetSupportedPlatforms() { se::MultiPlatformManager::PlatformMap platform_map; @@ -78,7 +100,7 @@ PlatformUtil::GetSupportedPlatforms() { return platforms; } -/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { +/* static */ StatusOr PlatformUtil::GetSolePlatform() { TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); if (platforms.empty()) { return NotFound("no platforms found"); @@ -87,26 +109,42 @@ PlatformUtil::GetSupportedPlatforms() { } // Multiple platforms present and we can't pick a reasonable default. - auto l = [](string* out, const se::Platform* p) { out->append(p->Name()); }; - string platforms_string = tensorflow::str_util::Join(platforms, ", ", l); + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); return InvalidArgument( "must specify platform because more than one platform found: %s", platforms_string.c_str()); } -/*static*/ StatusOr PlatformUtil::GetPlatform( - const string& platform_name) { - using tensorflow::str_util::Lowercase; - string platform_str = Lowercase(platform_name); - // "cpu" and "host" mean the same thing. - if (platform_str == "cpu") { - platform_str = "host"; - } - // "gpu" and "cuda" mean the same thing. - if (platform_str == "gpu") { - platform_str = "cuda"; +/* static */ StatusOr PlatformUtil::GetDefaultPlatform() { + TF_ASSIGN_OR_RETURN(auto platforms, GetSupportedPlatforms()); + if (platforms.empty()) { + return NotFound("no platforms found"); + } else if (platforms.size() == 1) { + return platforms[0]; + } else if (platforms.size() == 2) { + for (int i = 0; i < 2; i++) { + if (Lowercase(platforms[i]->Name()) == kInterpreter && + Lowercase(platforms[1 - i]->Name()) != kInterpreter) { + return platforms[1 - i]; + } + } } + // Multiple platforms present and we can't pick a reasonable default. + string platforms_string = tensorflow::str_util::Join( + platforms, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "must specify platform because more than one platform (except for the " + "interpreter platform) found: %s", + platforms_string.c_str()); +} + +/*static*/ StatusOr PlatformUtil::GetPlatform( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); for (se::Platform* platform : platforms) { if (Lowercase(platform->Name()) == platform_str) { @@ -116,6 +154,32 @@ PlatformUtil::GetSupportedPlatforms() { return InvalidArgument("platform %s not found", platform_name.c_str()); } +/*static*/ StatusOr PlatformUtil::GetPlatformExceptFor( + const string& platform_name) { + string platform_str = CanonicalPlatformName(platform_name); + + TF_ASSIGN_OR_RETURN(auto platforms, PlatformUtil::GetSupportedPlatforms()); + std::vector matched; + for (se::Platform* platform : platforms) { + if (Lowercase(platform->Name()) != platform_name) { + matched.push_back(platform); + } + } + if (matched.empty()) { + return InvalidArgument("unable to find platform that is not %s", + platform_name.c_str()); + } + if (matched.size() == 1) { + return matched[0]; + } + string matched_string = tensorflow::str_util::Join( + matched, ", ", + [](string* out, const se::Platform* p) { out->append(p->Name()); }); + return InvalidArgument( + "found multiple platforms %s, but expected one platform except for %s", + matched_string.c_str(), platform_name.c_str()); +} + // Returns whether the device underlying the given StreamExecutor is supported // by XLA. static bool IsDeviceSupported(se::StreamExecutor* executor) { diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index a59d4ffe87f568ac786e4b2d3bf6983bc0d4695a..69188820a70707d9c9be10b20fb7de92ad4d9873 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -37,16 +37,28 @@ class PlatformUtil { static StatusOr> GetSupportedPlatforms(); - // Convenience function which returns the default supported platform. If + // Convenience function which returns the default supported platform for + // tests. If exactly one supported platform is present, then this platform is + // the default platform. If exactly two platforms are present and one of them + // is the interpreter platform, then the other platform is the default + // platform. Otherwise returns an error. + static StatusOr GetDefaultPlatform(); + + // Convenience function which returns the sole supported platform. If // exactly one supported platform is present, then this platform is the // default platform. Otherwise returns an error. - static StatusOr GetDefaultPlatform(); + static StatusOr GetSolePlatform(); // Returns the platform according to the given name. Returns error if there is // no such platform. static StatusOr GetPlatform( const string& platform_name); + // Returns exactly one platform that does not have given name. Returns error + // if there is no such platform, or there are multiple such platforms. + static StatusOr GetPlatformExceptFor( + const string& platform_name); + // Returns a vector of StreamExecutors for the given platform. The vector is // indexed by device ordinal (device numbering used by StreamExecutor). If an // element is nullptr, then the device is present by not supported by XLA. diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index 0fb90230f2f39a841973361f63d17af579a1342b..e62bafc50b0e1270702621c9ea7b2ee43e001fe0 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -101,8 +101,9 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( IsReshapeOrTranspose(operand) && !CanTriviallyChangeShape(operand->operand(0))) { VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " - << hlo->ToStringNoMetadata() << ":\n\t" - << operand->ToStringNoMetadata(); + << hlo->ToString(HloPrintOptions().set_print_metadata(false)) + << ":\n\t" + << operand->ToString(HloPrintOptions().set_print_metadata(false)); return operand; } } @@ -133,8 +134,9 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { bool AllOperandsHaveEasyShapeChanges( const HloInstruction* instruction, const HloInstruction* first_reshape_operand) { + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); VLOG(3) << "** Checking whether all operands have easy shape changes: " - << instruction->ToStringNoMetadata(); + << instruction->ToString(print_no_metadata); // Check whether all operands: // 0. Have the same dimensions as the output -- if not, it may be // implicitly broadcast, which can confound the movement's @@ -151,21 +153,21 @@ bool AllOperandsHaveEasyShapeChanges( VLOG(5) << "Operand shape differs from output shape; may be " "implicitly broadcast, so preventing " "movement\n\toperand: " - << operand->ToStringNoMetadata() - << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + << operand->ToString(print_no_metadata) << "\n\tinstruction: " + << instruction->ToString(print_no_metadata); return false; } if (AreEquivalentReshapes(first_reshape_operand, operand)) { VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToStringNoMetadata() - << "\n\toperand: " << operand->ToStringNoMetadata(); + << first_reshape_operand->ToString(print_no_metadata) + << "\n\toperand: " << operand->ToString(print_no_metadata); continue; } if (CanTriviallyChangeShape(operand)) { VLOG(5) << "Operand can trivially change shape: " - << operand->ToStringNoMetadata(); + << operand->ToString(print_no_metadata); continue; } @@ -173,12 +175,12 @@ bool AllOperandsHaveEasyShapeChanges( // well. VLOG(5) << "Operand is neither equalivant to the first Reshape operand" "nor can trivially change shape: " - << operand->ToStringNoMetadata(); + << operand->ToString(print_no_metadata); return false; } VLOG(3) << "All operands have easy shape changes: " - << instruction->ToStringNoMetadata(); + << instruction->ToString(print_no_metadata); return true; } @@ -250,11 +252,13 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, return false; } + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); // At this point we've decided to sink reshape/transpose operands. const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); VLOG(3) << "** Sinking reshape or transpose: " - << instruction->ToStringNoMetadata() << "\n\tfirst reshape operand: " - << first_reshape_operand->ToStringNoMetadata() + << instruction->ToString(print_no_metadata) + << "\n\tfirst reshape operand: " + << first_reshape_operand->ToString(print_no_metadata) << "\n\tnew operand shape: " << ShapeUtil::HumanString(new_operand_shape); @@ -267,7 +271,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, continue; } VLOG(3) << "Updating operand #" << i << ": " - << operands[i]->ToStringNoMetadata(); + << operands[i]->ToString(print_no_metadata); operands[i] = UpdateOperand(computation, first_reshape_operand, new_operand_shape, operands[i]); } @@ -298,7 +302,7 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, switch (first_reshape_operand->opcode()) { case HloOpcode::kReshape: VLOG(3) << "Creating new reshape for new elementwise op: " - << new_elementwise->ToStringNoMetadata(); + << new_elementwise->ToString(print_no_metadata); new_reshape = HloInstruction::CreateReshape(instruction->shape(), new_elementwise); break; diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index d997cab83f8c2bc74632e49f23e690ffb17b901a..98dfc89867ab33788c4cc837a66d6751a1ef2507 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -34,8 +34,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_module_config.h" +#include "tensorflow/compiler/xla/service/hlo_proto_util.h" #include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/service/session.pb.h" +#include "tensorflow/compiler/xla/service/source_map_util.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" @@ -55,46 +57,38 @@ namespace se = ::perftools::gputools; using ::tensorflow::strings::Printf; using ::tensorflow::strings::StrCat; +using ::xla::source_map_util::InvalidParameterArgument; namespace xla { namespace { -// Copies the contents of an Allocation into a Literal proto. -tensorflow::Status LiteralFromAllocation(const Allocation* allocation, - const Shape& literal_shape, - Literal* literal) { - TF_ASSIGN_OR_RETURN( - se::StreamExecutor * executor, - allocation->backend()->stream_executor(allocation->device_ordinal())); - return allocation->backend()->transfer_manager()->TransferLiteralFromDevice( - executor, allocation->device_memory(), allocation->shape(), literal_shape, - literal); -} - // Records the arguments used to invoke a computation in a SessionModule // proto. tensorflow::Status RecordArguments( - const tensorflow::gtl::ArraySlice arg_allocations, + const tensorflow::gtl::ArraySlice arguments, + se::StreamExecutor* executor, TransferManager* transfer_manager, SessionModule* module) { module->clear_arguments(); - for (const Allocation* allocation : arg_allocations) { - Literal argument; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, allocation->shape(), &argument)); - *module->add_arguments() = argument.ToProto(); + for (const ShapedBuffer* argument : arguments) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, *argument)); + *module->add_arguments() = literal->ToProto(); } return tensorflow::Status::OK(); } // Records the result of a computation in a SessionModule proto. -tensorflow::Status RecordResult(const Allocation* result_allocation, +tensorflow::Status RecordResult(const ShapedBuffer& result, + se::StreamExecutor* executor, + TransferManager* transfer_manager, SessionModule* module) { module->clear_result(); - Literal result; - TF_RETURN_IF_ERROR(LiteralFromAllocation( - result_allocation, result_allocation->shape(), &result)); - *module->mutable_result() = result.ToProto(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr literal, + transfer_manager->TransferLiteralFromDevice(executor, result)); + *module->mutable_result() = literal->ToProto(); return tensorflow::Status::OK(); } @@ -152,7 +146,9 @@ int ServiceOptions::intra_op_parallelism_threads() const { Service::Service(const ServiceOptions& options, std::unique_ptr execute_backend) - : options_(options), execute_backend_(std::move(execute_backend)) { + : options_(options), + allocation_tracker_(execute_backend.get()), + execute_backend_(std::move(execute_backend)) { CHECK_GT(options_.number_of_replicas(), 0); if (execute_backend_) { if (execute_backend_->device_count() > 0) { @@ -235,41 +231,40 @@ tensorflow::Status Service::ValidateResultShapeWithLayout( return ShapeUtil::ValidateShape(shape_with_layout); } -StatusOr> Service::ResolveAndValidateArguments( +StatusOr> Service::ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal) { - std::vector allocations; + int device_ordinal) { + std::vector shaped_buffers; for (size_t i = 0; i < arguments.size(); ++i) { - auto allocation_status = allocation_tracker_.Resolve(*arguments[i]); - if (!allocation_status.ok()) { - return Status(allocation_status.status().code(), - StrCat(allocation_status.status().error_message(), ", ", + auto buffer_status = allocation_tracker_.Resolve(*arguments[i]); + if (!buffer_status.ok()) { + return Status(buffer_status.status().code(), + StrCat(buffer_status.status().error_message(), ", ", "failed to resolve allocation for parameter ", i)); } - const Allocation* allocation = allocation_status.ValueOrDie(); + const ShapedBuffer* shaped_buffer = buffer_status.ValueOrDie(); // Verify allocation is same platform and device as the execution. - if (allocation->backend() != backend || - allocation->device_ordinal() != device_ordinal) { + if (shaped_buffer->platform() != execute_backend_->platform() || + shaped_buffer->device_ordinal() != device_ordinal) { return InvalidArgument( - "argument %lu is on device %s but computation will be executed " + "argument %lu is on device %s:%d but computation will be executed " "on device %s", - i, - allocation->backend() - ->device_name(allocation->device_ordinal()) - .c_str(), - backend->device_name(device_ordinal).c_str()); + i, shaped_buffer->platform()->Name().c_str(), + shaped_buffer->device_ordinal(), + execute_backend_->device_name(device_ordinal).c_str()); } - allocations.push_back(allocation); + shaped_buffers.push_back(shaped_buffer); } - return allocations; + return shaped_buffers; } StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options) { + const ExecutionOptions* execution_options, + const UserComputation& user_computation) { auto config = MakeUnique(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -283,8 +278,10 @@ StatusOr> Service::CreateModuleConfig( // ProgramShape. if (!ShapeUtil::Compatible(*argument_shapes[i], program_shape.parameters(i))) { - return InvalidArgument( - "computation expects parameter %d to have shape %s, given shape %s", + return InvalidParameterArgument( + *user_computation.ParameterMetadata(i).value(), + "Argument does not match shape of computation parameter %d: want %s, " + "got %s", i, ShapeUtil::HumanString(program_shape.parameters(i)).c_str(), ShapeUtil::HumanString(*argument_shapes[i]).c_str()); } @@ -325,20 +322,23 @@ StatusOr> Service::CreateModuleConfig( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options) { + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions& execution_options, + const UserComputation& user_computation) { std::vector argument_shapes; for (const auto* arg : arguments) { - argument_shapes.push_back(&arg->shape()); + argument_shapes.push_back(&arg->on_host_shape()); } - return CreateModuleConfig(program_shape, argument_shapes, &execution_options); + return CreateModuleConfig(program_shape, argument_shapes, &execution_options, + user_computation); } StatusOr>> Service::BuildExecutables( std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector> executors) { + std::vector> executors, + DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. @@ -384,7 +384,8 @@ StatusOr>> Service::BuildExecutables( TF_ASSIGN_OR_RETURN( std::vector> executables, - backend->compiler()->Compile(std::move(modules), std::move(executors))); + backend->compiler()->Compile(std::move(modules), std::move(executors), + device_allocator)); for (size_t i = 0; i < versioned_handles.size(); ++i) { if (!module_configs[i]->debug_options().xla_dump_executions_to().empty()) { @@ -397,10 +398,8 @@ StatusOr>> Service::BuildExecutables( StatusOr> Service::BuildExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, - Backend* backend, se::StreamExecutor* executor) { + std::unique_ptr module_config, Backend* backend, + se::StreamExecutor* executor, DeviceMemoryAllocator* device_allocator) { VLOG(1) << Printf("BuildExecutable on service %p with handle %s", this, versioned_handle.ToString().c_str()); @@ -430,12 +429,15 @@ StatusOr> Service::BuildExecutable( /*include_unreachable_instructions=*/ true)); - TF_ASSIGN_OR_RETURN( - module, backend->compiler()->RunHloPasses(std::move(module), executor)); + TF_RETURN_IF_ERROR(MaybeDumpHloModule(*module)); TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend->compiler()->RunBackend(std::move(module), executor)); + module, backend->compiler()->RunHloPasses(std::move(module), executor, + device_allocator)); + + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + backend->compiler()->RunBackend( + std::move(module), executor, device_allocator)); if (!other_directory_path.empty()) { executable->set_session_module(std::move(session_module)); @@ -446,11 +448,9 @@ StatusOr> Service::BuildExecutable( StatusOr> Service::BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - ExecutionProfile* profile) { + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, + DeviceMemoryAllocator* device_allocator) { std::shared_ptr executable = compilation_cache_.LookUp(versioned_handle, *module_config); @@ -471,8 +471,8 @@ StatusOr> Service::BuildAndCacheExecutable( HloModuleConfig original_module_config = *module_config; TF_ASSIGN_OR_RETURN( std::unique_ptr executable_unique_ptr, - BuildExecutable(versioned_handle, std::move(module_config), arguments, - backend, executor)); + BuildExecutable(versioned_handle, std::move(module_config), backend, + executor, device_allocator)); if (profile != nullptr) { uint64 end_micros = tensorflow::Env::Default()->NowMicros(); @@ -489,9 +489,7 @@ StatusOr> Service::BuildAndCacheExecutable( StatusOr> Service::ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, ExecutionProfile* profile) { @@ -547,7 +545,7 @@ Service::ExecuteParallelAndRegisterResult( // Asynchronously launch the computation. TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase result, + std::unique_ptr result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); if (replica == 0 && profile != nullptr) { @@ -557,17 +555,20 @@ Service::ExecuteParallelAndRegisterResult( // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { - result_handles.push_back(allocation_tracker_.Register( - backend, replicas[0]->device_ordinal(), result, - executables[i]->result_shape(), result_tags[i])); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle handle, + allocation_tracker_.Register(std::move(result), result_tags[i])); + result_handles.push_back(handle); } } } // Wait for all executions to complete. for (int64 i = 0; i < streams.size(); ++i) { - if (!streams[i]->BlockHostUntilDone()) { - return InternalError("failed to complete execution for stream %lld", i); + Status block_status = streams[i]->BlockHostUntilDone(); + if (!block_status.ok()) { + return InternalError("failed to complete execution for stream %lld: %s", + i, block_status.error_message().c_str()); } } @@ -578,7 +579,7 @@ Service::ExecuteParallelAndRegisterResult( se::Stream* stream = index_to_profiled_stream.second; Executable* executable = executables[device]; const HloModule& module = executable->module(); - HloExecutionProfile hlo_profile(&executable->hlo_profile_printer(), + HloExecutionProfile hlo_profile(&executable->hlo_profile_printer_data(), &executable->hlo_profile_index_map()); TF_RETURN_IF_ERROR( executable->PopulateExecutionProfile(&hlo_profile, stream->parent())); @@ -625,8 +626,7 @@ Service::ExecuteParallelAndRegisterResult( StatusOr Service::ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile) { // Set up streams. @@ -651,6 +651,7 @@ StatusOr Service::ExecuteAndRegisterResult( for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); + options.set_device_ordinal(stream->parent()->device_ordinal()); options.set_allocator(backend->memory_allocator()); options.set_inter_op_thread_pool(backend->inter_op_thread_pool()); options.set_intra_op_thread_pool( @@ -660,24 +661,21 @@ StatusOr Service::ExecuteAndRegisterResult( backend->inter_op_thread_pool()); } - perftools::gputools::DeviceMemoryBase result; + std::unique_ptr result; if (options_.number_of_replicas() == 1) { - TF_ASSIGN_OR_RETURN( - result, executable->ExecuteOnStreamWrapper( - &run_options[0], profile, arguments)); + TF_ASSIGN_OR_RETURN(result, executable->ExecuteOnStreamWrapper( + &run_options[0], profile, arguments)); } else { - std::vector< - tensorflow::gtl::ArraySlice> + // TODO(b/69985541): Support profiling also on this path. + std::vector> repeated_arguments(options_.number_of_replicas(), arguments); TF_ASSIGN_OR_RETURN(auto results, executable->ExecuteOnStreams( run_options, repeated_arguments)); TF_RET_CHECK(!results.empty()); - result = results[0]; + result = std::move(results[0]); } - return allocation_tracker_.Register(backend, executor->device_ordinal(), - result, executable->result_shape(), - result_tag); + return allocation_tracker_.Register(std::move(result), result_tag); } tensorflow::Status Service::SetReturnValue(const SetReturnValueRequest* arg, @@ -691,7 +689,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, ExecuteParallelResponse* result) { VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); - std::vector> all_arguments; + std::vector> all_arguments; std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; @@ -748,20 +746,16 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // In the case of partitioned computations, assume all arguments go on the // zeroth core. TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(request.arguments(), executors[0]->device_ordinal())); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } // Create an HloModuleConfig object for the computation, given the shape of // the program and the argument allocations. - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - request.execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, + request.execution_options(), *user_computation)); VLOG(3) << "ExecuteParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -780,10 +774,14 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Build the user computations into HloModules and compile to generate the // executables. + // + // TODO(jlebar): There's currently no way to pass a device allocator to + // ExecuteParallel, so we have to pass a null device_allocator below. TF_ASSIGN_OR_RETURN( std::vector> executables, BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), all_executors)); + execute_backend_.get(), all_executors, + /*device_allocator=*/nullptr)); std::vector executable_ptrs; executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { @@ -863,35 +861,31 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options(), + *user_computation)); VLOG(3) << "Execute created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - TF_ASSIGN_OR_RETURN( std::shared_ptr executable, BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), + execute_backend_.get(), execute_backend_->default_stream_executor(), result->mutable_profile())); if (executable->dumping()) { executable->session_module()->set_execution_platform( execute_backend_->platform()->Name()); - TF_RETURN_IF_ERROR( - RecordArguments(arg_allocations, executable->session_module())); + TF_RETURN_IF_ERROR(RecordArguments( + arguments, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); } TF_ASSIGN_OR_RETURN( @@ -902,10 +896,11 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, "result of " + user_computation->name(), result->mutable_profile())); if (executable->dumping()) { - TF_ASSIGN_OR_RETURN(const Allocation* result_allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* result_buffer, allocation_tracker_.Resolve(result->output())); - TF_RETURN_IF_ERROR( - RecordResult(result_allocation, executable->session_module())); + TF_RETURN_IF_ERROR(RecordResult( + *result_buffer, execute_backend_->default_stream_executor(), + execute_backend_->transfer_manager(), executable->session_module())); TF_RETURN_IF_ERROR(executable->DumpSessionModule()); } @@ -931,31 +926,25 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, user_computation->ComputeProgramShape(versioned_handle.version)); TF_ASSIGN_OR_RETURN( - std::vector arg_allocations, - ResolveAndValidateArguments(arg->arguments(), execute_backend_.get(), + std::vector arguments, + ResolveAndValidateArguments(arg->arguments(), execute_backend_->default_device_ordinal())); - TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(*program_shape, arg_allocations, - arg->execution_options())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr module_config, + CreateModuleConfig(*program_shape, arguments, arg->execution_options(), + *user_computation)); VLOG(3) << "ExecuteAsync created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); - std::vector arguments; - arguments.reserve(arg_allocations.size()); - for (const Allocation* allocation : arg_allocations) { - arguments.push_back(allocation->device_memory()); - } - ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::shared_ptr executable, - BuildAndCacheExecutable(versioned_handle, std::move(module_config), - arguments, execute_backend_.get(), - execute_backend_->default_stream_executor(), - &profile)); + BuildAndCacheExecutable( + versioned_handle, std::move(module_config), execute_backend_.get(), + execute_backend_->default_stream_executor(), &profile)); TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); @@ -970,7 +959,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, streams.push_back(std::move(stream)); } - perftools::gputools::DeviceMemoryBase result_data; + std::unique_ptr result_buffer; for (const Pool::SmartPtr& stream : streams) { ExecutableRunOptions options; options.set_stream(stream.get()); @@ -983,19 +972,19 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg, options, execute_backend_->StreamBorrower()); TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase this_result_data, + std::unique_ptr this_result_buffer, executable->ExecuteAsyncOnStream(&service_options, arguments)); // Take the first result. - if (result_data == nullptr) { - result_data = this_result_data; + if (result_buffer == nullptr) { + result_buffer = std::move(this_result_buffer); } } - auto output = allocation_tracker_.Register( - execute_backend_.get(), execute_backend_->default_device_ordinal(), - result_data, executable->result_shape(), - "result of " + user_computation->name()); + TF_ASSIGN_OR_RETURN( + GlobalDataHandle output, + allocation_tracker_.Register(std::move(result_buffer), + "result of " + user_computation->name())); *result->mutable_execution() = execution_tracker_.Register( execute_backend_.get(), std::move(streams), profile, output); @@ -1022,37 +1011,58 @@ tensorflow::Status Service::WaitForExecution(const WaitForExecutionRequest* arg, tensorflow::Status Service::TransferToClient(const TransferToClientRequest* arg, TransferToClientResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* shaped_buffer, allocation_tracker_.Resolve(arg->data())); - const Shape* literal_shape; + const Shape* return_shape; if (arg->has_shape_with_layout()) { if (!LayoutUtil::HasLayout(arg->shape_with_layout())) { return InvalidArgument("shape_with_layout must have layout if present."); } - literal_shape = &arg->shape_with_layout(); + return_shape = &arg->shape_with_layout(); } else { - literal_shape = &allocation->shape(); + return_shape = &shaped_buffer->on_host_shape(); } - Literal literal; - TF_RETURN_IF_ERROR( - LiteralFromAllocation(allocation, *literal_shape, &literal)); - *result->mutable_literal() = literal.ToProto(); + TF_ASSIGN_OR_RETURN( + se::StreamExecutor * executor, + execute_backend_->stream_executor(shaped_buffer->device_ordinal())); + + TF_ASSIGN_OR_RETURN( + std::unique_ptr result_literal, + execute_backend_->transfer_manager()->TransferLiteralFromDevice( + executor, *shaped_buffer)); + + if (LayoutUtil::LayoutsInShapesEqual(*return_shape, + result_literal->shape())) { + *result->mutable_literal() = result_literal->ToProto(); + } else { + *result->mutable_literal() = + result_literal->Relayout(*return_shape)->ToProto(); + } return tensorflow::Status::OK(); } +namespace { + +// Creates a clone of the given shaped buffer with the given device ordinal. The +// shape and DeviceMemoryBase values of the clone are identical to the original. +std::unique_ptr CloneShapedBufferOnDevice( + const ShapedBuffer& shaped_buffer, int device_ordinal) { + auto clone = MakeUnique( + shaped_buffer.on_host_shape(), shaped_buffer.on_device_shape(), + shaped_buffer.platform(), device_ordinal); + clone->buffers() = shaped_buffer.buffers(); + return clone; +} + +} // namespace + tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, TransferToServerResponse* result) { - Literal literal = Literal(arg->literal()); - const Shape& shape = literal.shape(); - - if (ShapeUtil::IsTuple(shape) && options_.number_of_replicas() > 1) { - // TODO(b/32990684): Tuple transfers to host end up allocating further - // buffers - implement that correctly. - return Unimplemented( - "Tuple transfers to the device not supported with replication."); - } + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + Literal::CreateFromProto(arg->literal())); + const Shape& shape = literal->shape(); std::vector replicas; if (arg->has_device_handle()) { @@ -1063,25 +1073,38 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg, replicas, Replicas(*execute_backend_, SingleComputationDeviceHandle())); } - // Allocate memory on the device, using the stream executor. The size of the - // allocation is obtained by examining the shape of the literal passed from - // the client. An allocation handle is returned in the response. - int64 allocation_size = - execute_backend_->transfer_manager()->GetByteSizeRequirement(shape); - - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase allocation, - execute_backend_->memory_allocator()->Allocate( - replicas[0]->device_ordinal(), allocation_size)); - - *result->mutable_data() = allocation_tracker_.Register( - execute_backend_.get(), replicas[0]->device_ordinal(), allocation, shape, - StrCat("TransferToServer literal of size ", allocation_size)); + // All memory allocation is done on the first replica. The allocations in all + // other replicas mirror the firsts'. + int master_device_ordinal = replicas[0]->device_ordinal(); + TF_ASSIGN_OR_RETURN( + std::unique_ptr shaped_buffer, + execute_backend_->transfer_manager()->AllocateShapedBuffer( + shape, execute_backend_->memory_allocator(), master_device_ordinal)); + // Transfer the data to the replicas. for (se::StreamExecutor* executor : replicas) { - TF_RETURN_IF_ERROR( - execute_backend_->transfer_manager()->TransferLiteralToDevice( - executor, literal, &allocation)); + if (executor->device_ordinal() == master_device_ordinal) { + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, *literal, *shaped_buffer)); + } else { + // The replica is not the master. Create an cloned shaped buffer with + // the replica's device ordinal. This is required because + // TransferLiteralToDevice verifies that the device ordinal of the shaped + // buffer matches that of the executor. + std::unique_ptr clone = + CloneShapedBufferOnDevice(*shaped_buffer, executor->device_ordinal()); + TF_RETURN_IF_ERROR( + execute_backend_->transfer_manager()->TransferLiteralToDevice( + executor, *literal, *clone)); + } } + TF_ASSIGN_OR_RETURN( + *result->mutable_data(), + allocation_tracker_.Register(std::move(shaped_buffer), + StrCat("TransferToServer literal of shape ", + ShapeUtil::HumanString(shape)))); + return tensorflow::Status::OK(); } @@ -1109,8 +1132,10 @@ tensorflow::Status Service::TransferToInfeed(const TransferToInfeedRequest* arg, executor = replicas[arg->replica_id()]; } + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + Literal::CreateFromProto(arg->literal())); return execute_backend_->transfer_manager()->TransferLiteralToInfeed( - executor, Literal(arg->literal())); + executor, *literal); } tensorflow::Status Service::TransferFromOutfeed( @@ -1185,7 +1210,22 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, bool is_constant, user_computation->IsConstant(arg->operand(), arg->parameters_size())); if (!is_constant) { - return InvalidArgument("Operand to ComputeConstant depends on parameter."); + StatusOr op_request_status = + user_computation->LookUpRequestForErrorReporting(arg->operand()); + string op_request_string = ""; + if (op_request_status.ok()) { + op_request_string = op_request_status.ValueOrDie()->ShortDebugString(); + } + return InvalidArgument( + "Operand to ComputeConstant depends on a parameter.\n\n" + " op requested for constant evaluation: %s\n\n" + "This is an internal error that typically happens when the XLA user " + "(e.g. TensorFlow) is attempting to determine a value that must be a " + "compile-time constant (e.g. an array dimension) but it is not capable " + "of being evaluated at XLA compile time.\n\n" + "Please file a usability bug with the framework being used (e.g. " + "TensorFlow).", + op_request_string.c_str()); } // We can't use ComputeProgramShape because it checks that all parameter @@ -1213,7 +1253,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, } TF_ASSIGN_OR_RETURN(std::unique_ptr module_config, - CreateModuleConfig(program_shape, {}, execution_options)); + CreateModuleConfig(program_shape, {}, execution_options, + *user_computation)); // Exclude dead parameter instructions for the purpose of computing constants. TF_ASSIGN_OR_RETURN( @@ -1222,18 +1263,16 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, /*include_unreachable_instructions=*/ false)); - std::vector parameters(arg->parameters_size()); + std::vector> parameters(arg->parameters_size()); for (int64 i = 0; i < arg->parameters_size(); ++i) { - parameters[i] = Literal(arg->parameters(i)); + TF_ASSIGN_OR_RETURN(parameters[i], + Literal::CreateFromProto(arg->parameters(i))); } - std::vector parameter_ptrs; - std::transform(parameters.begin(), parameters.end(), - std::back_inserter(parameter_ptrs), - [](const Literal& literal) { return &literal; }); - HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, - evaluator.Evaluate(*module, parameter_ptrs)); + TF_ASSIGN_OR_RETURN( + auto result_literal, + evaluator.Evaluate>(*module, parameters)); + // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. if (arg->has_output_layout()) { @@ -1246,9 +1285,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, tensorflow::Status Service::GetShape(const GetShapeRequest* arg, GetShapeResponse* result) { - TF_ASSIGN_OR_RETURN(const Allocation* allocation, + TF_ASSIGN_OR_RETURN(const ShapedBuffer* buffer, allocation_tracker_.Resolve(arg->data())); - *result->mutable_shape() = allocation->shape(); + *result->mutable_shape() = buffer->on_host_shape(); return tensorflow::Status::OK(); } @@ -1357,6 +1396,17 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddConcatenateInstruction(arg->concatenate_request()); break; + case OpRequest::kConditionalRequest: { + TF_ASSIGN_OR_RETURN(UserComputation * true_computation, + computation_tracker_.Resolve( + arg->conditional_request().true_computation())); + TF_ASSIGN_OR_RETURN(UserComputation * false_computation, + computation_tracker_.Resolve( + arg->conditional_request().false_computation())); + handle_status = computation->AddConditionalInstruction( + arg->conditional_request(), *true_computation, *false_computation); + break; + } case OpRequest::kConstantRequest: handle_status = computation->AddConstantInstruction(arg->constant_request()); @@ -1381,6 +1431,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddCustomCallInstruction(arg->custom_call_request()); break; + case OpRequest::kDotRequest: + handle_status = computation->AddDotInstruction(arg->dot_request()); + break; case OpRequest::kDynamicSliceRequest: handle_status = computation->AddDynamicSliceInstruction(arg->dynamic_slice_request()); @@ -1389,6 +1442,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddDynamicUpdateSliceInstruction( arg->dynamic_update_slice_request()); break; + case OpRequest::kFftRequest: + handle_status = computation->AddFftInstruction(arg->fft_request()); + break; case OpRequest::kGetTupleElementRequest: handle_status = computation->AddGetTupleElementInstruction( arg->get_tuple_element_request()); @@ -1397,9 +1453,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddInfeedInstruction(arg->infeed_request()); break; case OpRequest::kOutfeedRequest: - TF_RETURN_IF_ERROR( - computation->AddOutfeedInstruction(arg->outfeed_request())); - return tensorflow::Status::OK(); + handle_status = + computation->AddOutfeedInstruction(arg->outfeed_request()); + break; case OpRequest::kMapRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, @@ -1501,8 +1557,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddRecvInstruction(arg->recv_request()); break; } + case OpRequest::OP_NOT_SET: + return InvalidArgument("XLA service received OpRequest with OP_NOT_SET"); default: - return InvalidArgument("Unsupported operation"); + return InvalidArgument("Unsupported operation in XLA service"); } TF_ASSIGN_OR_RETURN(*result->mutable_output(), handle_status); @@ -1560,4 +1618,15 @@ StatusOr> Service::Replicas( return replicas; } +Status Service::MaybeDumpHloModule(const HloModule& module) const { + const string xla_dump_unoptimized_hlo_proto_to = + module.config().debug_options().xla_dump_unoptimized_hlo_proto_to(); + if (xla_dump_unoptimized_hlo_proto_to.empty()) { + return Status::OK(); + } + HloProto proto = MakeHloProto(module); + return protobuf_util::DumpProtoToDirectory( + proto, xla_dump_unoptimized_hlo_proto_to, module.name()); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index 47f4f0ade594089aa71717ef1e122886b0a6c7ac..6ce241971156599aaa25aea1b0caac0e1bd5379c 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -250,8 +250,9 @@ class Service : public ServiceInterface { // class. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, - tensorflow::gtl::ArraySlice arguments, - const ExecutionOptions& execution_options); + tensorflow::gtl::ArraySlice arguments, + const ExecutionOptions& execution_options, + const UserComputation& user_computation); protected: friend class LocalExecutable; @@ -265,25 +266,29 @@ class Service : public ServiceInterface { // Resolves the given argument handles in the allocation tracker and returns // the corresponding allocations. The function also verifies that each - // allocation matches the given backend and device ordinal. - StatusOr> ResolveAndValidateArguments( + // allocation matches the execution platform and device ordinal. + StatusOr> ResolveAndValidateArguments( tensorflow::gtl::ArraySlice arguments, - const Backend* backend, int device_ordinal); + int device_ordinal); // Create a Hlo module config for the given program shape and arguments. // execution_options is optional; if not given a default is used. StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options); + const ExecutionOptions* execution_options, + const UserComputation& user_computation); // Builds an Executable for the given parameters. + // + // If device_allocator is not null, the compiler may use it to allocate temp + // buffers, which the compiler is responsible for freeing. The allocator + // given here need not match the allocator used when running the executable. StatusOr> BuildExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor); + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, + DeviceMemoryAllocator* device_allocator = nullptr); // Same as BuildExecutable() above, but builds a list of Executables for the // given computations that may interact with each other. @@ -291,18 +296,17 @@ class Service : public ServiceInterface { std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector> executors); + std::vector> executors, + DeviceMemoryAllocator* device_allocator); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and // inserted into the cache. StatusOr> BuildAndCacheExecutable( const VersionedComputationHandle& versioned_handle, - std::unique_ptr module_config, - const tensorflow::gtl::ArraySlice - arguments, - Backend* backend, perftools::gputools::StreamExecutor* executor, - ExecutionProfile* profile); + std::unique_ptr module_config, Backend* backend, + perftools::gputools::StreamExecutor* executor, ExecutionProfile* profile, + DeviceMemoryAllocator* device_allocator = nullptr); // Runs the given executable with the given arguments and register the result // in the allocation tracker. The handle of the result from the tracker is @@ -310,8 +314,7 @@ class Service : public ServiceInterface { // ExecutionProfile object which will be filled in with profile data. StatusOr ExecuteAndRegisterResult( Executable* executable, - const tensorflow::gtl::ArraySlice - arguments, + const tensorflow::gtl::ArraySlice arguments, Backend* backend, perftools::gputools::StreamExecutor* executor, const string& result_tag, ExecutionProfile* profile); @@ -320,9 +323,7 @@ class Service : public ServiceInterface { // from the tracker are returned. StatusOr> ExecuteParallelAndRegisterResult( tensorflow::gtl::ArraySlice executables, - tensorflow::gtl::ArraySlice< - std::vector> - arguments, + tensorflow::gtl::ArraySlice> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, tensorflow::gtl::ArraySlice result_tags, @@ -347,6 +348,8 @@ class Service : public ServiceInterface { StatusOr> Replicas( const Backend& backend, const DeviceHandle& device_handle) const; + Status MaybeDumpHloModule(const HloModule& module) const; + // Returns the device handle that represents the replicated device for a // single computation that is not model-parallelized. DeviceHandle SingleComputationDeviceHandle() const; diff --git a/tensorflow/compiler/xla/service/service_executable_run_options.h b/tensorflow/compiler/xla/service/service_executable_run_options.h index 017e5ef09ed2f52b862821e9408540d188a1edf5..6c1f8feac7ed4423051cf2737be57dcfab508671 100644 --- a/tensorflow/compiler/xla/service/service_executable_run_options.h +++ b/tensorflow/compiler/xla/service/service_executable_run_options.h @@ -30,6 +30,9 @@ class ServiceExecutableRunOptions { using StreamBorrower = std::function::SmartPtr>(int)>; + ServiceExecutableRunOptions() + : ServiceExecutableRunOptions(ExecutableRunOptions()) {} + explicit ServiceExecutableRunOptions( ExecutableRunOptions run_options, StreamBorrower borrow_stream = nullptr, tensorflow::thread::ThreadPool* xla_intra_op_thread_pool = nullptr) diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 3df1911d07cf0cd123604b1fac63923a725a37c6..004889b5f216015ee1e1308702b2bf4cb0deb344 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -36,6 +37,9 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" +using tensorflow::str_util::Join; +using tensorflow::strings::Printf; + namespace xla { namespace { @@ -90,8 +94,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_ATAN2; case HloOpcode::kComplex: return BINOP_COMPLEX; - case HloOpcode::kDot: - return BINOP_DOT; case HloOpcode::kMultiply: return BINOP_MUL; case HloOpcode::kAdd: @@ -207,7 +209,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, } // Check that init_value's shape is suitable for reducer_shape. - if (!ShapeUtil::Compatible(accumulator_shape, init_value_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + init_value_shape)) { return InvalidArgument( "Reduction function's accumulator shape differs from the " "init_value shape: %s vs %s", @@ -218,8 +221,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Check that the inputs can be passed in as the second argument. const Shape& input_element_shape = ShapeUtil::MakeShape(input_element_type, {}); - if (!ShapeUtil::Compatible(input_element_shape, - reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(input_element_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape differs from the " "input type element type: %s vs %s", @@ -229,7 +232,8 @@ tensorflow::Status VerifyReducerShape(const ProgramShape& reducer_shape, // Currently the accumulator and inputs must be the same type, // though that restriction could be relaxed. - if (!ShapeUtil::Compatible(accumulator_shape, reducer_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(accumulator_shape, + reducer_shape.parameters(1))) { return InvalidArgument( "Reduction function's second parameter shape currently must " "match the result shape. Got %s vs %s", @@ -392,11 +396,13 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, dimension); } const Shape* arg_shape = nullptr; + PrimitiveType element_type = PRIMITIVE_TYPE_INVALID; for (const Shape* shape : arg_shapes) { TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*shape, "operand of concatenation")); if (!arg_shape) { arg_shape = shape; + element_type = arg_shape->element_type(); continue; } if (ShapeUtil::Rank(*arg_shape) != ShapeUtil::Rank(*shape)) { @@ -407,7 +413,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*arg_shape).c_str(), ShapeUtil::Rank(*shape), ShapeUtil::HumanString(*shape).c_str()); } - if (arg_shape->element_type() != shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { return InvalidArgument( "cannot concatenate arrays with different element types: %s vs %s", PrimitiveType_Name(arg_shape->element_type()).c_str(), @@ -429,6 +435,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(*shape).c_str(), dimension); } } + element_type = ShapeUtil::HigherPrecisionElementType(*shape, *arg_shape); } std::vector new_dimensions(arg_shape->dimensions().begin(), @@ -436,7 +443,7 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, for (size_t i = 1; i < arg_shapes.size(); ++i) { new_dimensions[dimension] += arg_shapes[i]->dimensions(dimension); } - return ShapeUtil::MakeShape(arg_shape->element_type(), new_dimensions); + return ShapeUtil::MakeShape(element_type, new_dimensions); } /* static */ StatusOr ShapeInference::InferConvertShape( @@ -534,7 +541,8 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, ShapeUtil::HumanString(operand_shape).c_str(), padding_config.ShortDebugString().c_str()); } - if (operand_shape.element_type() != padding_value_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + padding_value_shape)) { return InvalidArgument( "the element types of the operands to pad do not match"); } @@ -546,11 +554,118 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, std::max(operand_shape.dimensions(i) - 1, 0LL) * padding_config.dimensions(i).interior_padding(); } - return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions); + return ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(operand_shape, padding_value_shape), + dimensions); +} + +// Current DotDimensionNumbers Requirements: +// +// Contracting Dimensions: +// *) Exactly one contracting dimension on both lhs and rhs. +// *) Contracting dimension size must be the same on both lhs and rhs. +// *) Contracting dimension numbers do not need to be the same (i.e. transposes +// are passed on to emitter implementations). +// +// Batch Dimensions: +// *) Same number of batch dimensions on both lhs and rhs. +// *) Same batch dimension numbers (and sizes) on both lhs and rhs. +// *) Batch dimension numbers must be ordered before contracting and +// non-contracting/non-batch dimension numbers. +// +// Non-Contracting-Non-Batch Dimensions: +// *) Can be 0 (matrix-vector) or 1 (matrix-matrix). +// + +namespace { + +Status ValidateDotDimensionNumbers( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { + // Check that dimension numbers are in range. + auto dims_in_range = + [](const int64 rank, tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + in_range) && + std::all_of(batch_dims.begin(), batch_dims.end(), in_range); + }; + + tensorflow::gtl::ArraySlice lhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.lhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice rhs_contracting_dimensions = + AsInt64Slice(dimension_numbers.rhs_contracting_dimensions()); + tensorflow::gtl::ArraySlice lhs_batch_dimensions = + AsInt64Slice(dimension_numbers.lhs_batch_dimensions()); + tensorflow::gtl::ArraySlice rhs_batch_dimensions = + AsInt64Slice(dimension_numbers.rhs_batch_dimensions()); + + if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions, + lhs_batch_dimensions) || + !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions, + rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is out of range in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that dimension numbers are unique. + auto dims_unique = [](tensorflow::gtl::ArraySlice contracting_dims, + tensorflow::gtl::ArraySlice batch_dims) -> bool { + tensorflow::gtl::FlatSet dim_set; + auto is_unique = [&dim_set](int64 i) -> bool { + return dim_set.insert(i).second; + }; + return std::all_of(contracting_dims.begin(), contracting_dims.end(), + is_unique) && + std::all_of(batch_dims.begin(), batch_dims.end(), is_unique); + }; + + if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) || + !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) { + return InvalidArgument("A dimension number is not unique in dot: %s", + dimension_numbers.DebugString().c_str()); + } + + // Check that the count of non-contracting-non-batch dimensions is in {0, 1}. + const int64 lhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(lhs) - + dimension_numbers.lhs_contracting_dimensions_size() - + dimension_numbers.lhs_batch_dimensions_size(); + const int64 rhs_non_contracting_non_batch_dims = + ShapeUtil::Rank(rhs) - + dimension_numbers.rhs_contracting_dimensions_size() - + dimension_numbers.rhs_batch_dimensions_size(); + if (lhs_non_contracting_non_batch_dims < 0 || + lhs_non_contracting_non_batch_dims > 1 || + rhs_non_contracting_non_batch_dims < 0 || + rhs_non_contracting_non_batch_dims > 1) { + return InvalidArgument( + "batch and contracting dimension number mismatch " + "with rank "); + } + + // Check that batch dimension numbers are ordered before all others, and + // that they are monotonically increasing. + std::vector batch_dim_numbers(lhs_batch_dimensions.size()); + std::iota(batch_dim_numbers.begin(), batch_dim_numbers.end(), 0); + if (!std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + lhs_batch_dimensions.begin()) || + !std::equal(batch_dim_numbers.begin(), batch_dim_numbers.end(), + rhs_batch_dimensions.begin())) { + return InvalidArgument( + "batch dimension numbers must precede non-batch dimensions and be" + "monotonically increasing."); + } + + return Status::OK(); } -/* static */ StatusOr ShapeInference::InferDotOpShape(const Shape& lhs, - const Shape& rhs) { +} // namespace + +/* static */ StatusOr ShapeInference::InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot")); @@ -566,45 +681,71 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, }; // Check if both element types are the same. - if (lhs.element_type() != rhs.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return fail("element types do not match"); } - if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 || - ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) { - return fail("dot only supports rank 1 or 2"); + if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) { + return fail("dot only supports rank 1 or above."); } - // Determine the index of the contracted dimensions for input tensors. - // dimensions -1 of lhs and dimension 0 of rhs are contracted. - int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1); - int64 rhs_contracted_dimension = 0; + // Validate basic properties of dot dimension numbers. + TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers)); + + // Check that there is only one contracting dimension for both lhs and rhs. + if (dimension_numbers.lhs_contracting_dimensions_size() != + dimension_numbers.rhs_contracting_dimensions_size() || + dimension_numbers.lhs_contracting_dimensions_size() != 1) { + return fail("must specify one contracting dimension for both lhs and rhs."); + } - // Check if the contracted dimension sizes are the same. - if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) && - rhs_contracted_dimension < ShapeUtil::Rank(rhs)) && - lhs.dimensions(lhs_contracted_dimension) != - rhs.dimensions(rhs_contracted_dimension)) { - return fail("contracted dimensions mismatch"); + // Check that contracting dimension sizes match. + const int64 lhs_contracting_dimension = + dimension_numbers.lhs_contracting_dimensions(0); + const int64 rhs_contracting_dimension = + dimension_numbers.rhs_contracting_dimensions(0); + if (lhs.dimensions(lhs_contracting_dimension) != + rhs.dimensions(rhs_contracting_dimension)) { + return fail("contracting dimension sizes do not match."); + } + + // Check that number of batch dimensions match. + if (dimension_numbers.lhs_batch_dimensions_size() != + dimension_numbers.rhs_batch_dimensions_size()) { + return fail("must the same number of batch dimensions for lhs and rhs."); + } + + // Check that batch dimension numbers and sizes match. + for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) { + if (dimension_numbers.lhs_batch_dimensions(i) != + dimension_numbers.rhs_batch_dimensions(i) || + lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) != + rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) { + return fail("batch dimension numbers and sizes must match for lhs/rhs."); + } } // The ranks of lhs and rhs are decremented by 1 respectively due to the // contraction, and added for the rank of the result. When an input tensor is // a scalar, its contribution to the rank of the result is 0. // Generate the result dimensions in order, rhs dimensions followed by lhs - // dimensions except the contracted dimensions. + // dimensions except the contracted and batch dimensions. std::vector dimensions; + std::unordered_set rhs_batch_dims( + dimension_numbers.rhs_batch_dimensions().begin(), + dimension_numbers.rhs_batch_dimensions().end()); for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) { - if (i != lhs_contracted_dimension) { + if (i != lhs_contracting_dimension) { dimensions.push_back(lhs.dimensions(i)); } } for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) { - if (i != rhs_contracted_dimension) { + if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) { dimensions.push_back(rhs.dimensions(i)); } } - Shape result = ShapeUtil::MakeShape(lhs.element_type(), dimensions); + Shape result = ShapeUtil::MakeShape( + ShapeUtil::HigherPrecisionElementType(lhs, rhs), dimensions); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(result)); VLOG(2) << "inferred dot shape: " << ShapeUtil::HumanString(result); @@ -635,7 +776,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(rhs).c_str()); } } - return ShapeUtil::MakeShape(lhs.element_type(), output_dimensions); + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + output_dimensions); } /* static */ StatusOr ShapeInference::InferInDimBroadcastShape( @@ -697,6 +839,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // specified in broadcast_dimensions are then changed to match the // corresponding dimension size in smaller_shape. Shape output_shape(larger_shape); + output_shape.set_element_type( + ShapeUtil::HigherPrecisionElementType(larger_shape, smaller_shape)); for (int i = 0; i < smaller_shape.dimensions_size(); ++i) { int64 dimension_to_match = broadcast_dimensions.at(i); @@ -746,7 +890,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(rhs, "rhs of elementwise binary operation")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "binary op %s with different element types: %s and %s", BinaryOperation_Name(operation).c_str(), @@ -765,10 +909,11 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } } - if (ShapeUtil::Compatible(lhs, rhs)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(lhs, rhs)) { // If the shapes are the same other than layout, the output shape is the // same (elementwise op). - return lhs; + return ShapeUtil::ChangeElementType( + lhs, ShapeUtil::HigherPrecisionElementType(lhs, rhs)); } if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { @@ -805,7 +950,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "inferring shape for <%s>(%s, %s) with broadcast_dimensions={%s}", BinaryOperation_Name(operation).c_str(), ShapeUtil::HumanString(lhs).c_str(), ShapeUtil::HumanString(rhs).c_str(), - tensorflow::str_util::Join(broadcast_dimensions, ", ").c_str()); + Join(broadcast_dimensions, ", ").c_str()); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(lhs)); TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(rhs)); @@ -816,8 +961,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( rhs, tensorflow::strings::StrCat("rhs of binary operation ", BinaryOperation_Name(operation)))); switch (operation) { - case BINOP_DOT: - return InferDotOpShape(lhs, rhs); case BINOP_MAX: case BINOP_MIN: case BINOP_SUB: @@ -843,7 +986,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_ASSIGN_OR_RETURN(const Shape& shape, InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions)); - if (lhs.element_type() == F32) { + if (lhs.element_type() == F32 && rhs.element_type() == F32) { return ShapeUtil::ChangeElementType(shape, C64); } else { return Unimplemented("complex component type not supported"); @@ -948,12 +1091,13 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR( ExpectNotTupleOrOpaque(*arg_shapes[i], "operand of map")); - if (ShapeUtil::Compatible(*arg_shapes[i], *arg_shape)) { + if (ShapeUtil::CompatibleIgnoringFpPrecision(*arg_shapes[i], *arg_shape)) { continue; } if (!ShapeUtil::IsTuple(*arg_shapes[i]) && !ShapeUtil::IsTuple(*arg_shape) && - ShapeUtil::SameElementType(*arg_shapes[i], *arg_shape)) { + ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shapes[i], + *arg_shape)) { if (ShapeUtil::IsScalar(*arg_shapes[i])) { continue; } @@ -970,7 +1114,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Map operation requires all operands to have the same shape; got: " "%s", - tensorflow::str_util::Join(pieces, ", ").c_str()); + Join(pieces, ", ").c_str()); } // Check that dimensions.size == arg_shape.dimensions_size() (we currently @@ -987,7 +1131,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( if (dimensions[i] != i) { return InvalidArgument( "Map requires monotonically increasing dimension numbers, found: %s ", - tensorflow::str_util::Join(dimensions, ", ").c_str()); + Join(dimensions, ", ").c_str()); } } @@ -1018,7 +1162,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( i, ShapeUtil::HumanString(parameter_shape).c_str()); } - if (parameter_shape.element_type() != arg_shape->element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(parameter_shape, + *arg_shape)) { return InvalidArgument( "mapped computation's parameter type has to match argument element " "type; got parameter %d shape: %s, argument shape: %s", @@ -1091,7 +1236,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of offset factor is %s " @@ -1100,7 +1246,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-training, " "but the shape of scale factor is %s " @@ -1199,7 +1346,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(offset_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(offset_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1209,7 +1357,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1219,7 +1368,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1229,7 +1379,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(variance_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(variance_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for " "batch-norm-inference, " @@ -1351,7 +1502,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(output_grad_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(output_grad_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(output_grad_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of output_grad is %s " @@ -1360,7 +1512,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(scale_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(scale_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of scale factor is %s " @@ -1369,7 +1522,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(mean_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(mean_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1378,7 +1532,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( PrimitiveType_Name(operand_shape.element_type()).c_str()); } - if (!ShapeUtil::SameElementType(var_shape, operand_shape)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(var_shape, + operand_shape)) { return InvalidArgument( "The inputs should have the same element type for batch-norm-grad, " "but the element type of mean is %s " @@ -1439,7 +1594,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of convolution")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of convolution")); - if (!ShapeUtil::SameElementType(lhs, rhs)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(lhs, rhs)) { return InvalidArgument( "Convolution with different element types: %s and %s", ShapeUtil::HumanString(lhs).c_str(), @@ -1584,15 +1739,107 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( dimensions[dnums.output_spatial_dimensions(i)] = window_output_shape.dimensions(i); } + return ShapeUtil::MakeShape(ShapeUtil::HigherPrecisionElementType(lhs, rhs), + dimensions); +} - return ShapeUtil::MakeShape(lhs.element_type(), dimensions); +/* static */ StatusOr ShapeInference::InferFftShape( + const Shape& in, const FftType fft_type, + const tensorflow::gtl::ArraySlice fft_length) { + const int64 fft_rank = fft_length.size(); + if (fft_rank < 1 || fft_rank > 3) { + return InvalidArgument("FFT only supports ranks 1-3, but got %lld", + fft_rank); + } +#define RET_CHECK_RANK(x) \ + if (x.dimensions_size() < fft_rank) { \ + return InvalidArgument( \ + "FFT of rank %lld requires input of at least " \ + "same rank; got input of rank %d", \ + fft_rank, x.dimensions_size()); \ + } + switch (fft_type) { + case FFT: + case IFFT: + if (in.element_type() != C64) { + return InvalidArgument("%s requires C64 input type, found %s", + FftType_Name(fft_type).c_str(), + PrimitiveType_Name(in.element_type()).c_str()); + } + RET_CHECK_RANK(in); + return in; + case RFFT: { + if (in.element_type() != F32) { + return InvalidArgument("RFFT requires F32 input type, found %s", + PrimitiveType_Name(in.element_type()).c_str()); + } + RET_CHECK_RANK(in); + for (int i = 0; i < fft_rank; i++) { + if (in.dimensions(in.dimensions_size() - fft_rank + i) != + fft_length[i]) { + return InvalidArgument( + "RFFT requires innermost dimensions match fft_length but " + "dimension %lld is %lld and should be %lld", + in.dimensions_size() - fft_rank + i, + in.dimensions(in.dimensions_size() - fft_rank + i), + fft_length[i]); + } + } + Shape result = ShapeUtil::ChangeElementType(in, C64); + result.set_dimensions(result.dimensions_size() - 1, + fft_length[fft_rank - 1] / 2 + 1); + return result; + } + case IRFFT: { + if (in.element_type() != C64) { + return InvalidArgument("IRFFT requires C64 input type, found %s", + PrimitiveType_Name(in.element_type()).c_str()); + } + RET_CHECK_RANK(in); + Shape result = ShapeUtil::ComplexComponentShape(in); + for (int i = 0; i < fft_rank - 1; i++) { + if (in.dimensions(in.dimensions_size() - fft_rank + i) != + fft_length[i]) { + return InvalidArgument( + "IRFFT requires all but one innermost dimensions match " + "fft_length, but dimension %lld is %lld and should be %lld", + in.dimensions_size() - fft_rank + i, + in.dimensions(in.dimensions_size() - fft_rank + i), + fft_length[i]); + } + } + if (in.dimensions(in.dimensions_size() - 1) != + fft_length[fft_rank - 1] / 2 + 1) { + return InvalidArgument( + "IRFFT requires innermost dimension matches fft_length/2+1, but " + "dimension %d is %lld and should be %lld", + in.dimensions_size() - 1, in.dimensions(in.dimensions_size() - 1), + fft_length[fft_rank - 1] / 2 + 1); + } + result.set_dimensions(result.dimensions_size() - 1, + fft_length[fft_rank - 1]); + return result; + } + default: + LOG(FATAL) << "Unexpected fft_type: " << fft_type; + } +#undef RET_CHECK_RANK } /* static */ StatusOr ShapeInference::InferCrossReplicaSumShape( - const Shape& operand) { - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(operand, "operand of cross replica sum")); - return operand; + tensorflow::gtl::ArraySlice operand_shapes) { + for (const Shape* operand_shape : operand_shapes) { + TF_RETURN_IF_ERROR( + ExpectNotTupleOrOpaque(*operand_shape, "operand of cross replica sum")); + } + if (operand_shapes.size() == 1) { + return *operand_shapes[0]; + } + std::vector operand_shape_values; + for (const Shape* operand_shape : operand_shapes) { + operand_shape_values.push_back(*operand_shape); + } + return ShapeUtil::MakeTupleShape(operand_shape_values); } /* static */ StatusOr ShapeInference::InferReduceShape( @@ -1655,16 +1902,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( } const Shape& operand_element_shape = ShapeUtil::MakeShape(operand_shape.element_type(), {}); - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(0))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(0))) { return InvalidArgument( "select function's first parameter shape currently must " "match the operand element shape. Got %s vs %s", ShapeUtil::HumanString(select_shape.parameters(0)).c_str(), ShapeUtil::HumanString(operand_element_shape).c_str()); } - if (!ShapeUtil::Compatible(operand_element_shape, - select_shape.parameters(1))) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(operand_element_shape, + select_shape.parameters(1))) { return InvalidArgument( "select function's second parameter shape currently must " "match the operand element shape. Got %s vs %s", @@ -1681,7 +1928,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( InferWindowOutputShape(operand_shape, window, operand_shape.element_type(), /*allow_negative_padding=*/false)); - if (!ShapeUtil::Compatible(source_shape, window_result_shape)) { + if (!ShapeUtil::CompatibleIgnoringFpPrecision(source_shape, + window_result_shape)) { return InvalidArgument( "source shape does not match the shape of window-reduced operand: " "source(%s), window-reduced operand(%s)", @@ -1695,21 +1943,28 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( const Shape& arg, tensorflow::gtl::ArraySlice starts, tensorflow::gtl::ArraySlice limits, tensorflow::gtl::ArraySlice strides) { + auto error = [&](const string& message) { + return InvalidArgument( + "%s in slice operation; argument shape: %s; starts: {%s}; limits: " + "{%s}; strides: {%s}", + message.c_str(), ShapeUtil::HumanString(arg).c_str(), + Join(starts, ",").c_str(), Join(limits, ",").c_str(), + Join(strides, ",").c_str()); + }; TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(arg, "operand of slice")); VLOG(2) << tensorflow::strings::Printf( "slicing shape %s starts={%s} limits={%s}", - ShapeUtil::HumanString(arg).c_str(), - tensorflow::str_util::Join(starts, ", ").c_str(), - tensorflow::str_util::Join(limits, ", ").c_str()); + ShapeUtil::HumanString(arg).c_str(), Join(starts, ", ").c_str(), + Join(limits, ", ").c_str()); if (starts.size() != limits.size()) { - return InvalidArgument("slice start and limit sizes differ: %zu vs %zu", - starts.size(), limits.size()); + return error(Printf("slice start and limit sizes differ: %zu vs %zu", + starts.size(), limits.size())); } if (starts.size() != strides.size()) { - return InvalidArgument("slice start and strides sizes differ: %zu vs %zu", - starts.size(), strides.size()); + return error(Printf("slice start and strides sizes differ: %zu vs %zu", + starts.size(), strides.size())); } if (starts.size() != ShapeUtil::Rank(arg)) { @@ -1728,20 +1983,20 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( start_index); } if (limit_index > arg.dimensions(dimension)) { - return InvalidArgument( - "limit index (%lld) must be less than or equal to dimension " - "size (%lld)", - limit_index, arg.dimensions(dimension)); + return error( + Printf("limit index (%lld) must be less than or equal to dimension " + "size (%lld)", + limit_index, arg.dimensions(dimension))); } VLOG(2) << tensorflow::strings::Printf("starts[%lld] = %lld", dimension, start_index); VLOG(2) << tensorflow::strings::Printf("limits[%lld] = %lld", dimension, limit_index); if (start_index > limit_index) { - return InvalidArgument( - "limit index (%lld) must be greater or equal to " - "start index (%lld) in slice with positive stride", - limit_index, start_index); + return error( + Printf("limit index (%lld) must be greater or equal to " + "start index (%lld) in slice with positive stride", + limit_index, start_index)); } if (stride <= 0) { return InvalidArgument("stride (%lld) must be positive", stride); @@ -1764,7 +2019,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "slicing shape %s at dynamic start_indices %s with slice_sizes={%s}", ShapeUtil::HumanString(operand_shape).c_str(), ShapeUtil::HumanString(start_indices_shape).c_str(), - tensorflow::str_util::Join(slice_sizes, ", ").c_str()); + Join(slice_sizes, ", ").c_str()); if (ShapeUtil::Rank(start_indices_shape) != 1) { return InvalidArgument( @@ -1857,7 +2112,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::Rank(update_shape), ShapeUtil::Rank(operand_shape)); } - if (operand_shape.element_type() != update_shape.element_type()) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(operand_shape, + update_shape)) { return InvalidArgument( "dynamic update slice update element type does not match argument. " "operand.element_type: %s vs update.element_type: %s", @@ -1958,6 +2214,64 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return init; } +/* static */ StatusOr ShapeInference::InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation) { + if (!ShapeUtil::ShapeIs(predicate, PRED, {})) { + return InvalidArgument("predicate must be a boolean; got %s.", + ShapeUtil::HumanString(predicate).c_str()); + } + + if (true_computation.parameters_size() != 1) { + return InvalidArgument("true_computation must take 1 argument; got %d.", + true_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(true_computation.parameters(0), true_operand)) { + auto true_shape_string = [&]() { + return tensorflow::strings::Printf( + "true_operand: %s; true_computation: %s", + ShapeUtil::HumanString(true_operand).c_str(), + ShapeUtil::HumanString(true_computation).c_str()); + }; + return InvalidArgument( + "true_operand must match the shape of the only parameter of " + "true_computation: got %s.", + true_shape_string().c_str()); + } + + if (false_computation.parameters_size() != 1) { + return InvalidArgument("false_computation must take 1 argument; got %d.", + false_computation.parameters_size()); + } + if (!ShapeUtil::Compatible(false_computation.parameters(0), false_operand)) { + auto false_shape_string = [&]() { + return tensorflow::strings::Printf( + "false_operand: %s; false_computation: %s", + ShapeUtil::HumanString(false_operand).c_str(), + ShapeUtil::HumanString(false_computation).c_str()); + }; + return InvalidArgument( + "false_operand must match the shape of the only parameter of " + "false_computation: got %s.", + false_shape_string().c_str()); + } + if (!ShapeUtil::Compatible(true_computation.result(), + false_computation.result())) { + auto shape_string = [&]() { + return tensorflow::strings::Printf( + "true_computation result: %s; false_computation result: %s.", + ShapeUtil::HumanString(true_computation.result()).c_str(), + ShapeUtil::HumanString(false_computation.result()).c_str()); + }; + return InvalidArgument( + "the result of true_computation and false_computation must have the " + "same shape: got %s.", + shape_string().c_str()); + } + return true_computation.result(); +} + /* static */ StatusOr ShapeInference::InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes) { TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "operand of broadcast")); @@ -2003,8 +2317,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InvalidArgument( "Reshape dimensions [%s] are not a permutation of the operand " "dimensions (operand shape is %s).", - tensorflow::str_util::Join(dimensions, ",").c_str(), - ShapeUtil::HumanString(operand).c_str()); + Join(dimensions, ",").c_str(), ShapeUtil::HumanString(operand).c_str()); } return inferred_shape; @@ -2036,24 +2349,26 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(min, "clamp min")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(operand, "clamp operand")); TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(max, "clamp max")); - if (!ShapeUtil::SameElementType(min, operand) || - !ShapeUtil::SameElementType(max, operand)) { + if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(min, operand) || + !ShapeUtil::SameElementTypeIgnoringFpPrecision(max, operand)) { return InvalidArgument("clamp op with different operand types: %s, %s, %s", ShapeUtil::HumanString(min).c_str(), ShapeUtil::HumanString(operand).c_str(), ShapeUtil::HumanString(max).c_str()); } - if (((ShapeUtil::Compatible(min, operand) || ShapeUtil::IsScalar(min)) && - (ShapeUtil::Compatible(max, operand) || ShapeUtil::IsScalar(max)))) { + if (((ShapeUtil::CompatibleIgnoringFpPrecision(min, operand) || + ShapeUtil::IsScalar(min)) && + (ShapeUtil::CompatibleIgnoringFpPrecision(max, operand) || + ShapeUtil::IsScalar(max)))) { return operand; } if (ShapeUtil::IsScalar(operand)) { - if (ShapeUtil::Compatible(min, max)) { - return min; + if (ShapeUtil::CompatibleIgnoringFpPrecision(min, max)) { + return ShapeUtil::ChangeElementType(min, operand.element_type()); } else if (ShapeUtil::IsScalar(min)) { - return max; + return ShapeUtil::ChangeElementType(max, operand.element_type()); } else if (ShapeUtil::IsScalar(max)) { - return min; + return ShapeUtil::ChangeElementType(min, operand.element_type()); } } return Unimplemented( @@ -2066,7 +2381,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // broadcast from all operands, not just the predicate. /* static */ StatusOr ShapeInference::InferSelectShape( const Shape& pred, const Shape& on_true, const Shape& on_false) { - if (!ShapeUtil::Compatible(on_true, on_false)) { + bool compatible; + if (ShapeUtil::IsTuple(on_true)) { + // Select only defines the top-level buffer, so if it's a tuple, the two + // input must match exactly. + compatible = ShapeUtil::Compatible(on_true, on_false); + } else { + compatible = ShapeUtil::CompatibleIgnoringFpPrecision(on_true, on_false); + } + if (!compatible) { return InvalidArgument( "operands to select must be the same shape; got %s and %s", ShapeUtil::HumanString(on_true).c_str(), @@ -2081,7 +2404,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // By this stage we know that pred's element type is PRED. Therefore, this // check restricts pred to be a PRED scalar, or a PRED array with the same // dimensions as on_true and on_false. - return on_true; + return ShapeUtil::ChangeElementType( + on_true, ShapeUtil::HigherPrecisionElementType(on_true, on_false)); } else { return Unimplemented( "select operation with non-scalar predicate with dimensionality " @@ -2096,8 +2420,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { string computation_signature = ShapeUtil::HumanString(to_apply); - string argument_shapes = tensorflow::str_util::Join( - arg_shapes, ", ", [](string* out, const Shape* shape) { + string argument_shapes = + Join(arg_shapes, ", ", [](string* out, const Shape* shape) { tensorflow::strings::StrAppend(out, ShapeUtil::HumanString(*shape)); }); return InvalidArgument( diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 0aadb98a407c2160b60e686f6f3ea250bb9e838f..b39151ebbc19f5d0b702a80da5069f58c8dfb07d 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -109,8 +109,15 @@ class ShapeInference { const Shape& lhs, const Shape& rhs, const Window& window, const ConvolutionDimensionNumbers& dimension_numbers); - // Infers the shape produced a cross replica sum with the given operand shape. - static StatusOr InferCrossReplicaSumShape(const Shape& operand); + // Infers the shape produced by the given FFT type on the given operand. + static StatusOr InferFftShape( + const Shape& in, FftType fft_type, + tensorflow::gtl::ArraySlice fft_length); + + // Infers the shape produced a cross replica sum with the given operand + // shapes. + static StatusOr InferCrossReplicaSumShape( + tensorflow::gtl::ArraySlice operand_shapes); // Infers the shape produced by applying the given reduction computation // shape to the given input operand shape. @@ -178,6 +185,12 @@ class ShapeInference { const ProgramShape& body, const Shape& init); + // Infers the shape produced by a conditional operation. + static StatusOr InferConditionalShape( + const Shape& predicate, const Shape& true_operand, + const Shape& false_operand, const ProgramShape& true_computation, + const ProgramShape& false_computation); + // Infers the shape produced by a broadcast operation. static StatusOr InferBroadcastShape( const Shape& operand, tensorflow::gtl::ArraySlice broadcast_sizes); @@ -229,11 +242,13 @@ class ShapeInference { tensorflow::gtl::ArraySlice arg_shapes, const ProgramShape& to_apply); - private: // Helper that infers the shape produced by performing a dot operation with // the given LHS and RHS shapes. - static StatusOr InferDotOpShape(const Shape& lhs, const Shape& rhs); + static StatusOr InferDotOpShape( + const Shape& lhs, const Shape& rhs, + const DotDimensionNumbers& dimension_numbers); + private: // Helper that infers the shape produced by performing an element-wise binary // operation with the given LHS and RHS shapes. // Note: By "element-wise" we mean operations that look at a single element in diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index be93c879c0b7fd74c3b93e28c6dc0f5c656a522a..026c021165785bd3945d6a846dae446ad45da9b7 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -898,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) { // scalar vector: error TEST_F(ShapeInferenceTest, ScalarDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {}); + ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), HasSubstr("dot only supports rank")); @@ -907,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) { // 3D 2D: error TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = ShapeInference::InferDotOpShape( + ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status.ok()); ASSERT_THAT(inferred_status.status().error_message(), - HasSubstr("dot only supports rank")); + HasSubstr("batch and contracting dimension number mismatch")); } // vector vector -> scalar TEST_F(ShapeInferenceTest, VectorDotVector) { + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); auto inferred_status = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie())); auto inferred_status_mismatch = - ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {}); + ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix vector -> vector TEST_F(ShapeInferenceTest, MatrixDotVector) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // vector matrix -> vector TEST_F(ShapeInferenceTest, VectorDotMatrix) { - auto inferred_status = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status = + ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums); ASSERT_IS_OK(inferred_status.status()); ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_)); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } // matrix matrix -> matrix TEST_F(ShapeInferenceTest, MatrixDotMatrix) { - auto inferred_status_match = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {}); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + auto inferred_status_match = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums); ASSERT_IS_OK(inferred_status_match.status()); ASSERT_TRUE( ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_)) << "inferred: " << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) << " expected: " << ShapeUtil::HumanString(matrix_64_48_); - auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape( - BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {}); + auto inferred_status_mismatch = + ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums); ASSERT_FALSE(inferred_status_mismatch.ok()); } +// BatchMatMul with two batch dimensions and one contracting dimension. +TEST_F(ShapeInferenceTest, DotGeneral) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(1); + + dot_dnums.add_rhs_contracting_dimensions(2); + dot_dnums.add_rhs_batch_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status_match = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_IS_OK(inferred_status_match.status()); + ASSERT_TRUE( + ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape)) + << "inferred: " + << ShapeUtil::HumanString(inferred_status_match.ValueOrDie()) + << " expected: " << ShapeUtil::HumanString(output_shape); +} + +// BatchMatMul with two contracting dimensions fails. +TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("must specify one contracting dimension for both " + "lhs and rhs")); +} + +// BatchMatMul with different batch dimension sizes fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(1); + dot_dnums.add_rhs_batch_dimensions(0); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers and sizes must match")); +} + +// BatchMatMul with different batch dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(2); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("batch dimension numbers must precede non-batch")); +} + +// BatchMatMul with out-of-range dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(3); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is out of range")); +} + +// BatchMatMul with non-unique dimension numbers fails. +TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) { + Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14}); + + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_lhs_batch_dimensions(0); + + dot_dnums.add_rhs_contracting_dimensions(0); + dot_dnums.add_rhs_batch_dimensions(1); + + auto inferred_status = + ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums); + ASSERT_FALSE(inferred_status.ok()); + ASSERT_THAT(inferred_status.status().error_message(), + HasSubstr("A dimension number is not unique")); +} + TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) { // Test variations of broadcasting a vector for a binary add with a // matrix. @@ -1296,5 +1437,95 @@ TEST_F(ShapeInferenceTest, Transpose) { ShapeUtil::MakeShape(F32, {3, 4, 5, 2}))); } +TEST_F(ShapeInferenceTest, Conditional) { + auto inferred_status0 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_IS_OK(inferred_status0.status()); + EXPECT_TRUE(ShapeUtil::Equal(f32_, inferred_status0.ValueOrDie())); + + auto inferred_status1 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, vector_32_, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_64_), + ShapeUtil::MakeProgramShape({vector_32_}, vector_64_)); + EXPECT_IS_OK(inferred_status1.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_64_, inferred_status1.ValueOrDie())); + + auto tuple_f32_v32 = ShapeUtil::MakeTupleShape({f32_, vector_32_}); + auto inferred_status2 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, tuple_f32_v32, + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({tuple_f32_v32}, vector_32_)); + EXPECT_IS_OK(inferred_status2.status()); + EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status2.ValueOrDie())); + + auto inferred_status_error0 = ShapeInference::InferConditionalShape( + s32_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error0.ok()); + EXPECT_THAT(inferred_status_error0.status().error_message(), + HasSubstr("predicate must be a boolean")); + + auto inferred_status_error1 = ShapeInference::InferConditionalShape( + pred_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), matrix_32_48_, + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_)); + EXPECT_FALSE(inferred_status_error1.ok()); + EXPECT_THAT(inferred_status_error1.status().error_message(), + HasSubstr("true_computation must take 1 argument")); + + auto inferred_status_error2 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_64_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, f32_)); + EXPECT_FALSE(inferred_status_error2.ok()); + EXPECT_THAT(inferred_status_error2.status().error_message(), + HasSubstr("true_operand must match the shape of the only " + "parameter of true_computation")); + + auto inferred_status_error3 = ShapeInference::InferConditionalShape( + pred_, matrix_32_48_, ShapeUtil::MakeTupleShape({f32_, vector_32_}), + ShapeUtil::MakeProgramShape({matrix_32_48_}, vector_32_), + ShapeUtil::MakeProgramShape({f32_, vector_32_}, vector_32_)); + EXPECT_FALSE(inferred_status_error3.ok()); + EXPECT_THAT(inferred_status_error3.status().error_message(), + HasSubstr("false_computation must take 1 argument")); + + auto inferred_status_error4 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + EXPECT_FALSE(inferred_status_error4.ok()); + EXPECT_THAT(inferred_status_error4.status().error_message(), + HasSubstr("false_operand must match the shape of the only " + "parameter of false_computation")); + + auto inferred_status_error5 = ShapeInference::InferConditionalShape( + pred_, vector_32_, vector_64_, + ShapeUtil::MakeProgramShape({vector_32_}, f32_), + ShapeUtil::MakeProgramShape({vector_64_}, vector_32_)); + EXPECT_FALSE(inferred_status_error5.ok()); + EXPECT_THAT(inferred_status_error5.status().error_message(), + HasSubstr("the result of true_computation and false_computation " + "must have the same shape")); +} + +TEST_F(ShapeInferenceTest, BadSlice) { + auto arg = ShapeUtil::MakeShape(F32, {4}); + StatusOr statusor = + ShapeInference::InferSliceShape(arg, {0}, {5}, {1}); + ASSERT_FALSE(statusor.ok()); + + LOG(INFO) << statusor.status(); + + EXPECT_THAT(statusor.status().error_message(), + HasSubstr("less than or equal to dimension size")) + << statusor.status(); + EXPECT_THAT(statusor.status().error_message(), HasSubstr("argument shape")) + << statusor.status(); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index a7539a1a11d2bbd62c780890c6730dbb212307c4..c679d401c3691b14a43ce77cbe953cd4c64a9e92 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -34,58 +34,32 @@ namespace xla { using ::tensorflow::strings::Appendf; -/* static */ StatusOr> -ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, - const se::Platform* platform, - int device_ordinal, - const se::DeviceMemoryBase& buffer) { - if (ShapeUtil::IsTuple(shape)) { - return InvalidArgument("Shape must be an array: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - auto shaped_buffer = - MakeUnique(shape, platform, device_ordinal); - *shaped_buffer->mutable_shape_index_to_buffer_entry()->mutable_element({}) = - 0; - *shaped_buffer->mutable_buffers() = {buffer}; - return std::move(shaped_buffer); -} - -ShapedBuffer::ShapedBuffer(const Shape& shape, const se::Platform* platform, - int device_ordinal) - : shape_(shape), +ShapedBuffer::ShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, + const se::Platform* platform, int device_ordinal) + : on_host_shape_(on_host_shape), + on_device_shape_(on_device_shape), platform_(platform), device_ordinal_(device_ordinal), - shape_index_to_buffer_entry_(shape) {} + buffers_(on_device_shape) {} void ShapedBuffer::clear() { - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { // A default constructed DeviceMemoryBase is a null pointer. - memory_base = se::DeviceMemoryBase(); + pair.second = se::DeviceMemoryBase(); } } -void ShapedBuffer::AddBufferAtIndex( - const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index) { - *mutable_shape_index_to_buffer_entry()->mutable_element(shape_index) = - buffers().size(); - mutable_buffers()->push_back(buffer); -} - -const se::DeviceMemoryBase& ShapedBuffer::buffer( - const ShapeIndex& index) const { - return buffers_[shape_index_to_buffer_entry_.element(index)]; -} - -se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { - return &buffers_[shape_index_to_buffer_entry_.element(index)]; -} - string ShapedBuffer::ToString() const { - string s = "ShapedBuffer(" + platform_->Name() + "):\n"; + string s = tensorflow::strings::StrCat( + "ShapedBuffer(", platform_->Name(), ":", device_ordinal(), + "), on-host shape=" + ShapeUtil::HumanStringWithLayout(on_host_shape()), + ", on-device shape=" + + ShapeUtil::HumanStringWithLayout(on_device_shape()), + ":\n"); ShapeUtil::ForEachSubshape( - shape(), [this, &s](const Shape& subshape, const ShapeIndex& index) { + on_device_shape(), + [this, &s](const Shape& subshape, const ShapeIndex& index) { string shape_str; if (ShapeUtil::IsTuple(subshape)) { shape_str = "tuple"; @@ -105,53 +79,24 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer) { return out; } -/* static */ StatusOr> -ScopedShapedBuffer::Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn) { - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - auto shaped_buffer = - WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); - - // Allocate an appropriate sized buffer for each element in the shape - // including the tuple pointer arrays. - for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { - const ShapeIndex& index = pair.first; - size_t& buffer_entry = pair.second; - TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase memory_base, - shaped_buffer->allocator_->Allocate( - shaped_buffer->device_ordinal(), - shape_size_fn(ShapeUtil::GetSubshape( - shaped_buffer->shape(), index)))); - shaped_buffer->buffers_.push_back(memory_base); - buffer_entry = shaped_buffer->buffers_.size() - 1; - } - - return std::move(shaped_buffer); -} - /* static */ StatusOr> ScopedShapedBuffer::MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator) { auto scoped_buffer = WrapUnique(new ScopedShapedBuffer( - shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal())); + shaped_buffer->on_host_shape(), shaped_buffer->on_device_shape(), + allocator, shaped_buffer->device_ordinal())); scoped_buffer->buffers_ = shaped_buffer->buffers(); - scoped_buffer->shape_index_to_buffer_entry_ = - shaped_buffer->shape_index_to_buffer_entry(); - shaped_buffer->clear(); return std::move(scoped_buffer); } -ScopedShapedBuffer::ScopedShapedBuffer(const Shape& shape, +ScopedShapedBuffer::ScopedShapedBuffer(const Shape& on_host_shape, + const Shape& on_device_shape, DeviceMemoryAllocator* allocator, int device_ordinal) - : ShapedBuffer(shape, allocator->platform(), device_ordinal), + : ShapedBuffer(on_host_shape, on_device_shape, allocator->platform(), + device_ordinal), allocator_(allocator) {} ScopedShapedBuffer::~ScopedShapedBuffer() { @@ -159,7 +104,8 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. std::set deallocated_opaques; - for (se::DeviceMemoryBase& memory_base : buffers_) { + for (auto& pair : buffers_) { + se::DeviceMemoryBase& memory_base = pair.second; if (!memory_base.is_null() && deallocated_opaques.count(memory_base.opaque()) == 0) { deallocated_opaques.insert(memory_base.opaque()); @@ -170,13 +116,10 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } std::unique_ptr ScopedShapedBuffer::release() { - auto shaped_buffer = - MakeUnique(shape(), platform(), device_ordinal()); - - *shaped_buffer->mutable_buffers() = buffers(); - *shaped_buffer->mutable_shape_index_to_buffer_entry() = - shape_index_to_buffer_entry(); + auto shaped_buffer = MakeUnique( + on_host_shape(), on_device_shape(), platform(), device_ordinal()); + shaped_buffer->buffers() = buffers(); clear(); return shaped_buffer; diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index fa88caa13ff734995e8ab0925f17d0d3c26b8fda..d397e47d2ca734458c7dc99baa5c81b16d0fd72b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -31,61 +31,68 @@ limitations under the License. namespace xla { // Class which encapsulates a buffer or set of buffers containing data of a -// particular XLA shape. Used for zero-copy execution interface for a -// XLA client running in the same process as the service (LocalClient), +// particular XLA shape. class ShapedBuffer { public: - // Convenience method which creates a ShapedBuffer of array shape (not a - // tuple). Its single buffer pointer is set to the given value "buffer". The - // given buffer must be large enough to store the given shape as given by - // ShapeUtil::ByteSizeOf. - static StatusOr> MakeArrayShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); - - ShapedBuffer(const Shape& shape, + // Construct a ShapedBuffer with null DeviceMemoryBases at each index. The + // shape of the data on the host and the device may differ because the device + // may have a different representation for different data types. Therefore, + // both the on-host and on-device shape are required. The on-device shape + // determines the number of device allocations (DeviceMemoryBase) held by the + // ShapedBuffer. + ShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, const perftools::gputools::Platform* platform, int device_ordinal); - const Shape& shape() const { return shape_; } + // Returns the shape of the on-host representation of the data held by this + // ShapedBuffer. + const Shape& on_host_shape() const { return on_host_shape_; } + + // Returns the shape of the on-device representation of the data held by this + // ShapedBuffer. + const Shape& on_device_shape() const { return on_device_shape_; } + const perftools::gputools::Platform* platform() const { return platform_; } int device_ordinal() const { return device_ordinal_; } + // Return the root buffer of the shape (shape index {}). + const perftools::gputools::DeviceMemoryBase& root_buffer() const { + return buffer(/*index=*/{}); + } + // Returns the buffer at the given shape index where index is defined as in // ShapeUtil::GetSubshape. const perftools::gputools::DeviceMemoryBase& buffer( - const ShapeIndex& index) const; - perftools::gputools::DeviceMemoryBase* mutable_buffer( - const ShapeIndex& index); - - // Returns the underlying structure which stores the buffer pointers. - const std::vector& buffers() const { - return buffers_; + const ShapeIndex& index) const { + return buffers_.element(index); } - std::vector* mutable_buffers() { - return &buffers_; + + // Sets the device memory buffer at the given index. + void set_buffer(const perftools::gputools::DeviceMemoryBase& buffer, + const ShapeIndex& index) { + *buffers_.mutable_element(index) = buffer; } - // Returns the tree of indices which map to buffer pointers. - const ShapeTree& shape_index_to_buffer_entry() const { - return shape_index_to_buffer_entry_; + // Returns the underlying ShapeTree containing all the device addresses in the + // ShapedBuffer. + const ShapeTree& buffers() const { + return buffers_; } - ShapeTree* mutable_shape_index_to_buffer_entry() { - return &shape_index_to_buffer_entry_; + ShapeTree& buffers() { + return buffers_; } // Set all device memory pointers in the object to null. void clear(); - // Adds a new buffer at the given shape index. - void AddBufferAtIndex(const perftools::gputools::DeviceMemoryBase& buffer, - const ShapeIndex& shape_index); - string ToString() const; protected: - // The shape of the device buffer with layout. - const Shape shape_; + // The shape of the data when represented on the host. + const Shape on_host_shape_; + + // The shape of the data on the device. + const Shape on_device_shape_; // The platform the memory is allocated on. const perftools::gputools::Platform* platform_; @@ -93,14 +100,8 @@ class ShapedBuffer { // The device the memory is allocated on. const int device_ordinal_; - // The list of DeviceMemoryBase pointers representing this shape. - // Note that there can be a many to one relationship between tuple elements - // and buffers. To account for this, shape_index_to_buffer_entry_ allows us - // to make from a position in a shape to an index into this list. - std::vector buffers_; - - // The tree of indices into buffers_. - ShapeTree shape_index_to_buffer_entry_; + // The tree of device buffers. Its shape is on_device_shape(). + ShapeTree buffers_; }; std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); @@ -110,20 +111,16 @@ std::ostream& operator<<(std::ostream& out, const ShapedBuffer& buffer); // destructed. class ScopedShapedBuffer : public ShapedBuffer { public: - // Return a newly allocated ScopedShapedBuffer of an arbitrary shape. Array - // buffers (leaves in the shape) are allocated and uninitialized. Tuple - // buffers (if any) are allocated and initialized to the backend-specific - // representation of an array of pointers to the tuple elements. - static StatusOr> Allocate( - const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal, - const std::function& shape_size_fn); - // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the // deallocation of the device memory held in the shaped buffer. All device // memory pointers in the given ShapedBuffer are set to null. static StatusOr> MakeScoped( ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); + // Create a ScopedShapedBuffer with null DeviceMemoryBases at each index. + ScopedShapedBuffer(const Shape& on_host_shape, const Shape& on_device_shape, + DeviceMemoryAllocator* allocator, int device_ordinal); + // Return the allocator used to allocate the device memory held in this // ScopedShapedBuffer. DeviceMemoryAllocator* memory_allocator() const { return allocator_; } @@ -138,8 +135,6 @@ class ScopedShapedBuffer : public ShapedBuffer { virtual ~ScopedShapedBuffer(); protected: - ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator, - int device_ordinal); ScopedShapedBuffer(const ScopedShapedBuffer&) = delete; void operator=(const ScopedShapedBuffer&) = delete; diff --git a/tensorflow/compiler/xla/service/source_map_util.cc b/tensorflow/compiler/xla/service/source_map_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cbaac7b3760717bcacb57adc8782a5755c0aa6d --- /dev/null +++ b/tensorflow/compiler/xla/service/source_map_util.cc @@ -0,0 +1,66 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/source_map_util.h" + +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace source_map_util { +namespace { + +Status InvalidParameterArgumentV(const OpMetadata& op_metadata, + const char* format, va_list args) { + string message; + tensorflow::strings::Appendv(&message, format, args); + if (!op_metadata.source_file().empty()) { + tensorflow::strings::Appendf(&message, " (%s:%d)", + op_metadata.source_file().c_str(), + op_metadata.source_line()); + } + return InvalidArgument("%s", message.c_str()); +} + +} // namespace + +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const char* format, ...) { + va_list args; + va_start(args, format); + Status result = InvalidParameterArgumentV(op_metadata, format, args); + va_end(args); + return result; +} + +Status InvalidParameterArgument(Executable* executable, int parameter_number, + const char* format, ...) { + va_list args; + va_start(args, format); + if (executable != nullptr && executable->has_module()) { + const HloModule& module = executable->module(); + const HloComputation& computation = *module.entry_computation(); + HloInstruction* param = computation.parameter_instruction(parameter_number); + const OpMetadata& metadata = param->metadata(); + Status result = InvalidParameterArgumentV(metadata, format, args); + va_end(args); + return result; + } + Status result = InvalidArgumentV(format, args); + va_end(args); + return result; +} + +} // namespace source_map_util +} // namespace xla diff --git a/tensorflow/compiler/xla/service/source_map_util.h b/tensorflow/compiler/xla/service/source_map_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a776d745f4e56ca4f3d2480740259832bbc85011 --- /dev/null +++ b/tensorflow/compiler/xla/service/source_map_util.h @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ + +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/core/platform/macros.h" + +namespace xla { +namespace source_map_util { + +// Creates an INVALID_ARUGMENT status with the given format string. +// +// Also, attempts to extract the OpMetadata for parameter_number on executable +// and append it to the status message for source mapping to user code. +// +// executable may be nullptr, but parameter_number should not be out of bounds +// or a CHECK-failure may occur. +Status InvalidParameterArgument(Executable* executable, int parameter_number, + const char* format, ...) + TF_PRINTF_ATTRIBUTE(3, 4); + +// As above, but takes the parameter metadata directly instead of extracting it +// from the executable. +Status InvalidParameterArgument(const OpMetadata& op_metadata, + const char* format, ...) + TF_PRINTF_ATTRIBUTE(2, 3); + +} // namespace source_map_util +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SOURCE_MAP_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc index d5f53ad56fb019d0ae7c27fc28706f05614ece68..2f36e2b16e0f2eed10aef811dd3cceeba6a5b8a9 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.cc +++ b/tensorflow/compiler/xla/service/transfer_manager.cc @@ -40,6 +40,45 @@ TransferManager::GetPlatformTransferManagers() { return r; } +Status TransferManager::TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest) { + const Shape on_device_shape = HostShapeToDeviceShape(literal.shape()); + TF_RET_CHECK(ShapeUtil::IsArray(on_device_shape)) + << "On-device representation of " + << ShapeUtil::HumanString(literal.shape()) + << " is not an array: " << ShapeUtil::HumanString(on_device_shape); + if (dest.size() < GetByteSizeRequirement(on_device_shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + dest.size(), GetByteSizeRequirement(on_device_shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/literal.shape(), on_device_shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(dest, /*index=*/{}); + return TransferLiteralToDevice(executor, literal, shaped_buffer); +} + +StatusOr> TransferManager::TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source) { + TF_RET_CHECK(ShapeUtil::Equal(HostShapeToDeviceShape(shape), shape)) + << "Shape " << ShapeUtil::HumanString(shape) + << " has a differently shaped representation on-device: " + << ShapeUtil::HumanString(HostShapeToDeviceShape(shape)); + if (source.size() < GetByteSizeRequirement(shape)) { + return FailedPrecondition( + "Allocation on device not large enough for array: " + "%lld < %lld", + source.size(), GetByteSizeRequirement(shape)); + } + ShapedBuffer shaped_buffer(/*on_host_shape=*/shape, shape, + executor->platform(), executor->device_ordinal()); + shaped_buffer.set_buffer(source, /*index=*/{}); + return TransferLiteralFromDevice(executor, shaped_buffer); +} + /* static */ void TransferManager::RegisterTransferManager( se::Platform::Id platform_id, TransferManagerCreationFunction creation_function) { @@ -75,14 +114,12 @@ TransferManager::GetPlatformTransferManagers() { Status TransferManager::WriteTupleIndexTables( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) { - VLOG(2) << "Writing tuple index tables to ShapedBuffer rooted at " - << device_buffer.buffer(/*index=*/{}).opaque() - << "; shape: " << ShapeUtil::HumanString(device_buffer.shape()); + VLOG(2) << "Writing tuple index tables for " << device_buffer; TF_RET_CHECK(executor->device_ordinal() == device_buffer.device_ordinal()); return ShapeUtil::ForEachSubshapeWithStatus( - device_buffer.shape(), + device_buffer.on_device_shape(), [&](const Shape& device_subshape, const ShapeIndex& index) -> Status { if (ShapeUtil::IsTuple(device_subshape)) { se::DeviceMemoryBase device_memory = device_buffer.buffer(index); @@ -97,7 +134,7 @@ Status TransferManager::WriteTupleIndexTables( elements.push_back(device_buffer.buffer(element_index)); element_index.pop_back(); } - return WriteTuplePointersToDevice(executor, elements, device_subshape, + return WriteSingleTupleIndexTable(executor, elements, device_subshape, &device_memory); } @@ -143,31 +180,43 @@ Status TransferManager::TransferBufferToDevice( return Status::OK(); } -StatusOr> -TransferManager::GatherBufferPointersFromTuple( - se::StreamExecutor* executor, const se::DeviceMemoryBase& source, - const Shape& shape) { - TF_RET_CHECK(ShapeUtil::IsTuple(shape)); - - std::set buffer_pointers; - buffer_pointers.insert(source); - - TF_ASSIGN_OR_RETURN(std::vector tuple_elements, - ShallowCopyTupleFromDevice(executor, source, shape)); - for (auto i = 0; i < tuple_elements.size(); ++i) { - const Shape& element_shape = shape.tuple_shapes(i); - if (ShapeUtil::IsTuple(element_shape)) { - TF_ASSIGN_OR_RETURN( - std::set buffer_pointers_in_element, - GatherBufferPointersFromTuple(executor, tuple_elements[i], - element_shape)); - buffer_pointers.insert(buffer_pointers_in_element.begin(), - buffer_pointers_in_element.end()); - } else { - buffer_pointers.insert(tuple_elements[i]); - } +StatusOr> TransferManager::AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal) { + if (!LayoutUtil::HasLayout(on_host_shape)) { + return InvalidArgument( + "Shape must have a layout: %s", + ShapeUtil::HumanStringWithLayout(on_host_shape).c_str()); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(on_host_shape)); + const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape); + TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape)); + + auto shaped_buffer = WrapUnique(new ShapedBuffer( + on_host_shape, on_device_shape, allocator->platform(), device_ordinal)); + + // Allocate an appropriate sized buffer for each element in the shape + // including the tuple pointer arrays. + for (auto& pair : shaped_buffer->buffers()) { + const ShapeIndex& index = pair.first; + se::DeviceMemoryBase& memory_base = pair.second; + const Shape& subshape = ShapeUtil::GetSubshape(on_device_shape, index); + TF_ASSIGN_OR_RETURN(memory_base, + allocator->Allocate(shaped_buffer->device_ordinal(), + GetByteSizeRequirement(subshape))); } - return std::move(buffer_pointers); + + return std::move(shaped_buffer); +} + +StatusOr> +TransferManager::AllocateScopedShapedBuffer(const Shape& on_host_shape, + DeviceMemoryAllocator* allocator, + int device_ordinal) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr unscoped_buffer, + AllocateShapedBuffer(on_host_shape, allocator, device_ordinal)); + return ScopedShapedBuffer::MakeScoped(unscoped_buffer.get(), allocator); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index fdc123e54eb7f754c12510bef551b98da01b585d..9f2b5c4aecf0b52f610171e0c2755de577b2bd9e 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -44,55 +44,47 @@ class TransferManager { // Returns the ID of the platform that this transfer manager acts on. virtual perftools::gputools::Platform::Id PlatformId() const = 0; - // Transfers the region into the provided literal using the provided - // executor. device_shape is the shape, including layout, of the data on the - // device, while literal_shape will be the shape for the literal. device_shape - // and literal_shape must be compatible, but need not have the same layout. - // TODO(b/66694934): Remove TransferLiteral* methods which accept bare - // DeviceMemoryBase. - virtual Status TransferLiteralFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& region, - const Shape& device_shape, const Shape& literal_shape, - Literal* literal) = 0; - - // Transfers the given literal into the provided region output parameter, - // using the given executor. - virtual Status TransferLiteralToDevice( - perftools::gputools::StreamExecutor* executor, const Literal& literal, - perftools::gputools::DeviceMemoryBase* region) = 0; - - // Transfers the data held in the given ShapedBuffer into the provided literal - // using the provided executor. literal_shape will be the shape for the - // literal. The shape of the ShapedBuffer and literal_shape must be - // compatible, but need not have the same layout. + // Returns the shape of the on-device representation for the given shape on + // the host. This is intended for use with ShapedBuffer where buffers are + // pre-allocated by the host, e.g. TransferLiteralToDevice, without the user + // needing to consider device-specific behaviors. + virtual Shape HostShapeToDeviceShape(const Shape& host_shape) const { + return host_shape; + } + + // Returns a literal containing the data held in the given ShapedBuffer. + // using the provided executor. The optional literal_shape will be the shape + // for the literal. The shape of the ShapedBuffer and + // DeviceShape(literal_shape) must be compatible, but need not have the same + // layout. virtual StatusOr> TransferLiteralFromDevice( perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer) = 0; // Transfers the given literal into the previously allocated device memory - // represented by the given ShapedBuffer using the given executor. + // represented by the given ShapedBuffer using the given executor. The shape + // of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible, + // but need not have the same layout virtual Status TransferLiteralToDevice( perftools::gputools::StreamExecutor* executor, const Literal& literal, const ShapedBuffer& device_buffer) = 0; + // Convenience methods for transferring an array to or from the device at a + // known address. This avoids having to construct a ShapedBuffer just to + // transfer an array at a known address. + Status TransferArrayToDevice( + perftools::gputools::StreamExecutor* executor, const Literal& literal, + const perftools::gputools::DeviceMemoryBase& dest); + StatusOr> TransferArrayFromDevice( + perftools::gputools::StreamExecutor* executor, const Shape& shape, + const perftools::gputools::DeviceMemoryBase& source); + // Transfers the given literal into the Infeed interface of the device, // using the given executor. virtual Status TransferLiteralToInfeed( perftools::gputools::StreamExecutor* executor, const Literal& literal) = 0; - // Transfer a memory block of the given size from 'source' buffer to the - // Infeed interface of the device using the given executor. - // - // size is the size to transfer from source in bytes. - // - // source is the source data that must be in the target-dependent layout that - // the Infeed HLO used in the computation expects. - virtual Status TransferBufferToInfeed( - perftools::gputools::StreamExecutor* executor, int64 size, - const void* source) = 0; - // Transfers the given literal from the Outfeed interface of the device, // using the given executor. virtual Status TransferLiteralFromOutfeed( @@ -104,37 +96,26 @@ class TransferManager { tensorflow::gtl::ArraySlice executor) = 0; - // Shallow copy a tuple from the device and create a DeviceMemoryBase object - // for each element in the tuple. A DeviceMemoryBase object refers to the - // buffer containing the data of that element. The DeviceMemoryBase objects - // are returned as a vector. - virtual StatusOr> - ShallowCopyTupleFromDevice( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, - const Shape& shape) = 0; - // Given an allocated ShapedBuffer, constructs the tuple index table(s) in // each buffer of the given ShapedBuffer corresponding to tuple shapes. If the // ShapedBuffer is array-shaped this method does nothing. Status WriteTupleIndexTables(perftools::gputools::StreamExecutor* executor, const ShapedBuffer& device_buffer); - // Returns all buffer pointers that the tuple `source` refers to. Unlike - // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested - // tuples as well. Also, the returned DeviceMemoryBase objects are - // deduplicated. - StatusOr> - GatherBufferPointersFromTuple( - perftools::gputools::StreamExecutor* executor, - const perftools::gputools::DeviceMemoryBase& source, const Shape& shape); - // Determines the byte size requirement for the given shape on the underlying // architecture. This will be used to allocate an appropriately sized memory // region for a host-to-device transfer. virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0; - typedef std::unique_ptr (*TransferManagerCreationFunction)(); + // Allocate a ShapedBuffer which can hold data with the given on-host + // shape. The on-device shape may be different as indicated by + // HostShapeToDeviceShape. + StatusOr> AllocateShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); + StatusOr> AllocateScopedShapedBuffer( + const Shape& on_host_shape, DeviceMemoryAllocator* allocator, + int device_ordinal); ///// // The TransferManager class also serves as a point to register objects for @@ -144,6 +125,7 @@ class TransferManager { // assumed to be a singleton, so no ownership is transferred. // // Precondition: a platform kind must not be registered more than once. + typedef std::unique_ptr (*TransferManagerCreationFunction)(); static void RegisterTransferManager( perftools::gputools::Platform::Id platform_id, TransferManagerCreationFunction transfer_manager); @@ -154,6 +136,17 @@ class TransferManager { const perftools::gputools::Platform* platform); protected: + // Transfer a memory block of the given size from 'source' buffer to the + // Infeed interface of the device using the given executor. + // + // size is the size to transfer from source in bytes. + // + // source is the source data that must be in the target-dependent layout that + // the Infeed HLO used in the computation expects. + virtual Status TransferBufferToInfeed( + perftools::gputools::StreamExecutor* executor, int64 size, + const void* source) = 0; + // Transfer a memory block of the given size from the device source into the // 'destination' buffer. // @@ -172,10 +165,9 @@ class TransferManager { const void* source, perftools::gputools::DeviceMemoryBase* destination); // Writes the given device-memory pointers in 'elements' to the given region - // to construct a tuple in the platform-specific tuple representation. This - // can handle nested tuples as well. In the nested case, the element - // DeviceMemoryBase points to another array of pointers on the device. - virtual Status WriteTuplePointersToDevice( + // to construct a tuple index table in the platform-specific tuple + // representation. + virtual Status WriteSingleTupleIndexTable( perftools::gputools::StreamExecutor* executor, tensorflow::gtl::ArraySlice elements, diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index fb55d4e5433ce666a061256691ea08ee56fde396..83185ac49e9b7c386d10d1cbc4e20dcdfdfd6cae 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -42,7 +42,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoDot( TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < dot.operand_count(); ++i) { auto& operand = *dot.operand(i); - if (operand.IsRank2Transpose() && operand.user_count() == 1) { + if (operand.IsRank2Transpose()) { operand_set.push_back(i); } } @@ -61,8 +61,7 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( TransposeFolding::OperandIndices operand_set; for (int64 i = 0; i < convolution.operand_count(); ++i) { auto& operand = *convolution.operand(i); - if (operand.opcode() == HloOpcode::kTranspose && - operand.user_count() == 1) { + if (operand.opcode() == HloOpcode::kTranspose) { operand_set.push_back(i); } } @@ -102,6 +101,10 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; auto& operand_indices = pair.second; + if (operand_indices.empty()) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); ConvolutionDimensionNumbers new_dnums = dnums; @@ -121,8 +124,9 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { transpose_dimensions[dnums.input_batch_dimension()]); new_dnums.set_input_feature_dimension( transpose_dimensions[dnums.input_feature_dimension()]); - for (const auto& spatial_dimension : dnums.input_spatial_dimensions()) { - CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + for (auto& input_spatial_dimension : + *new_dnums.mutable_input_spatial_dimensions()) { + input_spatial_dimension = transpose_dimensions[input_spatial_dimension]; } new_lhs = &transpose_operand; } else { diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index 6ac32e88f1f4af4743990daecd6c1f66a4e32763..caa1a111ad880b9dee62c1c94e32e8275c196fbf 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -64,9 +64,12 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -104,9 +107,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { HloInstruction* transpose1 = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/transpose0, /*rhs=*/transpose1)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1, 3}), + /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -169,9 +175,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { HloInstruction* transpose_y = builder.AddInstruction(HloInstruction::CreateTranspose( ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0})); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot, - /*lhs=*/x, /*rhs=*/transpose_y)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x, + /*rhs=*/transpose_y, dot_dnums)); HloModule module("test_module"); HloComputation* entry_computation = @@ -376,5 +385,69 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); } +// Test that a transpose of every dimension in the activations gets folded into +// convolution. +TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) { + auto builder = HloComputation::Builder("entry_computation"); + HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}), + /*name=*/"x")); + HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), + /*name=*/"y")); + HloInstruction* transpose_x = + builder.AddInstruction(HloInstruction::CreateTranspose( + ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2})); + auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); + Window window; + for (int i = 0; i < 2; ++i) { + WindowDimension* dim = window.add_dimensions(); + dim->set_padding_low(0); + dim->set_padding_high(0); + dim->set_base_dilation(1); + dim->set_window_dilation(1); + dim->set_stride(1); + dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i))); + } + StatusOr conv_shape = ShapeInference::InferConvolveShape( + transpose_x->shape(), y->shape(), window, dnums); + EXPECT_IS_OK(conv_shape); + HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve( + conv_shape.ValueOrDie(), transpose_x, y, window, dnums)); + + HloModule module("test_module"); + HloComputation* entry_computation = + module.AddEntryComputation(builder.Build(conv)); + FoldTranspose(&module); + + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ( + dnums.input_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(1)); + EXPECT_EQ( + dnums.input_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().input_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(0), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(0)); + EXPECT_EQ( + dnums.output_spatial_dimensions(1), + new_conv->convolution_dimension_numbers().output_spatial_dimensions(1)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 0c848566478a25d4862cb0698e029dacd71f7a6a..657a8fe09ae9df906d695f7f49df72500d611792 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -273,6 +273,16 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } +Status TuplePointsToAnalysis::HandleSlice(HloInstruction* slice) { + // A kSlice instruction aliases its operand if the backend lowers it to an + // in-place implementation. + if (slice->IsInPlaceSlice()) { + CreateCopiedPointsToSet(slice, slice->operand(0)); + return Status::OK(); + } + return DefaultAction(slice); +} + Status TuplePointsToAnalysis::HandleRecvDone(HloInstruction* recv_done) { // RecvDone aliases its input (Recv) tuple element {0} to its output. PointsToSet& points_to_set = CreateEmptyPointsToSet(recv_done); @@ -427,10 +437,15 @@ bool TuplePointsToAnalysis::InstructionDefinesBufferAtIndex( Status TuplePointsToAnalysis::VerifyBuffer(const LogicalBuffer& buffer) const { if (!InstructionDefinesBufferAtIndex(buffer.instruction(), buffer.index())) { - return FailedPrecondition( - "LogicalBuffer %s is ill-defined: instruction %s does not define a " - "buffer at that index", - buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + // kSlice ops that are lowered to an in-place version are expected to not + // define their output buffer. + if (buffer.instruction()->opcode() != HloOpcode::kSlice || + !buffer.instruction()->IsInPlaceSlice()) { + return FailedPrecondition( + "LogicalBuffer %s is ill-defined: instruction %s does not define a " + "buffer at that index", + buffer.ToString().c_str(), buffer.instruction()->name().c_str()); + } } if (buffer.id() < 0 || diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 8928de107eed8c40bbe2130e26fe83ca3802d2f6..c3743b150168ebcf1051050dc511e50c43108c4f 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -199,12 +199,10 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { StatusOr GetBufferDefinedAt( const HloInstruction* instruction, const ShapeIndex& index) const; - // Return a vector containing all BufferAliases of the given logical buffer - // This trivially includes the BufferAlias with same instruction and index as - // the logical buffer itself, so the returned vector is never empty. The - // buffer alias set is the inverse of the points-to set. That is, - // LogicalBuffer B is in the points-to set of instruction I at index N iff - // instruction I, index N is a BufferAlias of B. + // Return a (possibly empty) vector containing all BufferAliases of the given + // logical buffer The buffer alias set is the inverse of the points-to set. + // That is, LogicalBuffer B is in the points-to set of instruction I at index + // N iff instruction I, index N is a BufferAlias of B. using BufferAliasVector = tensorflow::gtl::InlinedVector; const BufferAliasVector& GetBufferAliases(const LogicalBuffer& buffer) const; @@ -250,6 +248,7 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status HandleTuple(HloInstruction* tuple) override; Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; + Status HandleSlice(HloInstruction* slice) override; Status HandleCopy(HloInstruction* copy) override; Status HandleRecvDone(HloInstruction* recv_done) override; Status HandleSend(HloInstruction* send) override; diff --git a/tensorflow/compiler/xla/service/tuple_util.cc b/tensorflow/compiler/xla/service/tuple_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..4a530bb0b20582b303f4af969514748b46fd5064 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +/*static*/ HloInstruction* TupleUtil::ExtractPrefix(HloInstruction* input_tuple, + int64 elements) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + + std::vector tuple_elements; + tuple_elements.reserve(elements); + for (int i = 0; i < elements; i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +/*static*/ HloInstruction* TupleUtil::AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values) { + CHECK(ShapeUtil::IsTuple(input_tuple->shape())); + + HloComputation* computation = input_tuple->parent(); + const Shape& input_shape = input_tuple->shape(); + std::vector tuple_elements; + tuple_elements.reserve(input_shape.tuple_shapes_size()); + for (int i = 0; i < input_shape.tuple_shapes_size(); i++) { + tuple_elements.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + input_shape.tuple_shapes(i), input_tuple, i))); + } + tuple_elements.insert(tuple_elements.end(), trailing_values.begin(), + trailing_values.end()); + return computation->AddInstruction( + HloInstruction::CreateTuple(tuple_elements)); +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_util.h b/tensorflow/compiler/xla/service/tuple_util.h new file mode 100644 index 0000000000000000000000000000000000000000..e5ff9aaa8357fe8e4777d6dee37bbec72e144c06 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class TupleUtil { + public: + // Generates HLO instructions to get a prefix tuple from `input_tuple` (which + // must be of tuple shape) of length `elements`. Returns the root of the + // graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* ExtractPrefix(HloInstruction* input_tuple, + int64 elements); + + // Generates HLO instructions to create a tuple that consists of the values in + // `trailing_values` appended to `input_tuple` (which must be of tuple shape). + // Returns the root of the graph of instructions generated. + // + // The instructions are generated into the computation containing + // `input_tuple`. + static HloInstruction* AppendSuffix( + HloInstruction* input_tuple, + tensorflow::gtl::ArraySlice trailing_values); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/tuple_util_test.cc b/tensorflow/compiler/xla/service/tuple_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..754fd8ef169231827eeb5bfd72aeb596644ca767 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_util_test.cc @@ -0,0 +1,81 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1) { + const char* const hlo_string = R"( +HloModule Module + +ENTRY entry { + p0 = (f32[32,32]{1,0},f32[32,32]{1,0},f32[32,32]{1,0}) parameter(0) + ROOT p1 = f32[32,32]{1,0} parameter(1) +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + + return std::move(module); +} + +TEST(TupleUtilTest, ExtractPrefix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* prefix = TupleUtil::ExtractPrefix(param0, 2); + + EXPECT_THAT(prefix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1))); +} + +TEST(TupleUtilTest, AppendSuffix) { + HloInstruction *param0, *param1; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1)); + + HloInstruction* with_suffix = + TupleUtil::AppendSuffix(param0, {param1, param1}); + + EXPECT_THAT(with_suffix, op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1), + op::GetTupleElement(op::Parameter(0), 2), + op::Parameter(1), op::Parameter(1))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index 4e90491b55a5688e37cbabae0843f584578add55..fead9b92362bcd1974f2dff6e030bc47dfc5aa85 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -88,8 +88,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kAtan2; case BINOP_COMPLEX: return HloOpcode::kComplex; - case BINOP_DOT: - return HloOpcode::kDot; case BINOP_MUL: return HloOpcode::kMultiply; case BINOP_ADD: @@ -371,14 +369,6 @@ StatusOr UserComputation::AddRngInstruction( // Check the number of parameters per RNG distribution. switch (rng_request.distribution()) { - case RandomDistribution::RNG_BERNOULLI: - if (rng_request.parameter_size() != 1) { - return InvalidArgument( - "RNG distribution (%s) expects 1 parameters, but got %d", - RandomDistribution_Name(rng_request.distribution()).c_str(), - rng_request.parameter_size()); - } - break; case RandomDistribution::RNG_NORMAL: case RandomDistribution::RNG_UNIFORM: if (rng_request.parameter_size() != 2) { @@ -765,6 +755,54 @@ StatusOr UserComputation::AddWhileInstruction( return handle; } +StatusOr UserComputation::AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* pred, + LookUpRequest(conditional_request.predicate())); + TF_ASSIGN_OR_RETURN(const OperationRequest* true_operand, + LookUpRequest(conditional_request.true_operand())); + TF_ASSIGN_OR_RETURN(const OperationRequest* false_operand, + LookUpRequest(conditional_request.false_operand())); + + VersionedComputationHandle::Version true_computation_version = + true_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr true_computation_shape, + true_computation.ComputeProgramShape(true_computation_version)); + + VersionedComputationHandle::Version false_computation_version = + false_computation.version(); + TF_ASSIGN_OR_RETURN( + std::shared_ptr false_computation_shape, + false_computation.ComputeProgramShape(false_computation_version)); + + TF_ASSIGN_OR_RETURN(Shape inferred_shape, + ShapeInference::InferConditionalShape( + pred->output_shape(), true_operand->output_shape(), + false_operand->output_shape(), + *true_computation_shape, *false_computation_shape)); + + ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = inferred_shape; + request.add_embedded_computation_versions(true_computation_version); + request.add_embedded_computation_versions(false_computation_version); + *request.mutable_request()->mutable_conditional_request() = + conditional_request; + + VLOG(1) << "AddConditionalInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << conditional_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddBroadcastInstruction( const BroadcastRequest& broadcast_request) { tensorflow::mutex_lock lock(mutex_); @@ -1075,6 +1113,31 @@ StatusOr UserComputation::AddConvolveInstruction( return handle; } +StatusOr UserComputation::AddFftInstruction( + const FftRequest& fft_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* operand, + LookUpRequest(fft_request.operand())); + TF_ASSIGN_OR_RETURN(Shape shape, + ShapeInference::InferFftShape( + operand->output_shape(), fft_request.fft_type(), + AsInt64Slice(fft_request.fft_length()))); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_fft_request() = fft_request; + + VLOG(1) << "AddFftInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << fft_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddCrossReplicaSumInstruction( const CrossReplicaSumRequest& cross_replica_sum_request) { tensorflow::mutex_lock lock(mutex_); @@ -1082,7 +1145,7 @@ StatusOr UserComputation::AddCrossReplicaSumInstruction( TF_ASSIGN_OR_RETURN(const OperationRequest* operand, LookUpRequest(cross_replica_sum_request.operand())); TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferCrossReplicaSumShape( - operand->output_shape())); + {&operand->output_shape()})); ComputationDataHandle handle = CreateComputationDataHandle(); @@ -1122,7 +1185,7 @@ StatusOr UserComputation::AddInfeedInstruction( return handle; } -Status UserComputation::AddOutfeedInstruction( +StatusOr UserComputation::AddOutfeedInstruction( const OutfeedRequest& outfeed_request) { tensorflow::mutex_lock lock(mutex_); @@ -1134,8 +1197,6 @@ Status UserComputation::AddOutfeedInstruction( // Verify that operand is valid. TF_RETURN_IF_ERROR(LookUpRequest(outfeed_request.operand()).status()); - // No handle is returned, but a handle must be assigned to this instruction - // for computation versioning. ComputationDataHandle handle = CreateComputationDataHandle(); OperationRequest& request = (*session_computation_.mutable_requests())[handle.handle()]; @@ -1146,7 +1207,7 @@ Status UserComputation::AddOutfeedInstruction( VLOG(1) << "AddOutfeedInstruction (" << GetVersionedHandleInternal() << "), data handle " << handle.handle() << ": " << outfeed_request.ShortDebugString(); - return Status::OK(); + return handle; } StatusOr UserComputation::AddCallInstruction( @@ -1192,6 +1253,14 @@ StatusOr UserComputation::AddCustomCallInstruction( TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); } + if (tensorflow::StringPiece(custom_call_request.call_target_name()) + .starts_with("$")) { + return InvalidArgument( + "Invalid custom_call_target \"%s\": Call targets that start with '$' " + "are reserved for internal use.", + custom_call_request.call_target_name().c_str()); + } + const ComputationDataHandle handle = CreateComputationDataHandle(); OperationRequest& request = @@ -1207,6 +1276,33 @@ StatusOr UserComputation::AddCustomCallInstruction( return handle; } +StatusOr UserComputation::AddDotInstruction( + const DotRequest& dot_request) { + tensorflow::mutex_lock lock(mutex_); + + TF_ASSIGN_OR_RETURN(const OperationRequest* lhs, + LookUpRequest(dot_request.lhs())); + TF_ASSIGN_OR_RETURN(const OperationRequest* rhs, + LookUpRequest(dot_request.rhs())); + + TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape( + lhs->output_shape(), rhs->output_shape(), + dot_request.dimension_numbers())); + + const ComputationDataHandle handle = CreateComputationDataHandle(); + + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = shape; + *request.mutable_request()->mutable_dot_request() = dot_request; + + VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << dot_request.ShortDebugString(); + return handle; +} + StatusOr UserComputation::AddUnaryInstruction( const UnaryOpRequest& unary_request) { tensorflow::mutex_lock lock(mutex_); @@ -1433,7 +1529,7 @@ StatusOr LookUpRequest( return &session_computation.requests().at(handle_value); } -// Returns the OperationRequestion corresponding to the root (result) of the +// Returns the OperationRequest corresponding to the root (result) of the // session computation. StatusOr GetRoot( VersionedComputationHandle::Version version, @@ -1479,8 +1575,8 @@ UserComputation::ComputeProgramShape( request.request().parameter_request(); int64 param_no = parameter_request.parameter(); // Parameters may be out of order so expand ProgramShape parameters - // until - // it is at least large enough to hold the current parameter number. + // until it is at least large enough to hold the current parameter + // number. while (program_shape->parameters_size() <= param_no) { program_shape->add_parameters(); program_shape->add_parameter_names(); @@ -1594,6 +1690,13 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + PureFunctionalVisitor(session_computation, fft_request.operand(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kCrossReplicaSumRequest: { // TODO(b/33009255): Implmement constant folding for cross replica sum. *is_functional = false; @@ -1629,6 +1732,15 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + PureFunctionalVisitor(session_computation, dot_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, dot_request.rhs(), + num_parameters, visited, is_functional); + break; + } + case OpRequest::kSendRequest: { *is_functional = false; break; @@ -1757,6 +1869,23 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + PureFunctionalVisitor(session_computation, + conditional_request.predicate(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.true_operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + conditional_request.false_operand(), num_parameters, + visited, is_functional); + // TODO(b/32495713): We aren't checking the true and false computations + // themselves. + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -1868,6 +1997,9 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, default: LOG(FATAL) << "Unexpected request type: " << request.request().op_case(); } + if (!*is_functional) { + VLOG(1) << "Non-functional: " << request.request().DebugString(); + } visited->insert(handle.handle()); } @@ -1985,6 +2117,21 @@ UserComputation::GetEmbeddedComputations( break; } + case OpRequest::kConditionalRequest: { + CHECK_EQ(2, request.embedded_computation_versions_size()); + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + const VersionedComputationHandle true_computation_versioned_handle = { + conditional_request.true_computation(), + request.embedded_computation_versions(0)}; + computations.push_back(true_computation_versioned_handle); + const VersionedComputationHandle false_computation_versioned_handle = + {conditional_request.false_computation(), + request.embedded_computation_versions(1)}; + computations.push_back(false_computation_versioned_handle); + break; + } + default: // No embedded computation. break; @@ -2000,6 +2147,24 @@ UserComputation::GetEmbeddedComputations( return computations; } +StatusOr +UserComputation::LookUpRequestForErrorReporting( + const ComputationDataHandle& handle) const { + tensorflow::mutex_lock lock(mutex_); + return LookUpRequest(handle); +} + +tensorflow::gtl::optional UserComputation::ParameterMetadata( + int parameter_number) const { + tensorflow::mutex_lock lock(mutex_); + auto it = parameters_.find(parameter_number); + if (it == parameters_.end()) { + return tensorflow::gtl::nullopt; + } + OperationRequest* op = it->second; + return &op->request().metadata(); +} + Status UserComputation::RemapEmbeddedComputations( const std::map& old_to_new) { auto update = [&old_to_new](ComputationHandle* to_update) -> Status { @@ -2071,6 +2236,16 @@ Status UserComputation::RemapEmbeddedComputations( TF_RETURN_IF_ERROR(update(while_request->mutable_body())); break; } + case OpRequest::kConditionalRequest: { + TF_RET_CHECK(2 == request.embedded_computation_versions_size()); + ConditionalRequest* conditional_request = + request.mutable_request()->mutable_conditional_request(); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_true_computation())); + TF_RETURN_IF_ERROR( + update(conditional_request->mutable_false_computation())); + break; + } default: // No embedded computation. TF_RET_CHECK(0 == request.embedded_computation_versions_size()); @@ -2274,6 +2449,12 @@ static void ForEachOperand( break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + apply(fft_request.operand()); + break; + } + case OpRequest::kBatchNormTrainingRequest: { const BatchNormTrainingRequest& batch_norm_training_request = request.request().batch_norm_training_request(); @@ -2417,6 +2598,15 @@ static void ForEachOperand( break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + apply(conditional_request.predicate()); + apply(conditional_request.true_operand()); + apply(conditional_request.false_operand()); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -2453,6 +2643,13 @@ static void ForEachOperand( break; } + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + apply(dot_request.rhs()); + apply(dot_request.lhs()); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); @@ -2571,48 +2768,11 @@ HloComputation* ComputationLowerer::ResolveComputation( HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( HloInstruction* operand, const Shape& output_shape) { - CHECK(ShapeUtil::IsScalar(operand->shape()) || - ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)); - Shape broadcast_shape = ShapeUtil::MakeShape( - operand->shape().element_type(), AsInt64Slice(output_shape.dimensions())); - // Do explicit broadcast for scalar. - if (ShapeUtil::IsScalar(operand->shape())) { - HloInstruction* broadcast = hlo_builder_.AddInstruction( - HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); - broadcast->set_metadata(operand->metadata()); - if (operand->has_sharding()) { - broadcast->set_sharding(operand->sharding()); - } - return broadcast; - } - // Do explicit broadcast for degenerate broadcast. - std::vector broadcast_dimensions; - std::vector reshaped_dimensions; - for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) { - if (operand->shape().dimensions(i) == output_shape.dimensions(i)) { - broadcast_dimensions.push_back(i); - reshaped_dimensions.push_back(operand->shape().dimensions(i)); - } - } - // Eliminate the size one dimensions. - HloInstruction* reshaped_operand = - hlo_builder_.AddInstruction(HloInstruction::CreateReshape( - ShapeUtil::MakeShape(operand->shape().element_type(), - reshaped_dimensions), - operand)); - reshaped_operand->set_metadata(operand->metadata()); - if (operand->has_sharding()) { - reshaped_operand->set_sharding(operand->sharding()); - } - // Broadcast 'reshape' up to the larger size. - HloInstruction* broadcast = - hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, reshaped_operand, broadcast_dimensions)); - broadcast->set_metadata(operand->metadata()); - if (operand->has_sharding()) { - broadcast->set_sharding(operand->sharding()); - } - return broadcast; + auto fadd = [this](std::unique_ptr x) { + return hlo_builder_.AddInstruction(std::move(x)); + }; + return fadd( + HloInstruction::CreateBroadcastSequence(output_shape, operand, fadd)); } void ComputationLowerer::Visit( @@ -2653,7 +2813,8 @@ void ComputationLowerer::Visit( const ConstantRequest& constant_request = request.request().constant_request(); hlo_instruction = add_instruction(HloInstruction::CreateConstant( - Literal(constant_request.literal()).CloneToUnique())); + Literal::CreateFromProto(constant_request.literal()) + .ConsumeValueOrDie())); break; } @@ -2732,13 +2893,31 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kFftRequest: { + const FftRequest& fft_request = request.request().fft_request(); + HloInstruction* operand = lookup_instruction(fft_request.operand()); + hlo_instruction = add_instruction(HloInstruction::CreateFft( + request.output_shape(), operand, fft_request.fft_type(), + AsInt64Slice(fft_request.fft_length()))); + break; + } + + case OpRequest::kDotRequest: { + const DotRequest& dot_request = request.request().dot_request(); + HloInstruction* lhs = lookup_instruction(dot_request.lhs()); + HloInstruction* rhs = lookup_instruction(dot_request.rhs()); + hlo_instruction = add_instruction(HloInstruction::CreateDot( + request.output_shape(), lhs, rhs, dot_request.dimension_numbers())); + break; + } + case OpRequest::kCrossReplicaSumRequest: { const CrossReplicaSumRequest& cross_replica_sum_request = request.request().cross_replica_sum_request(); HloInstruction* operand = lookup_instruction(cross_replica_sum_request.operand()); hlo_instruction = add_instruction(HloInstruction::CreateCrossReplicaSum( - request.output_shape(), operand)); + request.output_shape(), {operand})); break; } @@ -3021,6 +3200,30 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kConditionalRequest: { + const ConditionalRequest& conditional_request = + request.request().conditional_request(); + CHECK_EQ(2, request.embedded_computation_versions_size()); + VersionedComputationHandle::Version true_computation_version = + request.embedded_computation_versions(0); + HloComputation* true_computation = ResolveComputation( + conditional_request.true_computation(), true_computation_version); + VersionedComputationHandle::Version false_computation_version = + request.embedded_computation_versions(1); + HloComputation* false_computation = ResolveComputation( + conditional_request.false_computation(), false_computation_version); + HloInstruction* predicate = + lookup_instruction(conditional_request.predicate()); + HloInstruction* true_operand = + lookup_instruction(conditional_request.true_operand()); + HloInstruction* false_operand = + lookup_instruction(conditional_request.false_operand()); + hlo_instruction = add_instruction(HloInstruction::CreateConditional( + request.output_shape(), predicate, true_operand, true_computation, + false_operand, false_computation)); + break; + } + case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); @@ -3151,8 +3354,7 @@ void ComputationLowerer::Visit( lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs; rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs; } - if (debug_options_.xla_eliminate_hlo_implicit_broadcast() && - binary_op_request.binop() != BINOP_DOT) { + if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) { if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) { // lhs side is being implicitly broadcast. Change to explicit. lhs = diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 317c631dca2e1ebe6f3c8fbaf1a3e94106034f79..54bb24d6d7fe7aa8cc7c684795e40464e4eb6614 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -133,6 +133,10 @@ class UserComputation { StatusOr AddConvolveInstruction( const ConvolveRequest& convolve_request); + // Enqueues an FFT instruction onto this user computation. + StatusOr AddFftInstruction( + const FftRequest& fft_request); + // Enqueues a cross replica sum instruction onto this user computation. StatusOr AddCrossReplicaSumInstruction( const CrossReplicaSumRequest& cross_replica_sum_request); @@ -142,7 +146,8 @@ class UserComputation { const InfeedRequest& infeed_request); // Enqueues an outfeed instruction onto this user computation. - Status AddOutfeedInstruction(const OutfeedRequest& outfeed_request); + StatusOr AddOutfeedInstruction( + const OutfeedRequest& outfeed_request); // Enqueues a call instruction onto this user computation. StatusOr AddCallInstruction( @@ -153,6 +158,10 @@ class UserComputation { StatusOr AddCustomCallInstruction( const CustomCallRequest& custom_call_request); + // Enqueues a dot instruction onto this user computation. + StatusOr AddDotInstruction( + const DotRequest& dot_request); + // Enqueues a broadcast instruction onto this user computation. StatusOr AddBroadcastInstruction( const BroadcastRequest& broadcast_request); @@ -216,6 +225,12 @@ class UserComputation { const UserComputation& condition_computation, const UserComputation& body_computation); + // Enqueues a conditional instruction on this user computation. + StatusOr AddConditionalInstruction( + const ConditionalRequest& conditional_request, + const UserComputation& true_computation, + const UserComputation& false_computation); + // Enqueues a Send instruction onto this user computation. Status AddSendInstruction(const SendRequest& send_request); @@ -307,6 +322,23 @@ class UserComputation { SessionComputation CloneSessionComputation( VersionedComputationHandle::Version version) const; + // Warning: typically we don't want to look up computation data handles until + // the computation is finished being built, for consistency purposes. We + // expose this routine for error reporting purposes so that we can provide + // more meaningful error messages from the XLA service layer. + // + // Returns the operation request that the handle comes from. + StatusOr LookUpRequestForErrorReporting( + const ComputationDataHandle& handle) const; + + // Retrieves the parameter metadata for the given parameter number. + // + // If the parameter number is invalid for this computation, nullopt is + // returned. When the return value has_value(), nullptr will never be + // the held value. + tensorflow::gtl::optional ParameterMetadata( + int parameter_number) const; + private: // Warning: dangerous mutating operation that doesn't respect versioning. // This is only used at initialization time when constructing from a diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 5afaf226ae0cce7e9afc966c6b4adf838aeebc91..2fa163953f638c0038e9f6bb11ce2a3742e0558c 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -65,8 +65,10 @@ TEST_F(UserComputationTest, SimpleComputation) { OutfeedRequest outfeed_request; *outfeed_request.mutable_operand() = constant_handle; + *outfeed_request.mutable_shape() = kVectorShape; outfeed_request.set_outfeed_config("abc"); - TF_ASSERT_OK(computation.AddOutfeedInstruction(outfeed_request)); + TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle outfeed_handle, + computation.AddOutfeedInstruction(outfeed_request)); auto hlo_resolver = [](const VersionedComputationHandle& handle) { return nullptr; @@ -334,50 +336,5 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { operands[1]->opcode() == HloOpcode::kBroadcast); } -TEST_F(UserComputationTest, SkipDotInEliminatingImplicitBroadcast) { - auto debug_options = DebugOptions(); - debug_options.set_xla_eliminate_hlo_implicit_broadcast(true); - - // %a = Param({1, 3}); - // %b = Param({3, 1}); - // %dot = Dot(%a, %b); - ComputationHandle handle; - handle.set_handle(123); - UserComputation computation("TheComputation", handle); - - ParameterRequest a_request; - *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 3}); - a_request.set_name("a"); - a_request.set_parameter(0); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle, - computation.AddParameterInstruction(a_request)); - - ParameterRequest b_request; - *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {3, 1}); - b_request.set_name("b"); - b_request.set_parameter(1); - TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, - computation.AddParameterInstruction(b_request)); - - BinaryOpRequest dot; - dot.set_binop(BINOP_DOT); - *dot.mutable_lhs() = a_handle; - *dot.mutable_rhs() = b_handle; - TF_ASSERT_OK(computation.AddBinaryInstruction(dot).status()); - - auto hlo_resolver = [](const VersionedComputationHandle& handle) { - return nullptr; - }; - VersionedComputationHandle latest_version = computation.GetVersionedHandle(); - - // Build the HLO computation. - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr hlo_computation, - computation.BuildHloComputation(latest_version.version, hlo_resolver, - debug_options)); - - EXPECT_EQ(3, hlo_computation->instruction_count()); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc new file mode 100644 index 0000000000000000000000000000000000000000..a5f9b01f011ce04f1114c74391a967c62f015221 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc @@ -0,0 +1,296 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/lib/gtl/inlined_vector.h" + +namespace xla { + +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; +using tensorflow::gtl::InlinedVector; + +// Copies `to_hoist` to the computation containing `while_instr`, hoisting its +// operands as needed. All of its transitive operands are expected to be either +// in `hoisted_instructions` or `unhoisted_invariant_instructions`. This +// function hoists the operands in `unhoisted_invariant_instructions` and moves +// them into `hoisted_instructions`. +static void CreateLoopInvariantCopy( + FlatMap* hoisted_instructions, + FlatSet* unhoisted_invariant_instructions, + HloInstruction* while_instr, HloInstruction* to_hoist) { + HloComputation* parent_of_while = while_instr->parent(); + HloComputation* while_body = while_instr->while_body(); + + struct DFSFrame { + HloInstruction* instruction; + int64 operand_index; + }; + + InlinedVector dfs_stack; + dfs_stack.push_back({to_hoist, 0}); + + HloInstruction* while_body_param = while_body->parameter_instruction(0); + HloInstruction* while_operand = while_instr->mutable_operand(0); + + do { + DFSFrame* frame = &dfs_stack.back(); + if (frame->operand_index == frame->instruction->operand_count()) { + HloInstruction* old_instruction = frame->instruction; + + // All of the operands for old_instruction have been cloned, so it is + // time to clone old_instruction itself. + + auto get_new_operand = [&](HloInstruction* old_operand) { + return old_operand == while_body_param + ? while_operand + : FindOrDie(*hoisted_instructions, old_operand); + }; + + InlinedVector new_operands; + c_transform(old_instruction->operands(), std::back_inserter(new_operands), + get_new_operand); + + HloInstruction* new_instruction = + parent_of_while->AddInstruction(old_instruction->CloneWithNewOperands( + old_instruction->shape(), new_operands)); + + InsertOrDie(hoisted_instructions, old_instruction, new_instruction); + + // Approximately half of the instructions that would normally be present + // in unhoisted_invariant_instructions are constants. We save a bit of + // compile time by not putting these in the hashtable. + CHECK_EQ(unhoisted_invariant_instructions->erase(old_instruction), + to_hoist != old_instruction && + old_instruction->opcode() != HloOpcode::kConstant); + dfs_stack.pop_back(); + continue; + } + + HloInstruction* next_operand = + frame->instruction->mutable_operand(frame->operand_index++); + if (hoisted_instructions->count(next_operand) || + next_operand == while_body_param) { + continue; + } + + dfs_stack.push_back({next_operand, 0}); + } while (!dfs_stack.empty()); +} + +// Returns true if `instruction` is worth hoisting only if it lets us hoist some +// instruction using it. The rationale is that hoisting these instructions will +// prevent simplification and fusion in the while body. +static bool NotWorthHoistingIndividually(const HloInstruction& instruction) { + switch (instruction.opcode()) { + default: + return false; + + case HloOpcode::kBitcast: + case HloOpcode::kBroadcast: + case HloOpcode::kConstant: + case HloOpcode::kReverse: + case HloOpcode::kSlice: + case HloOpcode::kTuple: + return true; + + case HloOpcode::kTranspose: + return ShapeUtil::TransposeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape(), instruction.dimensions()); + + case HloOpcode::kReshape: + return ShapeUtil::ReshapeIsBitcast( + /*input_shape=*/instruction.operand(0)->shape(), + /*output_shape=*/instruction.shape()); + } +} + +// Populates `gte_set` with the GetTupleElement instructions in `while_body` +// that access elements in the parameter tuple that don't change across +// iterations. Assumes `while_body` is the body computation of the while loop +// in question. +static void GatherInvariantGTEs(HloComputation* while_body, + FlatSet* gte_set) { + const HloInstruction::InstructionVector root_operands = + while_body->root_instruction()->operands(); + for (int i = 0; i < root_operands.size(); i++) { + HloInstruction* instr = root_operands[i]; + if (instr->opcode() == HloOpcode::kGetTupleElement && + instr->tuple_index() == i && + instr->operand(0) == while_body->parameter_instruction(0) && + ShapeUtil::IsArray(instr->shape())) { + InsertOrDie(gte_set, instr); + } + } +} + +static StatusOr TryHoistingInvariantInstructionsFromWhileBody( + HloInstruction* while_instr) { + auto print_no_metadata = HloPrintOptions{}.set_print_metadata(false); + + if (!ShapeUtil::IsTuple(while_instr->shape())) { + // This restriction leaves one interesting pattern on the table: + // + // while_body(f32[1024, 1024] %param) { + // %value = expensive_op(%param) + // outfeed(%value) + // ROOT = %param + // } + // + // If we see that pattern in the while, instead of generalizing this + // algorithm to work with non-tuples, we should instead add a pass that + // canonicalizes while loops like the above to use a tuple state. + return false; + } + + string while_instr_name = while_instr->ToString(print_no_metadata); + VLOG(2) << "Trying to hoist from " << while_instr_name; + + HloComputation* while_body = while_instr->while_body(); + + // Maps instructions in the while body to instructions hoisted outside the + // while that compute the same value. + FlatMap hoisted_instructions; + + // Contains instructions that can be legally hoisted, but were deemed to be + // unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we + // hoist an instruction in this set, we move it from + // unhoisted_invariant_instructions to hoisted_instructions. + FlatSet unhoisted_invariant_instructions; + + // Invariant GTE's axiomatically satisfy the constraints for + // unhoisted_invariant_instructions -- they can be legally hoisted, but there + // is no benefit to hoisting them unless something that uses it is also + // hoisted. + GatherInvariantGTEs(while_body, &unhoisted_invariant_instructions); + + if (unhoisted_invariant_instructions.empty()) { + // There are no obviously loop invariant elements in the state being + // threaded through the while loop so give up. In theory this precondition + // is too strong -- we could have code that e.g. permutes the elements in + // the while state but uses a select to pick the same value on every + // iteration. + return false; + } + + // instructions_to_replace[i] is hoisted into a loop invariant instruction + // replacement_instructions[i]. + std::vector instructions_to_replace; + std::vector replacement_instructions; + + for (auto* instruction : while_body->MakeInstructionPostOrder()) { + if (instruction->HasSideEffect() || + instruction->opcode() == HloOpcode::kParameter || + !instruction->control_predecessors().empty() || + !instruction->control_successors().empty()) { + continue; + } + + auto is_invariant = [&](HloInstruction* op) { + return hoisted_instructions.find(op) != hoisted_instructions.end() || + unhoisted_invariant_instructions.count(op) || + op->opcode() == HloOpcode::kConstant; + }; + + if (!c_all_of(instruction->operands(), is_invariant)) { + continue; + } + + if (NotWorthHoistingIndividually(*instruction)) { + VLOG(2) << "Adding " << instruction->ToString(print_no_metadata) + << " to unhoisted invariant set."; + // Approximately half of the instructions that reach this point are + // constants. We save a bit of compile time by not putting these in the + // hashtable. + if (instruction->opcode() != HloOpcode::kConstant) { + InsertOrDie(&unhoisted_invariant_instructions, instruction); + } + continue; + } + + VLOG(2) << "Hoisting " << instruction->ToString(print_no_metadata); + + CreateLoopInvariantCopy(&hoisted_instructions, + &unhoisted_invariant_instructions, while_instr, + instruction); + + instructions_to_replace.push_back(instruction); + replacement_instructions.push_back( + FindOrDie(hoisted_instructions, instruction)); + } + + if (instructions_to_replace.empty()) { + return false; + } + + TF_ASSIGN_OR_RETURN( + WhileUtil::MakeInstructionsLiveInResult live_in_instructions_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, replacement_instructions)); + + HloComputation* new_while_body = + live_in_instructions_result.new_while_instr->while_body(); + + for (int i = 0; i < instructions_to_replace.size(); i++) { + HloInstruction* instruction_to_replace_in_new_while = + FindOrDie(live_in_instructions_result.while_body_instruction_map, + instructions_to_replace[i]); + TF_RETURN_IF_ERROR(new_while_body->ReplaceInstruction( + instruction_to_replace_in_new_while, + live_in_instructions_result.while_body_live_in_values[i])); + } + + VLOG(1) << "Hoisted " << instructions_to_replace.size() + << " instructions from " << while_instr_name; + + return true; +} + +StatusOr WhileLoopInvariantCodeMotion::Run(HloModule* module) { + bool changed = false; + std::vector while_instrs; + for (auto* comp : module->computations()) { + c_copy_if(comp->instructions(), std::back_inserter(while_instrs), + [](const HloInstruction* instr) { + return instr->opcode() == HloOpcode::kWhile; + }); + } + + for (HloInstruction* while_instr : while_instrs) { + // Right now we only hoist computations from the while body, but + // TryHoistingInvariantInstructionsFromWhileBody can be generalized to + // optimize the condition computation too, if needed. + // + // The transform we do here is a pessmization for while loops that execute + // zero times*, but at this time we expect those to be rare. If this + // becomes a problem we can consider using the conditional HLO to avoid + // doing extra work for while loops with zero trip count. + // + // * We delete while loops that have a zero trip count, so this would have + // to be a while loop with a somewhat opaque condition expression. + + TF_ASSIGN_OR_RETURN( + bool result, + TryHoistingInvariantInstructionsFromWhileBody(while_instr)); + changed |= result; + } + return changed; +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h new file mode 100644 index 0000000000000000000000000000000000000000..8c4b765b0003c48cfacb9d28e7c8259ac0927d66 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h @@ -0,0 +1,39 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that rewrites while loops to hoist loop invariant instructions in +// the while body into the computation that contains the while instruction. + +class WhileLoopInvariantCodeMotion : public HloPassInterface { + public: + ~WhileLoopInvariantCodeMotion() override = default; + + tensorflow::StringPiece name() const override { + return "while-loop-invariant-code-motion"; + } + StatusOr Run(HloModule* module) override; +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_INVARIANT_CODE_MOTION_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..799340fda905fb7d40b19b4cb79bb0fcb5629fd3 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion_test.cc @@ -0,0 +1,442 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class WhileLoopInvariantCodeMotionTest : public HloVerifiedTestBase { + public: + // Makes a computation which has one parameter, of the given shape, and always + // returns PRED[]{true}. This is useful as a dummy loop condition. + HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, + HloModule* module); +}; + +static void FindOnlyWhileInstruction(HloComputation* computation, + HloInstruction** while_instruction) { + *while_instruction = nullptr; + for (auto* instr : computation->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + ASSERT_EQ(*while_instruction, nullptr); + *while_instruction = instr; + } + } + + ASSERT_NE(*while_instruction, nullptr); +} + +HloComputation* WhileLoopInvariantCodeMotionTest::MakeAlwaysTrueComputation( + const Shape& param_shape, HloModule* module) { + HloComputation::Builder builder(TestName() + ".always_true"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + return module->AddEmbeddedComputation(builder.Build()); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistOneInvariantOperation) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistInvariantOperationTree) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* gte_2_loop_variant = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 2)); + + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + HloInstruction* mul_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kMultiply, add_result, gte_1)); + HloInstruction* negate_result = + builder.AddInstruction(HloInstruction::CreateUnary( + scalar_s32, HloOpcode::kNegate, mul_result)); + HloInstruction* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(4))); + HloInstruction* sub_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kSubtract, negate_result, constant)); + HloInstruction* divide_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kDivide, sub_result, gte_2_loop_variant)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, divide_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(entry_computation->instructions(), + AllOf(Contains(op::Add()), Contains(op::Multiply()), + Contains(op::Negate()), Contains(op::Subtract()), + Contains(op::Constant()), + + // The division had a loop varying operand so that better + // not be hoisted. + Not(Contains(op::Divide())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(AnyOf(op::Add(), op::Multiply(), op::Negate(), + op::Subtract(), op::Constant())))); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Contains(op::Divide())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistTriviallyLoopVaryingComputation) { + // Basic negative test: the add expression is not loop invariant. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, + DontHoistLoopVaryingComputationWithAlternatingTuples) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_1, gte_0, add_result})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), Contains(op::Add())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistInstructionWithSideEffects) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_s32, gte_0, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistBitcastAlone) { + // The bitcast's user, an outfeed, can't be hoisted, so don't hoist the + // bitcast either. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + builder.AddInstruction( + HloInstruction::CreateOutfeed(scalar_f32, bitcast_inst, "")); + builder.AddInstruction(HloInstruction::CreateTuple({gte_0, gte_1})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + auto* while_inst = builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); + + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Outfeed())); + EXPECT_THAT(while_inst->while_body()->instructions(), + Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, HoistBitcastIfNeeded) { + // The bitcast's user can be hoisted, so hoist the bitcast too. + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto scalar_f32 = ShapeUtil::MakeShape(F32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_f32, scalar_f32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_f32, param, 1)); + HloInstruction* bitcast_inst = builder.AddInstruction( + HloInstruction::CreateUnary(scalar_f32, HloOpcode::kBitcast, gte_0)); + HloInstruction* add_inst = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_f32, HloOpcode::kAdd, bitcast_inst, gte_1)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_inst})); + + return module().AddEmbeddedComputation(builder.Build()); + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + + HloComputation* entry_computation = + module().AddEntryComputation(builder.Build()); + + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_TRUE(simplified_loop); + + HloInstruction* transformed_while; + FindOnlyWhileInstruction(entry_computation, &transformed_while); + + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Add()))); + EXPECT_THAT(transformed_while->while_body()->instructions(), + Each(Not(op::Bitcast()))); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Add())); + EXPECT_THAT(entry_computation->instructions(), Contains(op::Bitcast())); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, DontHoistControlDependencies) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32, scalar_s32}); + + HloComputation* while_body; + { + HloComputation::Builder builder(TestName() + ".while_body"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloInstruction* gte_0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)); + HloInstruction* gte_1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + HloInstruction* add_result = + builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, gte_0, gte_1)); + TF_ASSERT_OK(param->AddControlDependencyTo(add_result)); + builder.AddInstruction( + HloInstruction::CreateTuple({gte_0, gte_1, add_result})); + + while_body = module().AddEmbeddedComputation(builder.Build()); + } + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +TEST_F(WhileLoopInvariantCodeMotionTest, BodyHasNonTupleRoot) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".passthrough"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + + result->AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + return result; + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopInvariantCodeMotion{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc index b38ee907d70e29093c5cef718e1432663015728b..981de9b2200a9ae8938db21299580f510834d2f0 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -236,7 +236,7 @@ static optional GetLoopTripCount(HloInstruction* while_op) { VLOG(2) << "Couldn't evaluate while cond: " << result.status(); return nullopt; } - return result.ValueOrDie()->GetArraySlice() == + return result.ValueOrDie()->data() == tensorflow::gtl::ArraySlice{true}; }; @@ -289,7 +289,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { // Don't try this transformation if the while loop isn't removable, since if // it succeeds ultimately we're going to have to replace the old while loop // with a new one. - if (!while_op->parent()->IsRemovable(while_op)) { + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Can't remove dead parameters from non-removable while op."; return false; } @@ -306,6 +306,13 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { return false; } + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple(...) instruction."; + return false; + } + + auto print_no_metadata = HloPrintOptions().set_print_metadata(false); + // Bail if param0 of while_cond or while_body has users which aren't of type // get-tuple-element. for (const HloInstruction* instr : {while_body->parameter_instruction(0), @@ -313,9 +320,10 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { for (const HloInstruction* user : instr->users()) { if (user->opcode() != HloOpcode::kGetTupleElement) { VLOG(2) << "Cowardly refusing to analyze while loop with " - << instr->ToStringNoMetadata() - << " used by non-GTE instruction " << user->ToStringNoMetadata() - << " in computation " << instr->parent()->name(); + << instr->ToString(print_no_metadata) + << " used by non-GTE instruction " + << user->ToString(print_no_metadata) << " in computation " + << instr->parent()->name(); return false; } } @@ -351,7 +359,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.insert(user->tuple_index()); if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) << " uses all of its inputs; no simplification possible."; return false; } @@ -375,7 +383,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { used_tuple_indices.insert(i); if (used_tuple_indices.size() == tuple_size) { - VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + VLOG(2) << "Loop " << while_op->ToString(print_no_metadata) << " uses all of its inputs; no simplification possible."; return false; } @@ -387,7 +395,8 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { CHECK_LT(used_tuple_indices.size(), tuple_size); VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() - << " elements from tuple of " << while_op->ToStringNoMetadata(); + << " elements from tuple of " + << while_op->ToString(print_no_metadata); // Build up maps from the old/new to the new/old tuple indices. std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), @@ -431,7 +440,7 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { continue; } CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement) - << user->ToStringNoMetadata(); + << user->ToString(print_no_metadata); int64 old_idx = user->tuple_index(); auto new_idx_iter = old_to_new_tuple_idx.find(old_idx); @@ -446,14 +455,14 @@ static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { CHECK(user->user_count() == 0 || user->user_count() == 1 && user->users().front() == while_body_root) - << "Instruction " << user->ToStringNoMetadata() + << "Instruction " << user->ToString(print_no_metadata) << " should be unused (except by root of while body), but has " "users: {" << tensorflow::str_util::Join( user->users(), ", ", - [](string* out, const HloInstruction* instr) { + [&](string* out, const HloInstruction* instr) { tensorflow::strings::StrAppend( - out, instr->ToStringNoMetadata()); + out, instr->ToString(print_no_metadata)); }) << "}"; @@ -555,10 +564,12 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { // // This is not a fundamental limitation. The control operands can be moved // onto the new HLOs after simplification, and any side-effecting ops inside - // the loop aren't removed, just cloned and added back to the loop. - // Nevertheless our infrastructure sees loop simplification as removal of - // these nodes and currently doesn't allow it. - if (!while_op->parent()->IsRemovable(while_op)) { + // the loop aren't removed, just cloned and added back to the loop. But + // moving an op out of the loop also removes implicit control dependencies + // between the op and the ops outside the loop, so we'd have to add those back + // for things like infeed/outfeed. It gets complicated. So for now we just + // avoid it. + if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) { VLOG(2) << "Not attempting to remove while loop it is not removable: " << while_op->ToShortString(); return false; @@ -586,7 +597,9 @@ static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { auto call_op = computation->AddInstruction(HloInstruction::CreateCall( while_op->shape(), while_op->operands(), while_op->while_body())); TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); - TF_RETURN_IF_ERROR(CallInliner::Inline(call_op)); + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_op)); + (void)inlined_instructions_map; return true; } return false; diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h index 50dac32a4ab0a5de756c1ddf5e62c3560e54a079..d3d55634c97bbdf3f81321d8089bb808c411340b 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier.h +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" @@ -41,4 +41,4 @@ class WhileLoopSimplifier : public HloPassInterface { } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc index d99b31dc0037968bc88d5f22d53309a6a4546963..c5183f8d3aee99696ed4114c3f7e451888222137 100644 --- a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -418,5 +418,32 @@ TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); } +TEST_F(WhileLoopSimplifierTest, BodyHasNonTupleRoot) { + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + Shape while_shape = ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}); + + HloComputation* while_body = [&]() { + HloComputation::Builder builder(TestName() + ".passthrough"); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "param")); + HloComputation* result = module().AddEmbeddedComputation(builder.Build()); + + result->AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)); + return result; + }(); + + HloComputation::Builder builder(TestName()); + auto* init_value = builder.AddInstruction( + HloInstruction::CreateParameter(0, while_shape, "init_value")); + builder.AddInstruction(HloInstruction::CreateWhile( + while_shape, MakeAlwaysTrueComputation(while_shape, &module()), + while_body, init_value)); + module().AddEntryComputation(builder.Build()); + TF_ASSERT_OK_AND_ASSIGN(bool simplified_loop, + WhileLoopSimplifier{}.Run(&module())); + EXPECT_FALSE(simplified_loop); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.cc b/tensorflow/compiler/xla/service/while_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..e20b25e4a08a946f6b58575a4d4e557744f8035c --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/tuple_util.h" + +namespace xla { + +static StatusOr WidenWhileCondition( + HloComputation* narrow_condition, const Shape& wide_shape) { + const Shape& narrow_shape = + narrow_condition->parameter_instruction(0)->shape(); + + HloComputation* wide_while_cond = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_condition->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + + // This is needed so that the root instruction is shaped as a PRED[] -- we + // need to get this right to begin with since we can't mutate the type of + // the root instruction later. We later change the root instruction to + // something more appropriate. + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + return narrow_condition->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* truncated_parameter = + TupleUtil::ExtractPrefix(wide_while_cond->parameter_instruction(0), + narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_cond = wide_while_cond->AddInstruction( + HloInstruction::CreateCall(ShapeUtil::MakeShape(PRED, {}), + {truncated_parameter}, narrow_condition)); + + wide_while_cond->set_root_instruction(call_narrow_cond); + + TF_RETURN_IF_ERROR(CallInliner::Inline(call_narrow_cond).status()); + return wide_while_cond; +} + +static StatusOr> +WidenWhileBody(HloComputation* narrow_body, const Shape& wide_shape) { + const Shape& narrow_shape = narrow_body->parameter_instruction(0)->shape(); + + HloComputation* wide_while_body = [&]() { + HloComputation::Builder builder( + tensorflow::strings::StrCat("wide.", narrow_body->name())); + builder.AddInstruction( + HloInstruction::CreateParameter(0, wide_shape, "wide_param")); + return narrow_body->parent()->AddEmbeddedComputation(builder.Build()); + }(); + + HloInstruction* wide_parameter = wide_while_body->parameter_instruction(0); + HloInstruction* truncated_parameter = TupleUtil::ExtractPrefix( + wide_parameter, narrow_shape.tuple_shapes_size()); + HloInstruction* call_narrow_body = + wide_while_body->AddInstruction(HloInstruction::CreateCall( + narrow_shape, {truncated_parameter}, narrow_body)); + + std::vector live_through_values; + for (int i = narrow_shape.tuple_shapes_size(); + i < wide_shape.tuple_shapes_size(); i++) { + live_through_values.push_back( + wide_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + wide_shape.tuple_shapes(i), wide_parameter, i))); + } + + wide_while_body->set_root_instruction( + TupleUtil::AppendSuffix(call_narrow_body, live_through_values)); + + TF_ASSIGN_OR_RETURN(auto inlined_instructions_map, + CallInliner::Inline(call_narrow_body)); + return {{wide_while_body, std::move(inlined_instructions_map)}}; +} + +/*static*/ StatusOr +WhileUtil::MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions) { + CHECK(ShapeUtil::IsTuple(while_instr->shape())); + + int64 elements_in_old_while_shape = while_instr->shape().tuple_shapes_size(); + Shape new_while_shape = while_instr->shape(); + for (auto* instruction : instructions) { + *new_while_shape.add_tuple_shapes() = instruction->shape(); + } + + TF_ASSIGN_OR_RETURN( + HloComputation * new_while_condition, + WidenWhileCondition(while_instr->while_condition(), new_while_shape)); + + HloComputation* new_while_body; + CallInliner::InlinedInstructionMap inlined_instructions_map; + TF_ASSIGN_OR_RETURN( + std::tie(new_while_body, inlined_instructions_map), + WidenWhileBody(while_instr->while_body(), new_while_shape)); + + HloInstruction* new_while_init = + TupleUtil::AppendSuffix(while_instr->mutable_operand(0), instructions); + HloComputation* containing_computation = while_instr->parent(); + HloInstruction* new_while = containing_computation->AddInstruction( + HloInstruction::CreateWhile(new_while_shape, new_while_condition, + new_while_body, new_while_init)); + TF_RETURN_IF_ERROR(containing_computation->ReplaceInstruction( + while_instr, TupleUtil::ExtractPrefix( + new_while, while_instr->shape().tuple_shapes_size()))); + + HloInstruction* while_body_param = new_while_body->parameter_instruction(0); + std::vector live_in_instructions; + for (int64 i = elements_in_old_while_shape; + i < new_while_shape.tuple_shapes_size(); i++) { + live_in_instructions.push_back( + new_while_body->AddInstruction(HloInstruction::CreateGetTupleElement( + instructions[i - elements_in_old_while_shape]->shape(), + while_body_param, i))); + } + + WhileUtil::MakeInstructionsLiveInResult result; + + result.new_while_instr = new_while; + result.while_body_live_in_values = std::move(live_in_instructions); + result.while_body_instruction_map = std::move(inlined_instructions_map); + + return std::move(result); +} +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_util.h b/tensorflow/compiler/xla/service/while_util.h new file mode 100644 index 0000000000000000000000000000000000000000..3600b5a80d26e37fdb7d5173c3b8743734306390 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ + +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" + +namespace xla { +class WhileUtil { + public: + // Holds a return value from MakeInstructionsLiveIn. + struct MakeInstructionsLiveInResult { + // The new while operation that has the requested values live in. + HloInstruction* new_while_instr; + + // The i'th element of `while_body_live_in_values` is an instruction in the + // while body that holds the i'th *newly added* live in value at runtime. + std::vector while_body_live_in_values; + + // `while_body_instruction_map` maps instructions in the original while body + // to the corresponding instructions in the body for the newly created while + // operation. + CallInliner::InlinedInstructionMap while_body_instruction_map; + }; + + // Replaces `while_instr` with a new while instruction that is equivalent to + // `while_instr`, except that it has all of the HLO instructions in + // `instructions` as live-in, loop invariant values. These new live in values + // are represented as new elements appended to the parameter of the while + // loop, which must be of tuple shape. GetTupleElement instructions computing + // each new live in value is returned in the `while_body_live_in_values` + // vector. + // + // Precondition: `while_instr` must have a tuple shaped state. + // + // Every instruction in `instructions` must be contained in the computation + // that contains `while_instr`. + static StatusOr MakeInstructionsLiveIn( + HloInstruction* while_instr, + tensorflow::gtl::ArraySlice instructions); +}; +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_UTIL_H_ diff --git a/tensorflow/compiler/xla/service/while_util_test.cc b/tensorflow/compiler/xla/service/while_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cf0d0db99bd92b6b364b4e28e56a0902d4065963 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_util_test.cc @@ -0,0 +1,130 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_util.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +namespace xla { +namespace { + +namespace op = ::xla::testing::opcode_matchers; + +StatusOr> GetParsedModule( + HloComputation** entry_computation, HloInstruction** param0, + HloInstruction** param1, HloInstruction** param2) { + const char* const hlo_string = R"( +HloModule ModuleWithWhile + +while_body { + ROOT p_body = (f32[32,32]{1,0}, f32[32,32]{1,0}) parameter(0) +} + +while_condition { + p_cond = f32[32,32]{1,0} parameter(0) + ROOT result = pred[] constant(true) +} + +ENTRY entry { + p_entry_0 = f32[32,32]{1,0} parameter(0) + p_entry_1 = s32[32,32]{1,0} parameter(1) + p_entry_2 = s64[32,32]{1,0} parameter(2) + while_init = (f32[32,32]{1,0}, f32[32,32]{1,0}) tuple(p_entry_0, p_entry_0) + ROOT while = (f32[32,32]{1,0}, f32[32,32]{1,0}) while(while_init), condition=while_condition, body=while_body +} +)"; + + TF_ASSIGN_OR_RETURN(std::unique_ptr module, + tools::Parse(hlo_string)); + + *entry_computation = module->entry_computation(); + *param0 = (*entry_computation)->parameter_instruction(0); + *param1 = (*entry_computation)->parameter_instruction(1); + *param2 = (*entry_computation)->parameter_instruction(2); + + return std::move(module); +} + +TEST(WhileUtil, MakeZeroInstructionsLiveOp) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, /*instructions=*/{})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(param_reconstructed, 0), + op::GetTupleElement(param_reconstructed, 1))); +} + +TEST(WhileUtilTest, MakeTwoInstructionsLive) { + HloInstruction *param0, *param1, *param2; + HloComputation* entry_computation; + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + GetParsedModule(&entry_computation, ¶m0, ¶m1, ¶m2)); + + HloInstruction* while_instr = entry_computation->root_instruction(); + ASSERT_EQ(while_instr->opcode(), HloOpcode::kWhile); + + TF_ASSERT_OK_AND_ASSIGN( + WhileUtil::MakeInstructionsLiveInResult make_live_in_result, + WhileUtil::MakeInstructionsLiveIn(while_instr, + /*instructions=*/{param0, param1})); + + HloInstruction* new_while_instr = make_live_in_result.new_while_instr; + + XLA_VLOG_LINES(3, module->ToString()); + + EXPECT_THAT( + entry_computation->root_instruction(), + op::Tuple(op::GetTupleElement(::testing::Eq(new_while_instr), 0), + op::GetTupleElement(::testing::Eq(new_while_instr), 1))); + + auto first_half_param_reconstructed = + op::Tuple(op::GetTupleElement(op::Parameter(0), 0), + op::GetTupleElement(op::Parameter(0), 1)); + + EXPECT_THAT(new_while_instr->while_body()->root_instruction(), + op::Tuple(op::GetTupleElement(first_half_param_reconstructed, 0), + op::GetTupleElement(first_half_param_reconstructed, 1), + op::GetTupleElement(op::Parameter(0), 2), + op::GetTupleElement(op::Parameter(0), 3))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa40b5cb264803097f52966d6f61f1f41b6b3017 --- /dev/null +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.cc @@ -0,0 +1,50 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { + +StatusOr ZeroSizedHloElimination::Run(HloModule* module) { + bool changed = false; + for (HloComputation* comp : module->MakeNonfusionComputations()) { + for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { + if (instruction->HasSideEffect() || + ShapeUtil::IsTuple(instruction->shape())) { + continue; + } + if (comp->IsRemovable(instruction) && + ShapeUtil::HasZeroElements(instruction->shape())) { + TF_RETURN_IF_ERROR(comp->ReplaceWithNewInstruction( + instruction, HloInstruction::CreateConstant( + Literal::CreateFromShape(instruction->shape())))); + changed = true; + } + } + } + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.h b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h similarity index 70% rename from tensorflow/compiler/xla/service/gpu/convolution_folding.h rename to tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h index f9c898721f8dd6b8b7e74c82bb2085cc437eaad5..063e312df66ce9cba0fa9f49c2fc6026ba6b74aa 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.h +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h @@ -13,25 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +// HLO pass that replaces zero sized Hlos with an zero sized constant literal. namespace xla { -namespace gpu { - -class ConvolutionFolding : public HloPassInterface { +class ZeroSizedHloElimination : public HloPassInterface { public: + StatusOr Run(HloModule* module) override; tensorflow::StringPiece name() const override { - return "convolution-folding"; + return "zero_sized_hlo_elimination"; } - - StatusOr Run(HloModule* module) override; }; - -} // namespace gpu } // namespace xla - -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CONVOLUTION_FOLDING_H_ +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_ZERO_SIZED_HLO_ELIMINATION_H_ diff --git a/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f8cdc1e0e73cdaa8675fc945ba3dbe19ce3da7d --- /dev/null +++ b/tensorflow/compiler/xla/service/zero_sized_hlo_elimination_test.cc @@ -0,0 +1,77 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/zero_sized_hlo_elimination.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/shape_inference.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace { +class ZeroSizedHloEliminationTest : public HloTestBase { + protected: + ZeroSizedHloEliminationTest() + : HloTestBase(), + builder_("zero_sized_computation"), + zero_sized_param_( + builder_.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {3, 0}), "zero sized param"))) {} + + StatusOr RunZeroSizedElimination() { + HloModule module("zero_sized_elimination_test_module"); + module.AddEntryComputation(builder_.Build()); + return ZeroSizedHloElimination{}.Run(&module); + } + + HloComputation::Builder builder_; + HloInstruction* zero_sized_param_; +}; + +TEST_F(ZeroSizedHloEliminationTest, EliminatedZeroSizedOp) { + builder_.AddInstruction(HloInstruction::CreateUnary( + zero_sized_param_->shape(), HloOpcode::kTanh, zero_sized_param_)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_TRUE(changed); +} + +TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateParameter) { + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_FALSE(changed); +} + +TEST_F(ZeroSizedHloEliminationTest, DoesNotEliminateSideEffects) { + builder_.AddInstruction(HloInstruction::CreateSend(zero_sized_param_, 0)); + TF_ASSERT_OK_AND_ASSIGN(bool changed, RunZeroSizedElimination()); + EXPECT_FALSE(changed); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc index 5bf9842a6ce7be747f58c10f302f85c6f82ac6f9..789eba5780d37e1fd4d80ec881855951c8bba0eb 100644 --- a/tensorflow/compiler/xla/shape_layout.cc +++ b/tensorflow/compiler/xla/shape_layout.cc @@ -32,13 +32,13 @@ tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) { return tensorflow::Status::OK(); } -tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const { - if (!ShapeUtil::Compatible(*other_shape, shape_)) { +tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const { + if (!ShapeUtil::Compatible(*to_shape, shape_)) { return InvalidArgument("Shape %s is not compatible with shape %s", - ShapeUtil::HumanString(*other_shape).c_str(), + ShapeUtil::HumanString(*to_shape).c_str(), ShapeUtil::HumanString(shape()).c_str()); } - *other_shape = shape_; + *to_shape = shape_; return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h index 92564660f21bf1b596c4b9ca04c07eaca27ed192..4c83750f3e6f3c735db66d8e0b86ae3f43e5ca11 100644 --- a/tensorflow/compiler/xla/shape_layout.h +++ b/tensorflow/compiler/xla/shape_layout.h @@ -38,18 +38,19 @@ class ShapeLayout { explicit ShapeLayout(const Shape& shape) : shape_(shape) {} // Assigns the layouts in this ShapeLayout to the Layout fields of the given - // shape. 'shape' and the shape of the ShapeLayout object must be compatible. - tensorflow::Status AssignLayoutToShape(Shape* shape) const; + // shape. 'to_shape' and the shape of the ShapeLayout object must be + // compatible. + tensorflow::Status AssignLayoutToShape(Shape* to_shape) const; // Returns true if the Layouts in this ShapeLayout match the layouts in the // given shape. Returns false otherwise. If the given shape is not compatible // with the ShapeLayout's shape, then false is returned. bool MatchesLayoutInShape(const Shape& shape) const; - // Copies the layout from the given shape into this ShapeLayout. 'shape' must - // be compatible with the ShapeLayout's shape, and 'shape' must have a layout - // (LayoutUtil::HasLayout). - tensorflow::Status CopyLayoutFromShape(const Shape& shape); + // Copies the layout from the given shape into this ShapeLayout. 'other_shape' + // must be compatible with the ShapeLayout's shape, and 'other_shape' must + // have a layout (LayoutUtil::HasLayout). + tensorflow::Status CopyLayoutFromShape(const Shape& other_shape); // Clears (Layout::Clear) all the Layouts stored in this object. void Clear(); diff --git a/tensorflow/compiler/xla/shape_tree.h b/tensorflow/compiler/xla/shape_tree.h index bf8d19015079f2ce0bd450594040ed818f94b66b..d752619bd65751779c24f061e44e206d66b01465 100644 --- a/tensorflow/compiler/xla/shape_tree.h +++ b/tensorflow/compiler/xla/shape_tree.h @@ -238,7 +238,7 @@ class ShapeTree { // (or compatible). // index : the index of the element in the shape. See ShapeUtil::GetSubshape // for definition of index. - // data : The data value at this elemnt. + // data : The data value at this element. template void ForEachElement(const Fn& func) const; diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 74fa0b2f2e740310be23661caef3f19e24e4087b..604e0173e789348923316174873f58058eaf2815 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include #include +#include #include #include @@ -58,36 +59,47 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index) { return out; } +std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index) { + out << shape_index.ToString(); + return out; +} + namespace { // Recursive helper for comparing the equality of two shapes. Returns true if // the shapes are the same. If compare_layouts is true, then layouts must also // match. bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { - if (ShapeUtil::IsTuple(lhs)) { - return ShapeUtil::IsTuple(rhs) && + if (ShapeUtil::IsTuple(lhs) || ShapeUtil::IsTuple(rhs)) { + return ShapeUtil::IsTuple(lhs) && ShapeUtil::IsTuple(rhs) && ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), [=](const Shape& l, const Shape& r) { return CompareShapes(l, r, compare_layouts); }); + } else if (ShapeUtil::IsOpaque(lhs) || ShapeUtil::IsOpaque(rhs)) { + return ShapeUtil::IsOpaque(lhs) && ShapeUtil::IsOpaque(rhs); } - // Explicitly compare the fields rather than using MessageDifferencer because - // we want empty layouts to be treated identically to missing layouts. + if (compare_layouts) { - if (!ContainersEqual(lhs.layout().minor_to_major(), - rhs.layout().minor_to_major())) { - VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + if (lhs.layout().format() != rhs.layout().format()) { return false; } - if (!ContainersEqual(lhs.layout().padded_dimensions(), - rhs.layout().padded_dimensions())) { - VLOG(3) - << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; - return false; - } - if (lhs.layout().padding_value() != rhs.layout().padding_value()) { - VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; - return false; + if (LayoutUtil::IsDenseArray(lhs)) { + if (!ContainersEqual(LayoutUtil::MinorToMajor(lhs), + LayoutUtil::MinorToMajor(rhs))) { + VLOG(3) << "CompareShapes: lhs layout != rhs layout"; + return false; + } + if (!ContainersEqual(lhs.layout().padded_dimensions(), + rhs.layout().padded_dimensions())) { + VLOG(3) + << "CompareShapes: lhs padded_dimensions != rhs padded_dimensions"; + return false; + } + if (lhs.layout().padding_value() != rhs.layout().padding_value()) { + VLOG(3) << "CompareShapes: lhs padding value != rhs padding_value"; + return false; + } } } @@ -141,7 +153,8 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::Rank(const Shape& shape) { - CHECK(!ShapeUtil::IsTuple(shape)) << "Tuples do not have a rank"; + CHECK(!ShapeUtil::IsTuple(shape)) + << "Tuples do not have a rank, shape: " << shape; return shape.dimensions_size(); } @@ -182,20 +195,32 @@ StatusOr MakeShapeWithLayoutInternal( .ValueOrDie(); } -/* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( +/* static */ Shape ShapeUtil::MakeShapeWithDescendingLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions) { std::vector layout(dimensions.size()); std::iota(layout.rbegin(), layout.rend(), static_cast(0)); return MakeShapeWithLayout(element_type, dimensions, layout); } -/* static */ Shape ShapeUtil::NormalizeShapeToMonotonicDim0MajorLayout( +/* static */ Shape ShapeUtil::MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements) { + DCHECK_NE(TUPLE, element_type); + DCHECK_NE(OPAQUE, element_type); + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + *shape.mutable_layout() = LayoutUtil::MakeSparseLayout(max_sparse_elements); + TF_DCHECK_OK(ShapeUtil::ValidateShape(shape)); + return shape; +} + +/* static */ Shape +ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout( const Shape& shape) { std::vector dims(shape.dimensions_size()); for (int i = 0; i < shape.dimensions_size(); ++i) { dims[i] = shape.dimensions(LayoutUtil::Major(shape.layout(), i)); } - return MakeShapeWithMonotonicDim0MajorLayout(shape.element_type(), dims); + return MakeShapeWithDescendingLayout(shape.element_type(), dims); } /* static */ void ShapeUtil::PopulateShape( @@ -235,6 +260,7 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ void ShapeUtil::AppendMajorDimension(int bound, Shape* shape) { + CHECK(LayoutUtil::IsDenseArray(*shape)); shape->mutable_layout()->add_minor_to_major(Rank(*shape)); shape->add_dimensions(bound); TF_DCHECK_OK(ValidateShape(*shape)); @@ -329,6 +355,14 @@ StatusOr MakeShapeWithLayoutInternal( return MakeTupleShape(new_elements); } +// Returns the shape of a real or imaginary component. +/* static */ Shape ShapeUtil::ComplexComponentShape( + const Shape& complex_shape) { + CHECK(ElementIsComplex(complex_shape)) << HumanString(complex_shape); + return ChangeElementType(complex_shape, primitive_util::ComplexComponentType( + complex_shape.element_type())); +} + /* static */ bool ShapeUtil::ShapeIs(const Shape& shape, PrimitiveType element_type, std::initializer_list dimensions) { @@ -336,7 +370,7 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ int64 ShapeUtil::ElementsIn(const Shape& shape) { - CHECK(!IsTuple(shape)); + CHECK(!IsTuple(shape)) << ShapeUtil::HumanString(shape); CHECK_EQ(shape.dimensions_size(), Rank(shape)); return std::accumulate( shape.dimensions().begin(), shape.dimensions().end(), 1LL, @@ -352,7 +386,7 @@ StatusOr MakeShapeWithLayoutInternal( } /* static */ string ShapeUtil::HumanString(const Shape& shape) { - if (shape.element_type() == TUPLE) { + if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -396,10 +430,30 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { static PrimitiveTypeNameGenerator* gen = new PrimitiveTypeNameGenerator(); return gen->LowercaseName(s); } + +StatusOr StringToPrimitiveType(const string& name) { + static std::unordered_map* name_to_type = [] { + static auto* map = new std::unordered_map; + for (int i = 0; i < PrimitiveType_ARRAYSIZE; i++) { + if (PrimitiveType_IsValid(i)) { + auto value = static_cast(i); + (*map)[LowercasePrimitiveTypeName(value)] = value; + } + } + return map; + }(); + auto found = name_to_type->find(name); + if (found == name_to_type->end()) { + return InvalidArgument("Invalid element type string: \"%s\".", + name.c_str()); + } + return found->second; +} + } // namespace /* static */ string ShapeUtil::HumanStringWithLayout(const Shape& shape) { - if (shape.element_type() == TUPLE) { + if (IsTuple(shape)) { string text = "("; const char* prefix = ""; for (const Shape& elem_shape : shape.tuple_shapes()) { @@ -421,8 +475,6 @@ const string& LowercasePrimitiveTypeName(PrimitiveType s) { if (LayoutUtil::HasLayout(shape)) { tensorflow::strings::StrAppend(&result, LayoutUtil::HumanString(shape.layout())); - } else { - tensorflow::strings::StrAppend(&result, "{no layout}"); } } return result; @@ -470,26 +522,35 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { string element_type_string; string dimensions_string; + string format_string; string layout_string; // tensorflow::StringPiece is not compatible with internal RE2 StringPiece, so // we convert in to the RE2-consumable type and then consume the corresponding // amount from our StringPiece type. tensorflow::RegexpStringPiece s_consumable(s->data(), s->size()); - if (RE2::Consume(&s_consumable, - "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*{([\\d,]*)})?", - &element_type_string, &dimensions_string, &layout_string)) { + if (RE2::Consume( + &s_consumable, + "^(\\w*\\d*)\\[([\\d,]*)\\](?:\\s*(dense|sparse)?\\s*{([\\d,]+)})?", + &element_type_string, &dimensions_string, &format_string, + &layout_string)) { size_t consumed = s->size() - s_consumable.size(); s->remove_prefix(consumed); + auto string_to_int64 = [&s](const string& input) -> StatusOr { + int64 element; + if (!tensorflow::strings::safe_strto64(input.c_str(), &element)) { + return InvalidArgument( + "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", + input.c_str(), s->ToString().c_str()); + } + return element; + }; + auto comma_list_to_int64s = - [&s](const string& input) -> StatusOr> { + [&s, + string_to_int64](const string& input) -> StatusOr> { std::vector results; for (const string& piece : tensorflow::str_util::Split(input, ',')) { - int64 element; - if (!tensorflow::strings::safe_strto64(piece.c_str(), &element)) { - return InvalidArgument( - "Invalid s64 value in parsed shape string: \"%s\" in \"%s\"", - piece.c_str(), s->ToString().c_str()); - } + TF_ASSIGN_OR_RETURN(int64 element, string_to_int64(piece)); results.push_back(element); } return results; @@ -500,31 +561,32 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { comma_list_to_int64s(dimensions_string)); // Extract the primitive element type. - PrimitiveType primitive_type = PRIMITIVE_TYPE_INVALID; - for (PrimitiveType i = - static_cast(PRIMITIVE_TYPE_INVALID + 1); - i < TUPLE; i = static_cast(i + 1)) { - if (tensorflow::str_util::Lowercase(PrimitiveType_Name(i)) == - element_type_string) { - primitive_type = i; - break; - } - } - if (primitive_type == PRIMITIVE_TYPE_INVALID) { + TF_ASSIGN_OR_RETURN(const PrimitiveType primitive_type, + StringToPrimitiveType(element_type_string)); + if (primitive_type == PRIMITIVE_TYPE_INVALID || primitive_type == TUPLE || + primitive_type == OPAQUE) { return InvalidArgument("Invalid element type string: \"%s\".", element_type_string.c_str()); } Shape result; - if (layout_string.empty()) { + if (format_string.empty() && layout_string.empty()) { // Create a shape without a layout set. result = ShapeUtil::MakeShape(primitive_type, dimensions); - } else { + } else if (format_string == "sparse") { + TF_ASSIGN_OR_RETURN(int64 max_elements, string_to_int64(layout_string)); + result = ShapeUtil::MakeShapeWithSparseLayout(primitive_type, dimensions, + max_elements); + } else if (format_string.empty() || format_string == "dense") { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( primitive_type, dimensions, min2maj)); + } else { + // This should not be reached. + LOG(FATAL) << "Unhandled condition when parsing shape; format: \"" + << format_string << "\", layout: \"" << layout_string << "\""; } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); @@ -537,7 +599,12 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { /* static */ StatusOr ShapeUtil::ParseShapeString( tensorflow::StringPiece s) { - return ParseShapeStringInternal(&s); + TF_ASSIGN_OR_RETURN(Shape shape, ParseShapeStringInternal(&s)); + if (!s.empty()) { + return InvalidArgument("Invalid shape string to parse: \"%s\"", + s.ToString().c_str()); + } + return shape; } /* static */ bool ShapeUtil::SameDimensions(const Shape& lhs, @@ -563,6 +630,19 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return SameDimensions(lhs, rhs); } +/* static */ bool ShapeUtil::CompatibleIgnoringFpPrecision(const Shape& lhs, + const Shape& rhs) { + if (lhs.element_type() == TUPLE) { + return rhs.element_type() == TUPLE && + ContainersEqual(lhs.tuple_shapes(), rhs.tuple_shapes(), + CompatibleIgnoringFpPrecision); + } + if (SameElementTypeIgnoringFpPrecision(lhs, rhs)) { + return CompatibleIgnoringElementType(lhs, rhs); + } + return false; +} + /* static */ int64 ShapeUtil::GetDimension(const Shape& shape, int64 dimension_number) { return shape.dimensions(GetDimensionNumber(shape, dimension_number)); @@ -622,23 +702,55 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { TF_DCHECK_OK(ValidateShape(shape)); DCHECK_NE(OPAQUE, shape.element_type()); if (shape.element_type() == TUPLE) { - CHECK_GT(pointer_size, 0); - return pointer_size * shape.tuple_shapes_size(); + return ByteSizeOfTupleIndexTable(shape, pointer_size); } + int64 byte_size = ByteSizeOfElements(shape); + if (LayoutUtil::IsSparseArray(shape)) { + byte_size += ByteSizeOfSparseIndices(shape); + } + return byte_size; +} + +/* static */ int64 ShapeUtil::ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK_EQ(TUPLE, shape.element_type()); + CHECK_GT(pointer_size, 0); + return pointer_size * shape.tuple_shapes_size(); +} + +/* static */ int64 ShapeUtil::ByteSizeOfElements(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(ShapeUtil::IsArray(shape)); int64 allocated_element_count; - if (shape.layout().padded_dimensions_size() > 0) { - CHECK_EQ(Rank(shape), shape.layout().padded_dimensions_size()); - allocated_element_count = 1; - for (int64 dimension_size : shape.layout().padded_dimensions()) { - allocated_element_count *= dimension_size; - } + + if (LayoutUtil::IsSparseArray(shape)) { + allocated_element_count = LayoutUtil::MaxSparseElements(shape.layout()); } else { - allocated_element_count = ElementsIn(shape); + CHECK(LayoutUtil::IsDenseArray(shape)); + tensorflow::gtl::ArraySlice padded_dimensions = + LayoutUtil::PaddedDimensions(shape); + if (!padded_dimensions.empty()) { + CHECK_EQ(Rank(shape), padded_dimensions.size()); + allocated_element_count = 1; + for (int64 dimension_size : padded_dimensions) { + allocated_element_count *= dimension_size; + } + } else { + allocated_element_count = ElementsIn(shape); + } } return allocated_element_count * ByteSizeOfPrimitiveType(shape.element_type()); } +/* static */ int64 ShapeUtil::ByteSizeOfSparseIndices(const Shape& shape) { + TF_DCHECK_OK(ValidateShape(shape)); + DCHECK(LayoutUtil::IsSparseArray(shape)); + return LayoutUtil::MaxSparseElements(shape.layout()) * + ShapeUtil::Rank(shape) * sizeof(int64); +} + /* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal( const Shape& shape) { if (shape.element_type() == TUPLE) { @@ -694,9 +806,9 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return LayoutUtil::ValidateLayoutInShape(shape); } -/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape, +/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original, PrimitiveType type) { - Shape new_shape = shape; + Shape new_shape = original; new_shape.set_element_type(type); return new_shape; } @@ -705,7 +817,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { ShapeIndexView index) { const Shape* return_shape = &shape; for (auto i : index) { - CHECK(IsTuple(*return_shape)); + CHECK(IsTuple(*return_shape)) + << "Invalid index " << index << " for shape " << shape; return_shape = &return_shape->tuple_shapes(i); } return *return_shape; @@ -863,7 +976,9 @@ Status ForEachMutableSubshapeHelper( new_shape.add_dimensions(dim); } if (shape.has_layout()) { + CHECK(LayoutUtil::IsDenseArray(shape)); Layout* new_layout = new_shape.mutable_layout(); + new_layout->set_format(DENSE); new_layout->clear_minor_to_major(); for (auto index : Permute(permutation, shape.layout().minor_to_major())) { new_layout->add_minor_to_major(index); @@ -1117,9 +1232,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, // as input_shape/output_shape and the dimension-0-major layout. These two // shapes are used for conversion between logical linear indices and // multi-dimensional indices. - Shape input_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + Shape input_shape_dim0_major = MakeShapeWithDescendingLayout( input_shape.element_type(), AsInt64Slice(input_shape.dimensions())); - Shape output_shape_dim0_major = MakeShapeWithMonotonicDim0MajorLayout( + Shape output_shape_dim0_major = MakeShapeWithDescendingLayout( output_shape.element_type(), AsInt64Slice(output_shape.dimensions())); for (int64 input_dim = 0; input_dim < Rank(input_shape); ++input_dim) { @@ -1290,6 +1405,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, shape.mutable_dimensions()->erase(shape.dimensions().begin() + dim_to_delete); if (LayoutUtil::HasLayout(shape)) { Layout* layout = shape.mutable_layout(); + layout->set_format(DENSE); for (size_t i = 0; i < layout->minor_to_major().size();) { if (layout->minor_to_major(i) == dim_to_delete) { layout->mutable_minor_to_major()->erase( @@ -1319,4 +1435,9 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, return shape; } +std::ostream& operator<<(std::ostream& out, const Shape& shape) { + out << ShapeUtil::HumanString(shape); + return out; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 2ea1bd95cb571134ab1e1dda37fbc887a1fa06b2..19b1aa93bd373ebd5f502d0dca56c9b31ab4fd7f 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -22,6 +22,8 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -61,6 +63,9 @@ class ShapeIndex { void push_back(int64 value) { indices_.push_back(value); } void pop_back() { indices_.pop_back(); } + // push_front is O(n^2), but shapes don't usually have a ton of dimensions. + void push_front(int64 value) { indices_.insert(indices_.begin(), value); } + std::vector::const_iterator begin() const { return indices_.begin(); } std::vector::const_iterator end() const { return indices_.end(); } std::vector::iterator begin() { return indices_.begin(); } @@ -133,6 +138,7 @@ class ShapeIndexView { }; std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); +std::ostream& operator<<(std::ostream& out, const ShapeIndexView& shape_index); // Namespaced collection of (static) shape utilities. // @@ -141,7 +147,10 @@ std::ostream& operator<<(std::ostream& out, const ShapeIndex& shape_index); class ShapeUtil { public: // Returns the number of elements are contained within the provided shape; - // e.g. for rank 0 (scalars) the result is always 1. + // e.g. for rank 0 (scalars) the result is always 1. Note that sparse shapes + // may not actually be able to store this number of elements. See + // LayoutUtil::MaxSparseElements(shape) to obtain the maximum number of + // elements that can be stored in a sparse shape. // Precondition: !IsTuple(shape) static int64 ElementsIn(const Shape& shape); @@ -162,6 +171,27 @@ class ShapeUtil { // Precondition: !ShapeUtil::IsOpaque(shape) && !ShapeUtil::IsTuple(shape) static int64 ByteSizeOfPrimitiveType(PrimitiveType primitive_type); + // Returns the number of bytes required to store the tuple member pointers for + // a allocation of shape. The `shape` must be a TUPLE shape, and + // `pointer_size` must be larger than zero. + static int64 ByteSizeOfTupleIndexTable(const Shape& shape, + int64 pointer_size); + + // Returns the number of bytes required for the elements in an allocation of + // `shape`, which must be an array shape. The return value does not include + // the bytes needed to store sparse indices. Dense shapes use a separate + // memory location for each element, and so for these shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape)`. For dense shapes, this + // size also includes padding if present in the layout. For sparse shapes, + // `ByteSizeOf(shape) == ByteSizeOfElements(shape) + + // ByteSizeOfSparseindices(shape)`. + static int64 ByteSizeOfElements(const Shape& shape); + + // Returns the number of bytes required for the sparse indices in an + // allocation of shape. The shape must be an array shape. The return value + // does not include the bytes needed to store sparse indices. + static int64 ByteSizeOfSparseIndices(const Shape& shape); + // Returns a human-readable string that represents the given shape, with or // without layout. e.g. "f32[42x12] {0, 1}" or "f32[64]". static string HumanString(const Shape& shape); @@ -170,7 +200,7 @@ class ShapeUtil { // As above, but for program shapes, returns a string for the form: // // (param_name: f32[42x12], ...) -> f32[24x42] - static string HumanString(const ProgramShape& shape); + static string HumanString(const ProgramShape& program_shape); // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. @@ -185,6 +215,31 @@ class ShapeUtil { return lhs.element_type() == rhs.element_type(); } + // As SameElementType, but allows floating point types to have different + // precisions. + static bool SameElementTypeIgnoringFpPrecision(const Shape& a, + const Shape& b) { + if (ElementIsFloating(a) && ElementIsFloating(b)) { + return true; + } + return ShapeUtil::SameElementType(a, b); + } + + // Returns the higher-precision element type if a and b are both floating + // point types; otherwise, checks that that they have the same element type + // and returns it. + static PrimitiveType HigherPrecisionElementType(const Shape& a, + const Shape& b) { + if (SameElementType(a, b)) { + return a.element_type(); + } + CHECK(SameElementTypeIgnoringFpPrecision(a, b)); + return primitive_util::BitWidth(a.element_type()) < + primitive_util::BitWidth(b.element_type()) + ? b.element_type() + : a.element_type(); + } + // Returns true if the rank, dimension sizes, and element type are // identical. Layout is ignored. Tuple elements are compared recursively for // compatibility. @@ -195,6 +250,10 @@ class ShapeUtil { // compatibility. static bool CompatibleIgnoringElementType(const Shape& lhs, const Shape& rhs); + // As Compatible, but allow one of lhs and rhs to be BF16 while the other + // being F32. Tuple elements are compared recursively for compatibility. + static bool CompatibleIgnoringFpPrecision(const Shape& lhs, const Shape& rhs); + // Returns whether the lhs and rhs shapes are identical protobufs. static bool Equal(const Shape& lhs, const Shape& rhs); @@ -267,14 +326,22 @@ class ShapeUtil { PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major); - // Constructs a new shape with major-first layout. - static Shape MakeShapeWithMonotonicDim0MajorLayout( + static Shape MakeShapeWithSparseLayout( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + int64 max_sparse_elements); + + // Constructs a new shape with major-first layout (i.e. {n, n-1, ..., 0}). + static Shape MakeShapeWithDescendingLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions); - // Returns a new shape with major-first layout that has the same layout of - // elements with a different shape. - static Shape NormalizeShapeToMonotonicDim0MajorLayout(const Shape& shape); + // Returns a new Shape based on the given Shape with low-dimension-major + // layout (i.e. {n, n-1, ..., 0}, like Fortran), and with the dimensions + // rearranged so that it has the same in-memory layout as the given shape. + // + // For example, transforms f32[B,H,W,C]{0,3,2,1} to f32[H,W,C,B]{3,2,1,0}. + static Shape MakeShapeWithDescendingLayoutAndSamePhysicalLayout( + const Shape& shape); // As MakeShape, but the object to write to is passed in. static void PopulateShape(PrimitiveType element_type, @@ -324,7 +391,8 @@ class ShapeUtil { return shape.element_type() == OPAQUE; } - // Returns whether the shape is an array. + // Returns whether the shape is an array. Note that scalars are considered + // arrays. static bool IsArray(const Shape& shape) { return !IsTuple(shape) && !IsOpaque(shape); } @@ -351,6 +419,10 @@ class ShapeUtil { // shape. E.g. a tuple like (f32, s32, u32) would slice via 1,3 to (s32, u32). static Shape SliceTuple(const Shape& tuple, int64 start, int64 limit); + // Returns the shape of the real/imaginary components of the given complex + // shape. + static Shape ComplexComponentShape(const Shape& complex_shape); + // Shorthand for testing whether a shape is of a given element type and // sequence of dimensions. // @@ -502,8 +574,7 @@ class ShapeUtil { CHECK_EQ(Rank(shape), base.size()); CHECK_EQ(incr.size(), base.size()); CHECK_EQ(count.size(), base.size()); - const Layout& layout = shape.layout(); - const int64 rank = layout.minor_to_major_size(); + const int64 rank = LayoutUtil::MinorToMajor(shape).size(); // Allows handling R0 arrays, such that the visitor function will be called // once with the proper empty indexes. int64 n = -1; @@ -511,7 +582,7 @@ class ShapeUtil { while (n < rank && visitor_function(indexes)) { // Increments dimensions in minor to major order. for (n = 0; n < rank; ++n) { - int64 dim = layout.minor_to_major(n); + int64 dim = LayoutUtil::Minor(shape.layout(), n); indexes[dim] += incr[dim]; if (indexes[dim] < base[dim] + count[dim]) { break; @@ -529,6 +600,8 @@ class ShapeUtil { TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil); }; +std::ostream& operator<<(std::ostream& out, const Shape& shape); + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SHAPE_UTIL_H_ diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 4bce7ca51d0534cbcad6faac12818c5f3e94b29e..4db97d45b20b86dc60531845c6e28a223203ff7f 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" #include "tensorflow/compiler/xla/types.h" @@ -71,7 +72,8 @@ TEST(ShapeUtilTest, Rank4DimensionIndexing) { TEST(ShapeUtilTest, ParseShapeStringR2F32) { string shape_string = "f32[123,456]"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeShape(F32, {123, 456}); ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) << "expected: " << ShapeUtil::HumanString(expected) @@ -80,7 +82,8 @@ TEST(ShapeUtilTest, ParseShapeStringR2F32) { TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { string shape_string = "(f32[1572864],s8[5120,1024])"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {1572864}), ShapeUtil::MakeShape(S8, {5120, 1024})}); @@ -91,7 +94,8 @@ TEST(ShapeUtilTest, ParseShapeStringTupleOfArrays) { TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { string shape_string = "(f32[1],(f32[2]), f32[3])"; - Shape actual = ShapeUtil::ParseShapeString(shape_string).ValueOrDie(); + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); Shape expected = ShapeUtil::MakeTupleShape({ ShapeUtil::MakeShape(F32, {1}), ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(F32, {2})}), @@ -102,6 +106,47 @@ TEST(ShapeUtilTest, ParseShapeStringNestedTuple) { << "actual: " << ShapeUtil::HumanString(actual); } +TEST(ShapeUtilTest, ParseShapeStringWithLayout) { + string shape_string = "f32[123,456]{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringWithExplicitDenseLayout) { + string shape_string = "f32[123,456]dense{0,1}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithLayout(F32, {123, 456}, {0, 1}); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseShapeStringWithSparseLayout) { + string shape_string = "f32[123,456]sparse{10}"; + TF_ASSERT_OK_AND_ASSIGN(Shape actual, + ShapeUtil::ParseShapeString(shape_string)); + Shape expected = ShapeUtil::MakeShapeWithSparseLayout(F32, {123, 456}, 10); + ASSERT_TRUE(ShapeUtil::Equal(expected, actual)) + << "expected: " << ShapeUtil::HumanString(expected) + << "actual: " << ShapeUtil::HumanString(actual); +} + +TEST(ShapeUtilTest, ParseInvalidShapeString) { + string shape_strings[] = { + "f32[123,456]foobar{0,1}", "f32[123,456]sparse{0,1}", "f32[123,456]{foo}", + "f32[123,456]dense{foo}", "f32[123,456]sparse{foo}", + }; + for (const string& shape_string : shape_strings) { + StatusOr result = ShapeUtil::ParseShapeString(shape_string); + ASSERT_FALSE(result.ok()) << "shape: " << shape_string; + } +} + TEST(ShapeUtilTest, CompatibleIdenticalShapes) { Shape shape1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); @@ -125,6 +170,18 @@ TEST(ShapeUtilTest, CompatibleNotIdenticalShapes) { EXPECT_TRUE(ShapeUtil::Compatible(shape_1, shape_2)); } +TEST(ShapeUtilTest, CompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {3, 2}); + ASSERT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + +TEST(ShapeUtilTest, IncompatibleIgnoringFpPrecision) { + Shape shape1 = ShapeUtil::MakeShape(BF16, {3, 2}); + Shape shape2 = ShapeUtil::MakeShape(F32, {2, 2}); + ASSERT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(shape1, shape2)); +} + TEST(ShapeUtilTest, IncompatibleDifferentElementShapes) { Shape shape_1 = ShapeUtil::MakeShape(F32, {3, 2}); Shape shape_2 = ShapeUtil::MakeShape(PRED, {3, 2}); @@ -139,6 +196,14 @@ TEST(ShapeUtilTest, CompatibleTuples) { EXPECT_TRUE(ShapeUtil::Compatible(tuple1, tuple2)); } +TEST(ShapeUtilTest, CompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {3, 2}), ShapeUtil::MakeShape(F32, {4, 5})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F64, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_TRUE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); @@ -148,6 +213,14 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithSwappedElements) { EXPECT_FALSE(ShapeUtil::CompatibleIgnoringElementType(tuple1, tuple2)); } +TEST(ShapeUtilTest, IncompatibleTuplesIgnoringFpPrecision) { + Shape tuple1 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(BF16, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); + Shape tuple2 = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {3, 2}), ShapeUtil::MakeShape(BF16, {4, 5})}); + EXPECT_FALSE(ShapeUtil::CompatibleIgnoringFpPrecision(tuple1, tuple2)); +} + TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentPrimitiveType) { Shape tuple1 = ShapeUtil::MakeTupleShape( {ShapeUtil::MakeShape(PRED, {4, 5}), ShapeUtil::MakeShape(F32, {3, 2})}); @@ -165,20 +238,6 @@ TEST(ShapeUtilTest, IncompatibleTuplesWithDifferentDimensions) { EXPECT_FALSE(ShapeUtil::Compatible(tuple1, tuple2)); } -TEST(ShapeUtilTest, EmptyLayoutEqualsMissingLayout) { - // A shape with a missing layout should be equal to a shape with an empty - // layout. - Shape scalar1 = ShapeUtil::MakeShape(F32, {}); - Shape scalar2 = ShapeUtil::MakeShape(F32, {}); - - EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); - - scalar1.clear_layout(); // Remove layout field. - scalar2.mutable_layout(); // Create empty layout field. - - EXPECT_TRUE(ShapeUtil::Equal(scalar1, scalar2)); -} - TEST(ShapeUtilTest, CompareShapesWithPaddedDimensionsMismatch) { Shape shape1 = ShapeUtil::MakeShape(F32, {20, 30}); shape1.mutable_layout()->add_padded_dimensions(10); @@ -199,17 +258,17 @@ TEST(ShapeUtilTest, CompareShapesWithPaddingValueMismatch) { EXPECT_FALSE(ShapeUtil::Equal(shape1, shape2)); } -TEST(ShapeUtilTest, ScalarUnpopulatedLayoutEqualsScalarLayout) { - Shape scalar_unpopulated = ShapeUtil::MakeShape(F32, {}); - scalar_unpopulated.clear_layout(); - ASSERT_FALSE(scalar_unpopulated.has_layout()) - << ShapeUtil::HumanStringWithLayout(scalar_unpopulated); +TEST(ShapeUtilTest, ScalarDefaultLayoutEqualsScalarEmptyMin2Maj) { + Shape scalar_default_layout = ShapeUtil::MakeShape(F32, {}); + ASSERT_TRUE(scalar_default_layout.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_default_layout); - const Shape scalar_populated = ShapeUtil::MakeShapeWithLayout(F32, {}, {}); - ASSERT_TRUE(scalar_populated.has_layout()) - << ShapeUtil::HumanStringWithLayout(scalar_populated); + const Shape scalar_empty_min2maj = + ShapeUtil::MakeShapeWithLayout(F32, {}, {}); + ASSERT_TRUE(scalar_empty_min2maj.has_layout()) + << ShapeUtil::HumanStringWithLayout(scalar_empty_min2maj); - EXPECT_TRUE(ShapeUtil::Equal(scalar_unpopulated, scalar_populated)); + EXPECT_TRUE(ShapeUtil::Equal(scalar_default_layout, scalar_empty_min2maj)); } TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { diff --git a/tensorflow/compiler/xla/sparse_index_array.cc b/tensorflow/compiler/xla/sparse_index_array.cc new file mode 100644 index 0000000000000000000000000000000000000000..31844abd89a020c87c403353374a80fb639a3244 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.cc @@ -0,0 +1,110 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +SparseIndexArray::SparseIndexArray() : rank_(0), max_indices_(0) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices) + : indices_(std::move(indices)), rank_(rank), max_indices_(max_indices) { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0) + << "indices_.size(): " << indices_.size() << ", rank_: " << rank_; + CHECK_LT(index_count(), max_indices_); +} + +SparseIndexArray::SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices) + : SparseIndexArray(max_indices, rank, + std::vector(indices.begin(), indices.end())) {} + +SparseIndexArray::SparseIndexArray(int64 max_indices, + const Array2D& indices) + : SparseIndexArray(max_indices, indices.n2(), + std::vector(indices.begin(), indices.end())) {} + +int64 SparseIndexArray::index_count() const { + CHECK_GT(rank_, 0); + CHECK_EQ(indices_.size() % rank_, 0); + return indices_.size() / rank_; +} + +tensorflow::gtl::ArraySlice SparseIndexArray::At( + int64 sparse_element_number) const { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_element_number, 0); + CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); + return tensorflow::gtl::ArraySlice( + indices_.data() + rank_ * sparse_element_number, rank_); +} + +tensorflow::gtl::MutableArraySlice SparseIndexArray::At( + int64 sparse_element_number) { + CHECK_GT(rank_, 0); + CHECK_GE(sparse_element_number, 0); + CHECK_LE(rank_ * sparse_element_number + rank_, indices_.size()); + return tensorflow::gtl::MutableArraySlice( + indices_.data() + rank_ * sparse_element_number, rank_); +} + +void SparseIndexArray::Append(tensorflow::gtl::ArraySlice index) { + CHECK_GT(rank_, 0); + CHECK_EQ(index.size(), rank_); + indices_.insert(indices_.end(), index.begin(), index.end()); +} + +void SparseIndexArray::Clear() { indices_.clear(); } + +void SparseIndexArray::Resize(int64 num_indices) { + CHECK_GT(rank_, 0); + indices_.resize(rank_ * num_indices); +} + +bool SparseIndexArray::Validate(const Shape& shape) const { + if (rank_ == 0 || rank_ != ShapeUtil::Rank(shape)) { + return false; + } + int64 num_indices = index_count(); + if (num_indices > LayoutUtil::MaxSparseElements(shape.layout())) { + return false; + } + if (num_indices < 2) { + return true; + } + tensorflow::gtl::ArraySlice last = At(0); + if (!IndexUtil::IndexInBounds(shape, last)) { + return false; + } + for (int64 n = 1; n < num_indices; ++n) { + tensorflow::gtl::ArraySlice next = At(n); + if (!IndexUtil::IndexInBounds(shape, next)) { + return false; + } + if (IndexUtil::CompareIndices(last, next) >= 0) { + return false; + } + last = next; + } + return true; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/sparse_index_array.h b/tensorflow/compiler/xla/sparse_index_array.h new file mode 100644 index 0000000000000000000000000000000000000000..f2ce22d6721ff8da46f741ccedc2a63dea5994c8 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array.h @@ -0,0 +1,176 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Utility class for managing sparse array indices. + +#ifndef TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ +#define TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ + +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/index_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace xla { + +// Encapsulates the array of indices for a sparse array. A SparseIndexArray +// contain indices for up to `max_indices` elements of a sparse array. Each +// sparse index is an array of `rank` int64 value that gives the location of a +// value within a sparse array. Note that the dimensions of the array are not +// checked (except for the rank). To avoid confusion, we refer to the position +// of an index within a SparseIndexArray as a sparse index number. +class SparseIndexArray { + public: + SparseIndexArray(); + SparseIndexArray(const SparseIndexArray&) = default; + SparseIndexArray(SparseIndexArray&&) = default; + SparseIndexArray& operator=(const SparseIndexArray&) = default; + SparseIndexArray& operator=(SparseIndexArray&&) = default; + + // Constructs a SparseIndexArray that can hold up to `max_indices` sparse + // indices, with an initial contents obtained from the given array. The rank + // is taken from the minor dimension of the array. The major dimension of the + // array must not exceed `max_indices`. + SparseIndexArray(int64 max_indices, const Array2D& indices); + + // Like above, but the array is flattened. For example, the following are + // equivalent: + // + // SparseIndexArray(10, 3, + // Array2D{ + // {0, 1, 2}, + // {3, 4, 5}, + // {6, 7, 8}, + // {9, 10, 11}, + // }) + // + // SparseIndexArray(10, 3, + // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}) + // + SparseIndexArray(int64 max_indices, int64 rank, + std::vector indices = {}); + SparseIndexArray(int64 max_indices, int64 rank, + tensorflow::gtl::ArraySlice indices); + + // Returns the number of elements represented by the indices stored in the + // array. + int64 index_count() const; + + // Returns a slice that refers to the given sparse index number. The argument + // must be in the range [0, element_count()). + tensorflow::gtl::ArraySlice At(int64 sparse_element_number) const; + tensorflow::gtl::MutableArraySlice At(int64 sparse_element_number); + + // Adds the given index at the end of the array. The new size of the + // SparseIndexArray must not exceed `max_indices`. + void Append(tensorflow::gtl::ArraySlice index); + + // Removes all indices from the array. + void Clear(); + + // Resizes the array to contain the given number of sparse indices. The new + // size must be smaller than `max_indices`. If the new size is larger than + // the old size, the value of the new indices is not specified. + void Resize(int64 num_indices); + + // Returns true iff all indices are unique and occur in sorted order, and are + // valid for the given shape. + bool Validate(const Shape& shape) const; + + int64 rank() const { return rank_; } + int64 max_indices() const { return max_indices_; } + + // Returns a pointer to the int64 array that holds the sparse indices. + tensorflow::gtl::MutableArraySlice mutable_data() { return &indices_; } + tensorflow::gtl::ArraySlice data() const { return indices_; } + + // Sorts this sparse index array along with the set of corresponding values. + // The indices and values are sorted in the lexicographic order of the + // indices, from smallest to largest. + // + // For example: + // + // std::vector v{10.0, 11.0, 12.0}; + // SparseIndexArray a(10, 3, + // {{3, 4, 5}, + // {1, 2, 3}, + // {2, 3, 4}}); + // a.SortWithValues(&v); + // // Prints "11.0, 12.0, 10.0": + // std::cout << v[0] << ", " << v[1] << ", " << v[2] << std::endl; + // + template + void SortWithValues(tensorflow::gtl::MutableArraySlice values); + + private: + std::vector indices_; + int64 rank_; + int64 max_indices_; +}; + +template +void SparseIndexArray::SortWithValues( + tensorflow::gtl::MutableArraySlice values) { + int64 num_elements = index_count(); + CHECK_EQ(values.size(), num_elements); + std::vector sort_order; + sort_order.reserve(num_elements); + for (int64 i = 0; i < num_elements; ++i) { + sort_order.push_back(i); + } + auto sort_order_less = [this](int64 lhs, int64 rhs) { + return IndexUtil::CompareIndices(At(lhs), At(rhs)) < 0; + }; + std::sort(sort_order.begin(), sort_order.end(), sort_order_less); + + // Reorder the array elements according to sort_order. Work through the array + // and follow cycles so we can do the reorder in-place. + tensorflow::gtl::InlinedVector saved_index(rank()); + for (int64 i = 0; i < num_elements; ++i) { + // sort_order[i] == -1 indicates the element has already been copied. + if (sort_order[i] < 0) { + continue; + } else if (i == sort_order[i]) { + // The element is already in sorted order. + sort_order[i] = -1; + continue; + } + + std::copy_n(At(i).begin(), rank(), saved_index.begin()); + NativeT saved_value = values[i]; + int64 j = i; + for (;;) { + if (sort_order[j] == i) { + std::copy_n(saved_index.begin(), rank(), At(j).begin()); + values[j] = saved_value; + sort_order[j] = -1; + break; + } + + std::copy_n(At(sort_order[j]).begin(), rank(), At(j).begin()); + values[j] = values[sort_order[j]]; + + int64 k = sort_order[j]; + sort_order[j] = -1; + j = k; + } + } +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SPARSE_INDEX_ARRAY_H_ diff --git a/tensorflow/compiler/xla/sparse_index_array_test.cc b/tensorflow/compiler/xla/sparse_index_array_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7377f88958dcb7daf3d3f4f0e07966fdc9294580 --- /dev/null +++ b/tensorflow/compiler/xla/sparse_index_array_test.cc @@ -0,0 +1,43 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/sparse_index_array.h" + +#include + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(SparseIndexArrayTest, Sort) { + SparseIndexArray a(10, 3); + a.Append({2, 3, 4}); + a.Append({3, 4, 5}); + a.Append({1, 2, 3}); + a.Append({5, 6, 7}); + a.Append({4, 5, 6}); + a.Append({6, 7, 8}); + std::vector values = { + 12.0, 13.0, 11.0, 15.0, 14.0, 16.0, + }; + a.SortWithValues(&values); + ASSERT_EQ(a.data(), std::vector({1, 2, 3, 2, 3, 4, 3, 4, 5, 4, 5, 6, 5, + 6, 7, 6, 7, 8})); + ASSERT_EQ(values, std::vector({11.0, 12.0, 13.0, 14.0, 15.0, 16.0})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/status_macros.h b/tensorflow/compiler/xla/status_macros.h index 5e5550563d02de99ddefbeb8ee8e1bf98afdcdbf..e51dd64e2a3dc7c359918cb08c6c94b2b4d9e91b 100644 --- a/tensorflow/compiler/xla/status_macros.h +++ b/tensorflow/compiler/xla/status_macros.h @@ -196,18 +196,8 @@ class StatusAdaptorForMacros { #define TF_STATUS_MACROS_CONCAT_NAME(x, y) TF_STATUS_MACROS_CONCAT_IMPL(x, y) #define TF_STATUS_MACROS_CONCAT_IMPL(x, y) x##y -#define TF_ASSIGN_OR_RETURN(...) \ - TF_STATUS_MACRO_GET_VARIADIC_IMPL(__VA_ARGS__, TF_ASSIGN_OR_RETURN_IMPL_3, \ - TF_ASSIGN_OR_RETURN_IMPL_2) \ - (__VA_ARGS__) - -#define TF_STATUS_MACRO_GET_VARIADIC_IMPL(_1, _2, _3, NAME, ...) NAME - -#define TF_ASSIGN_OR_RETURN_IMPL_2(lhs, rexpr) \ - TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) - -#define TF_ASSIGN_OR_RETURN_IMPL_3(lhs, rexpr) \ - TF_ASSIGN_OR_RETURN_IMPL( \ +#define TF_ASSIGN_OR_RETURN(lhs, rexpr) \ + TF_ASSIGN_OR_RETURN_IMPL( \ TF_STATUS_MACROS_CONCAT_NAME(_status_or_value, __COUNTER__), lhs, rexpr) #define TF_ASSIGN_OR_RETURN_IMPL(statusor, lhs, rexpr) \ diff --git a/tensorflow/compiler/xla/statusor_internals.h b/tensorflow/compiler/xla/statusor_internals.h index a2fda5bb3c6f11c20fc45c57885b1ce7523db81d..14636bd144bc0a155fc96c5a350c658fd2dadfe6 100644 --- a/tensorflow/compiler/xla/statusor_internals.h +++ b/tensorflow/compiler/xla/statusor_internals.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#ifndef TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#define TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ #include "tensorflow/compiler/xla/status.h" #include "tensorflow/core/platform/macros.h" @@ -242,4 +242,4 @@ struct TraitsBase { } // namespace internal_statusor } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ +#endif // TENSORFLOW_COMPILER_XLA_STATUSOR_INTERNALS_H_ diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc index 5fa2211ac66177514ac8ecabfa8791e7c8c014a2..f9d25945bc617507735fb6c4d011c39723497f69 100644 --- a/tensorflow/compiler/xla/statusor_test.cc +++ b/tensorflow/compiler/xla/statusor_test.cc @@ -32,26 +32,26 @@ namespace { class Base1 { public: virtual ~Base1() {} - int pad; + int pad_; }; class Base2 { public: virtual ~Base2() {} - int yetotherpad; + int yetotherpad_; }; class Derived : public Base1, public Base2 { public: ~Derived() override {} - int evenmorepad; + int evenmorepad_; }; class CopyNoAssign { public: - explicit CopyNoAssign(int value) : foo(value) {} - CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {} - int foo; + explicit CopyNoAssign(int value) : foo_(value) {} + CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {} + int foo_; private: const CopyNoAssign& operator=(const CopyNoAssign&); @@ -253,7 +253,7 @@ TEST(StatusOr, TestCopyCtorNonAssignable) { StatusOr original(value); StatusOr copy(original); EXPECT_EQ(copy.status(), original.status()); - EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo); + EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_); } TEST(StatusOr, TestCopyCtorStatusOKConverting) { diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index addce9019b340f9489a25dbdd2437f4d71740b95..5ff774075259e819718bcb91af4092129a6df582 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -69,6 +69,7 @@ cc_library( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_dataflow_analysis", "//tensorflow/compiler/xla/service:hlo_verifier", "//tensorflow/compiler/xla/service:transfer_manager", "//tensorflow/core:lib", @@ -104,7 +105,9 @@ cc_library( hdrs = ["hlo_test_base.h"], deps = [ ":literal_test_util", + ":test_utils", "//tensorflow/compiler/xla:shape_layout", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -114,6 +117,10 @@ cc_library( "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_runner", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", @@ -338,6 +345,24 @@ xla_test( ], ) +xla_test( + name = "xla_hlo_profile_test", + srcs = ["xla_hlo_profile_test.cc"], + deps = [ + "//tensorflow/compiler/xla:array2d", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:test_utils", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + "//tensorflow/core:test", + ], +) + xla_test( name = "axpy_simple_test", srcs = ["axpy_simple_test.cc"], @@ -354,6 +379,7 @@ xla_test( xla_test( name = "map_test", srcs = ["map_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -382,6 +408,7 @@ xla_test( name = "params_test", srcs = ["params_test.cc"], shard_count = 30, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -430,6 +457,22 @@ xla_test( ], ) +xla_test( + name = "conditional_test", + srcs = ["conditional_test.cc"], + deps = [ + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "unary_op_test", srcs = ["unary_op_test.cc"], @@ -532,9 +575,30 @@ xla_test( ], ) +xla_test( + name = "exhaustive_f32_elementwise_op_test", + srcs = ["exhaustive_f32_elementwise_op_test.cc"], + backends = [ + "cpu", + "gpu", + ], + shard_count = 48, + tags = [ + "enormous", + "manual", + ], + deps = [ + ":client_library_test_base", + ":literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + xla_test( name = "reduce_precision_test", srcs = ["reduce_precision_test.cc"], + tags = ["enable_for_xla_interpreter"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -557,6 +621,9 @@ xla_test( xla_test( name = "dot_operation_test", srcs = ["dot_operation_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -579,6 +646,9 @@ xla_test( xla_test( name = "dot_operation_runtime_test", srcs = ["dot_operation_test.cc"], + tags = [ + "enable_for_xla_interpreter", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", @@ -773,11 +843,6 @@ xla_test( xla_test( name = "bfloat16_test", srcs = ["bfloat16_test.cc"], - blacklisted_backends = [ - "cpu", - "cpu_parallel", - "gpu", - ], shard_count = 40, deps = [ ":test_utils", @@ -807,6 +872,31 @@ xla_test( ], ) +xla_test( + name = "half_test", + srcs = ["half_test.cc"], + backends = [ + # TODO(b/72509305): Flaky (fails with SEGV) as of 2018-01-25 + # "cpu", + "gpu", + ], + deps = [ + ":test_utils", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/client:computation", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/tests:client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "slice_test", srcs = ["slice_test.cc"], @@ -961,7 +1051,10 @@ xla_test( name = "reduce_window_test", timeout = "long", srcs = [], - tags = ["optonly"], + tags = [ + "enable_for_xla_interpreter", + "optonly", + ], xla_test_library_deps = [":reduce_window_test_library"], deps = [], ) @@ -970,6 +1063,10 @@ xla_test( name = "select_and_scatter_test", timeout = "long", srcs = ["select_and_scatter_test.cc"], + tags = [ + "enable_for_xla_interpreter", + "optonly", + ], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -1008,6 +1105,19 @@ xla_test( ], ) +xla_test( + name = "reduce_hlo_test", + srcs = ["reduce_hlo_test.cc"], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + xla_test( name = "call_test", srcs = ["call_test.cc"], @@ -1036,9 +1146,10 @@ xla_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", + "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", "//tensorflow/core:test", ], @@ -1364,6 +1475,31 @@ xla_test( ], ) +xla_test( + name = "execution_profile_test", + srcs = ["execution_profile_test.cc"], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + +xla_test( + name = "execution_profile_test_with_xla_hlo_profile", + srcs = ["execution_profile_test.cc"], + args = ["--xla_hlo_profile"], + deps = [ + ":client_library_test_base", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/compiler/xla/client:global_data", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + xla_test( name = "replay_test", srcs = ["replay_test.cc"], @@ -1456,6 +1592,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1482,6 +1619,7 @@ xla_test( "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xla/tests:client_library_test_base", "//tensorflow/compiler/xla/tests:hlo_test_base", @@ -1676,6 +1814,45 @@ xla_test( ], ) +# A demo of textual IR based test. +xla_test( + name = "sample_text_test", + srcs = ["sample_text_test.cc"], + # You can leave this empty if you want to test all supported backends. + backends = [ + "cpu", + "gpu", + ], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:lib", + ], +) + +# A demo of test that loads an hlo module from a file and compares results on gpu and cpu. +tf_cc_test( + name = "sample_file_test", + srcs = ["sample_file_test.cc"], + data = ["isolated_convolution.hlo"], + tags = ["requires-gpu-sm35"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/service:cpu_plugin", # reference backend + "//tensorflow/compiler/xla/service:gpu_plugin", # test backend + "//tensorflow/compiler/xla/service:platform_util", + "//tensorflow/compiler/xla/tests:literal_test_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index c6e8b24d1211743d07878d388522feacf9c0e7f1..7e9005001db34d403ea923eb9c152d114bf32803 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -1879,20 +1879,73 @@ XLA_TEST_F(ArrayElementwiseOpTest, ClampF32ScalarVector) { auto min_scalar = builder.ConstantR0(0.0f); auto min_vector = builder.ConstantR1({1.0f, -6.5f, 1.0f, 2.25f, 0.0f}); auto arg_vector = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); - auto arg_scalar = builder.ConstantR1({2.0f, 10.0f, -5.0f, 1.0f, 4.0f}); auto max_scalar = builder.ConstantR0(3.0f); auto max_vector = builder.ConstantR1({3.0f, 0.5f, 25.5f, 5.0f, 123.0}); // Perform clamp with broadcasted scalar and vector. auto clamp = builder.Add( builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), builder.Clamp(min_scalar, arg_vector, max_vector)), - builder.Add(builder.Clamp(min_vector, arg_scalar, max_vector), - builder.Clamp(min_scalar, arg_scalar, max_vector))); + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); - ComputeAndCompareR1(&builder, {8.0f, 4.5f, 2.0f, 6.5f, 15.0f}, {}, + ComputeAndCompareR1(&builder, {8.0f, 7.0f, 2.0f, 6.5f, 14.0f}, {}, error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, ClampS32Vector) { + ComputationBuilder builder(client_, TestName()); + auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0, -5}); + auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4, 10}); + auto max_vector = builder.ConstantR1({3, 0, 25, 5, 123, -1}); + auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + + ComputeAndCompareR1(&builder, {2, 0, 1, 2, 4, -1}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampS32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0(0); + auto min_vector = builder.ConstantR1({1, -6, 1, 2, 0}); + auto arg_vector = builder.ConstantR1({2, 10, -5, 1, 4}); + auto max_scalar = builder.ConstantR0(3); + auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); + + ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampU32Vector) { + ComputationBuilder builder(client_, TestName()); + auto min_vector = builder.ConstantR1({1, 2, 1, 2, 0, ~0u - 4}); + auto arg_vector = builder.ConstantR1({2, 10, 5, 1, 4, 10}); + auto max_vector = builder.ConstantR1({3, 5, 25, 5, 123, ~0u}); + auto clamp = builder.Clamp(min_vector, arg_vector, max_vector); + + ComputeAndCompareR1(&builder, {2, 5, 5, 2, 4, ~0u - 4}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ClampU32ScalarVector) { + ComputationBuilder builder(client_, TestName()); + auto min_scalar = builder.ConstantR0(0); + auto min_vector = builder.ConstantR1({1, 0, 1, 2, 0}); + auto arg_vector = builder.ConstantR1({2, 10, 0, 1, 4}); + auto max_scalar = builder.ConstantR0(3); + auto max_vector = builder.ConstantR1({3, 1, 25, 5, 123}); + // Perform clamp with broadcasted scalar and vector. + auto clamp = builder.Add( + builder.Add(builder.Clamp(min_vector, arg_vector, max_scalar), + builder.Clamp(min_scalar, arg_vector, max_vector)), + builder.Add(builder.Clamp(min_vector, arg_vector, max_vector), + builder.Clamp(min_scalar, arg_vector, max_scalar))); + + ComputeAndCompareR1(&builder, {8, 8, 2, 6, 14}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddTwoParametersF32s) { ComputationBuilder builder(client_, TestName()); @@ -1971,6 +2024,18 @@ XLA_TEST_F(ArrayElementwiseOpTest, SinF32s) { error_spec_); } +XLA_TEST_F(ArrayElementwiseOpTest, Atan2F32s) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0.0f, 5.0f, 0.0f, -3.0f, 2.0f, -8.0f}); + auto b = builder.ConstantR1({6.0f, 0.0f, -4.0f, 0.0f, 2.0f, 8.0f}); + auto atan = builder.Atan2(a, b); + + ComputeAndCompareR1( + &builder, + {0.0f, 1.57079633f, 3.14159265f, -1.57079633f, 0.78539816f, -0.78539816f}, + {}, error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({-2.5f, 3.14f, 2.25f}); @@ -1983,47 +2048,117 @@ XLA_TEST_F(ArrayElementwiseOpTest, TanhF32s) { XLA_TEST_F(ArrayElementwiseOpTest, TanhF32sVector) { // This is like the test ArrayElementwiseOpTest.TanhF32s above, except that // the input tensor is large enough to exercise the vectorized tanh - // implementation. - ComputationBuilder builder(client_, TestName()); - auto input_literal = Literal::CreateR2( - {{1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80}, - {-0.67, 0.16, -0.07, 0.39, -0.41, 0.04, 1.36, 1.25}, - {0.41, 0.65, -1.08, 0.32, -1.45, -0.77, -1.09, 0.91}, - {-1.03, -0.30, -1.11, -1.17, 1.50, -0.85, 0.04, 1.02}, - {0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81}, - {0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26}, - {-1.29, 1.35, 0.08, -1.24, -0.92, 0.49, 1.17, -0.45}, - {-1.31, -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05}}); - auto input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateR1( + {1.02, -0.32, 0.85, 0.90, 1.23, -0.91, -0.49, 0.80, -0.67, 0.16, + -0.07, 0.39, -0.41, 0.04, 1.36, 1.25, 0.41, 0.65, -1.08, 0.32, + -1.45, -0.77, -1.09, 0.91, -1.03, -0.30, -1.11, -1.17, 1.50, -0.85, + 0.04, 1.02, 0.34, -0.61, 0.41, 0.07, -0.02, 1.42, -0.62, 0.81, + 0.08, 0.81, -0.30, 1.17, -0.65, -0.44, 0.92, 1.26, -1.29, 1.35, + 0.08, -1.24, -0.92, 0.49, 1.17, -0.45, -1.31, -1.44, -0.13, -1.31, + -0.79, 1.41, 1.21, 1.05}); + TF_ASSERT_OK_AND_ASSIGN(auto input_data, + client_->TransferToServer(*input_literal)); auto input = builder.Parameter(0, input_literal->shape(), "input"); builder.Tanh(input); - ComputeAndCompareR2( + ComputeAndCompareR1( &builder, - {{0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, - -0.71985596, -0.45764771, 0.66664988}, - {-0.58278900, 0.16050975, -0.06770509, 0.36843640, -0.38476998, - 0.04018109, 0.87562293, 0.84788644}, - {0.38603750, 0.57294142, -0.79140943, 0.31032649, -0.89590985, - -0.64770776, -0.79625875, 0.72234446}, - {-0.77389336, -0.28871772, -0.80428445, -0.82541436, 0.90456349, - -0.68856895, 0.03877772, 0.76877952}, - {0.32561871, -0.54546672, 0.39072621, 0.07273290, -0.01924866, - 0.88924897, -0.55283129, 0.67183107}, - {0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, - -0.41581789, 0.72739530, 0.85025692}, - {-0.85931867, 0.87357593, 0.07782833, -0.84597743, -0.72748238, - 0.45396307, 0.82449573, -0.42462519}, - {-0.86363792, -0.89368379, -0.12621804, -0.86445558, -0.65565848, - 0.88789743, 0.83566397, 0.78287679}}, + {0.77009583, -0.30665702, 0.69070244, 0.71401149, 0.84400684, + -0.71985596, -0.45764771, 0.66664988, -0.58278900, 0.16050975, + -0.06770509, 0.36843640, -0.38476998, 0.04018109, 0.87562293, + 0.84788644, 0.38603750, 0.57294142, -0.79140943, 0.31032649, + -0.89590985, -0.64770776, -0.79625875, 0.72234446, -0.77389336, + -0.28871772, -0.80428445, -0.82541436, 0.90456349, -0.68856895, + 0.03877772, 0.76877952, 0.32561871, -0.54546672, 0.39072621, + 0.07273290, -0.01924866, 0.88924897, -0.55283129, 0.67183107, + 0.08006320, 0.66944766, -0.29068485, 0.82573754, -0.57170743, + -0.41581789, 0.72739530, 0.85025692, -0.85931867, 0.87357593, + 0.07782833, -0.84597743, -0.72748238, 0.45396307, 0.82449573, + -0.42462519, -0.86363792, -0.89368379, -0.12621804, -0.86445558, + -0.65565848, 0.88789743, 0.83566397, 0.78287679}, {input_data.get()}, // The error spec is unusually high here to account for the fact that we // use a rational interpolant to approximate tanh. ErrorSpec(0.004, 0.004)); } +XLA_TEST_F(ArrayElementwiseOpTest, ExpF32sVector) { + // The input tensor is large enough to exercise the vectorized exp + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + + // Just to help make sense of the scales here -- exp(89) saturates float32 and + // exp(-10) is smaller than our error spec. + std::unique_ptr input_literal = Literal::CreateR1( + {1.02, -0.32, 0.85, 0.9, 1.23, -0.91, -0.49, 0.8, -1.31, + -1.44, -0.13, -1.31, -0.79, 1.41, 1.21, 1.05, -195.6, -194.5, + -193.4, -192.3, -191.2, -190.1, -189.0, -187.9, -19.6, -18.5, -17.4, + -16.3, -15.2, -14.1, -13.0, -11.9, -10.8, -9.7, -8.6, -7.5, + -6.4, -5.3, -4.2, -3.1, -2.0, -0.9, 0.2, 1.3, 2.4, + 3.5, 4.6, 5.7, 6.8, 7.9, 9.0, 10.1, 11.2, 12.3, + 13.4, 14.5, 15.6, 16.7, 17.8, 18.9, 20.0, 21.1, 22.2, + 23.3, 24.4, 25.5, 26.6, 27.7, 28.8, 29.9, 31.0, 32.1, + 68.4, 69.5, 70.6, 71.7, 72.8, 73.9, 75.0, 76.1, 77.2, + 78.3, 79.4, 80.5, 81.6, 82.7, 83.8, 84.9, 85.2, 86.3, + 86.4, 86.5, 87.6, 87.7, 87.8, 87.9}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + builder.Exp(input); + + std::vector expected_result; + int64 input_size = input_literal->shape().dimensions(0); + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(std::exp(input_literal->Get({i}))); + } + + ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, + error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, LogF32sVector) { + // The input tensor is large enough to exercise the vectorized exp + // implementation on XLA CPU. + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr input_literal = Literal::CreateR1( + {-1.29, -1.41, -1.25, -13.5, -11.7, -17.9, -198, + -167, 1.29, 1.41, 1.25, 13.5, 11.7, 17.9, + 198, 167, 1.27e+03, 1.33e+03, 1.74e+03, 1.6e+04, 1.84e+04, + 1.74e+04, 1.89e+05, 1.9e+05, 1.93e+06, 1.98e+06, 1.65e+06, 1.97e+07, + 1.66e+07, 1e+07, 1.98e+08, 1.96e+08, 1.64e+09, 1.58e+09, 1.64e+09, + 1.44e+10, 1.5e+10, 1.99e+10, 1.17e+11, 1.08e+11, 1.08e+12, 1.38e+12, + 1.4e+12, 1.03e+13, 1.6e+13, 1.99e+13, 1.26e+14, 1.51e+14, 1.33e+15, + 1.41e+15, 1.63e+15, 1.39e+16, 1.21e+16, 1.27e+16, 1.28e+17, 1.62e+17, + 2e+18, 1.96e+18, 1.81e+18, 1.99e+19, 1.86e+19, 1.61e+19, 1.71e+20, + 1.47e+20, 1.83e+21, 1.33e+21, 1.3e+21, 1.35e+22, 1.84e+22, 1.02e+22, + 1.81e+23, 1.02e+23, 1.89e+24, 1.49e+24, 1.08e+24, 1.95e+25, 1.1e+25, + 1.62e+25, 1.2e+26, 1.41e+26, 1.93e+27, 1.66e+27, 1.62e+27, 1.05e+28, + 1.5e+28, 1.79e+28, 1.36e+29, 1.95e+29, 1.5e+30, 1.81e+30, 1.34e+30, + 1.7e+31, 1.44e+31, 1.1e+31, 1.4e+32, 1.67e+32, 1.96e+33, 1.11e+33, + 1.19e+33, 1.61e+34, 1.05e+34, 1.88e+34, 1.67e+35, 1.7e+35}); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + builder.Log(input); + + std::vector expected_result; + int64 input_size = input_literal->shape().dimensions(0); + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(std::log(input_literal->Get({i}))); + } + + ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, + error_spec_); +} + XLA_TEST_F(ArrayElementwiseOpTest, AddChainFoldLeft) { // a ------ (add) --------- (add) // / / @@ -2520,9 +2655,8 @@ XLA_TEST_F(ArrayElementwiseOpTest, R4_16x16x2x2_Plus_R1_16) { std::iota(r1.begin(), r1.end(), 1.0); ComputationBuilder builder(client_, TestName()); - std::unique_ptr a_literal = Literal::CreateR4FromArray4D(r4); - *a_literal->mutable_shape()->mutable_layout() = - LayoutUtil::MakeLayout({0, 1, 2, 3}); + std::unique_ptr a_literal = Literal::CreateR4FromArray4DWithLayout( + r4, LayoutUtil::MakeLayout({0, 1, 2, 3})); auto a = builder.ConstantLiteral(*a_literal); auto b = builder.ConstantR1(r1); builder.Add(a, b, {1}); diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc index 627a9c3e7d9f6eb8d360228362ea5adf12c6c798..3f6fd7c65d3360a622dbf754833009fb20410535 100644 --- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc +++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc @@ -62,6 +62,10 @@ TEST_F(AxpySimpleTest, AxpyTenValues) { auto ax = builder.Mul(alpha, x); auto axpy = builder.Add(ax, y); + TF_ASSERT_OK_AND_ASSIGN(ProgramShape shape, builder.GetProgramShape()); + + EXPECT_EQ("() -> f32[10]", ShapeUtil::HumanString(shape)); + std::vector expected = { 1.85840735, -1.85840735, 2.28318531, -2.28318531, -6.42477796, 6.42477796, 10.56637061, -10.56637061, -14.70796327, 14.70796327}; diff --git a/tensorflow/compiler/xla/tests/batch_normalization_test.cc b/tensorflow/compiler/xla/tests/batch_normalization_test.cc index 028d1251b455b82a291c236f7866e52e27d3590e..28ab9654997728fbafd6610af840e721e72cce5a 100644 --- a/tensorflow/compiler/xla/tests/batch_normalization_test.cc +++ b/tensorflow/compiler/xla/tests/batch_normalization_test.cc @@ -39,6 +39,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/math/math_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -46,9 +48,13 @@ limitations under the License. namespace xla { namespace { -class BatchNormalizationTest : public ClientLibraryTestBase { +class BatchNormalizationTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface { protected: BatchNormalizationTest() : input_array_(kSamples, kZ, kY, kX) { + mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm(GetParam()); + Array2D pz({ // z0 z1 {-1.0f, 4.1f}, // p0 @@ -56,7 +62,7 @@ class BatchNormalizationTest : public ClientLibraryTestBase { {5.0f, 4.4f}, // p2 }); input_array_.FillWithPZ(pz); - input_literal_ = *Literal::CreateR4FromArray4D(input_array_); + input_literal_ = std::move(*Literal::CreateR4FromArray4D(input_array_)); CHECK_EQ(kSamples, input_array_.planes()); CHECK_EQ(kZ, input_array_.depth()); CHECK_EQ(kY, input_array_.height()); @@ -73,7 +79,18 @@ class BatchNormalizationTest : public ClientLibraryTestBase { const ErrorSpec error_spec_{0.001, 0.001}; }; -TEST_F(BatchNormalizationTest, SubtractInZ) { +// If testing the GPU backend, run the tests twice, with and without cudnn +// batchnorm. Otherwise, just run the tests once -- the value of this flag +// doesn't matter. +#ifdef XLA_TEST_BACKEND_GPU +INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, + ::testing::Bool()); +#else +INSTANTIATE_TEST_CASE_P(BatchNormalizationTestInstance, BatchNormalizationTest, + ::testing::Values(false)); +#endif + +XLA_TEST_P(BatchNormalizationTest, SubtractInZ) { ComputationBuilder builder(client_, "subtract_in_z_one_sample"); auto x = builder.ConstantLiteral(input_literal_); auto y = builder.ConstantR1({3.14, 4.25}); @@ -89,22 +106,24 @@ TEST_F(BatchNormalizationTest, SubtractInZ) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, SquareTesseractElementwise) { +XLA_TEST_P(BatchNormalizationTest, SquareTesseractElementwise) { ComputationBuilder builder(client_, "square_tesseract_elementwise"); auto x = builder.ConstantLiteral(input_literal_); builder.SquareF32(x); + using tensorflow::MathUtil; + Array4D expected(kSamples, kZ, kY, kX); Array2D expected_pz({ - {std::pow(-1.0f, 2.0f), std::pow(4.1f, 2.0f)}, - {std::pow(2.0f, 2.0f), std::pow(4.1f, 2.0f)}, - {std::pow(5.0f, 2.0f), std::pow(4.4f, 2.0f)}, + {MathUtil::IPow(-1.0f, 2), MathUtil::IPow(4.1f, 2)}, + {MathUtil::IPow(2.0f, 2), MathUtil::IPow(4.1f, 2)}, + {MathUtil::IPow(5.0f, 2), MathUtil::IPow(4.4f, 2)}, }); expected.FillWithPZ(expected_pz); ComputeAndCompareR4(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, SumToZ) { +XLA_TEST_P(BatchNormalizationTest, SumToZ) { ComputationBuilder builder(client_, "sum_to_z"); auto input_activations = builder.ConstantLiteral(input_literal_); Computation add = CreateScalarAddComputation(F32, &builder); @@ -116,7 +135,7 @@ TEST_F(BatchNormalizationTest, SumToZ) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, SquareAndReduce) { +XLA_TEST_P(BatchNormalizationTest, SquareAndReduce) { ComputationBuilder builder(client_, "square_and_reduce"); auto input_activations = builder.ConstantLiteral(input_literal_); auto set_means = builder.ConstantR1({2.f, 4.2f}); @@ -131,7 +150,7 @@ TEST_F(BatchNormalizationTest, SquareAndReduce) { ComputeAndCompareR1(&builder, expected, {}, error_spec_); } -TEST_F(BatchNormalizationTest, VarianceToStddev) { +XLA_TEST_P(BatchNormalizationTest, VarianceToStddev) { ComputationBuilder builder(client_, "variance_to_stddev"); auto variance = builder.ConstantR1({6.f, .02f}); auto sqrt = builder.SqrtF32(variance); @@ -142,7 +161,7 @@ TEST_F(BatchNormalizationTest, VarianceToStddev) { // Compare against a forward batch normalization example in the NN spec // reference. -TEST_F(BatchNormalizationTest, SpecComparisonForward) { +XLA_TEST_P(BatchNormalizationTest, SpecComparisonForward) { ComputationBuilder builder(client_, "batch_normalize_per_spec"); auto input_activations = builder.CheckShape(builder.ConstantLiteral(input_literal_), @@ -198,19 +217,227 @@ TEST_F(BatchNormalizationTest, SpecComparisonForward) { ComputeAndCompareR4(&builder, expected, {}, error_spec_); } +XLA_TEST_P(BatchNormalizationTest, BasicTraining) { + const int kFeatureIndex = 3; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, + {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, BasicTrainingOnSublane) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = builder.ConstantR4FromArray4D( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); + + auto scale = builder.ConstantR1({2.0f, 3.0f}); + + auto offset = builder.ConstantR1({1.0f, 2.0f}); + + auto tuple = builder.BatchNormTraining(operand, scale, offset, + /*epsilon=*/0.001, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, + {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) + .get(), + Literal::CreateR1({4, 5}).get(), + Literal::CreateR1({5, 5}).get()}); + + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, TrainingWithFeatureOnLowDimension) { + // Use 0 dimension as feature, tests layout analyzer. + const int kFeatureIndex = 0; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(260, 1.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/1, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) + .get(), + Literal::CreateR1(std::vector(260, 1.0f)).get(), + Literal::CreateR1(std::vector(260, 0.0f)).get()}); + + ComputeAndCompareTuple(&builder, *expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, LargeEpsilonTest) { + // Test the correctness of choosing a large epsilon value. + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle h0; + auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, + /*parameter_number=*/0, "operand", + &builder, &h0); + ComputationDataHandle h1; + auto scale = + CreateR1Parameter(std::vector(1, 1.0f), + /*parameter_number=*/1, "scale", &builder, &h1); + ComputationDataHandle h2; + auto offset = + CreateR1Parameter(std::vector(1, 0.0f), + /*parameter_number=*/2, "offset", &builder, &h2); + + // var = 125, mean = 15, epsilon = -100 + auto tuple = builder.BatchNormTraining(h0, h1, h2, + /*epsilon=*/-100, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) + .get(), + Literal::CreateR1(std::vector(1, 15.0f)).get(), + Literal::CreateR1(std::vector(1, 125.0f)).get()}); + + ComputeAndCompareTuple(&builder, *expected, + {operand.get(), scale.get(), offset.get()}, + ErrorSpec(0.1)); +} + +XLA_TEST_P(BatchNormalizationTest, BatchNormGradBasic) { + const int kFeatureIndex = 2; + ComputationBuilder builder(client_, TestName()); + + auto operand = + builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); + + auto scale = builder.ConstantR1({1.0f, 1.0f}); + + auto mean = builder.ConstantR1({0.0f, 0.0f}); + + auto var = builder.ConstantR1({1.0f, 1.0f}); + + auto grad_output = builder.ConstantR4FromArray4D( + {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); + + builder.BatchNormGrad(operand, scale, mean, var, grad_output, + /*epsilon=*/0.0, kFeatureIndex); + + auto expected = Literal::MakeTuple( + {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, + {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) + .get(), + Literal::CreateR1({0, 0}).get(), + Literal::CreateR1({16, 20}).get()}); + + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.1)); +} + struct BatchNormTestParam { std::vector bounds; int64 feature_index; float random_value_mean; float random_value_var; + bool use_cudnn_batchnorm; + + friend ::std::ostream& operator<<(::std::ostream& os, + const BatchNormTestParam& p) { + os << "bounds={" << tensorflow::str_util::Join(p.bounds, ", ") << "}, "; + os << "feature_index=" << p.feature_index << ", "; + os << "random_value_mean=" << p.random_value_mean << ", "; + os << "random_value_var=" << p.random_value_var; + + // Don't print use_cudnn_batchnorm when it's false, because most backends + // never set it to true. + if (p.use_cudnn_batchnorm) { + os << ", use_cudnn_batchnorm=true"; + } + return os; + } }; // Tests to test the fused operation of BatchNorm. -class BatchNormTest : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class BatchNormTestManySizes + : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + public: + BatchNormTestManySizes() { + mutable_debug_options()->set_xla_gpu_use_cudnn_batchnorm( + GetParam().use_cudnn_batchnorm); + } }; -XLA_TEST_P(BatchNormTest, RandomizedTests) { +std::vector BuildBatchNormTestParams() { + std::vector params; + + auto add_testcase = [&](std::vector bounds, int64 feature_index, + float random_value_mean, float random_value_var) { + BatchNormTestParam p{bounds, feature_index, random_value_mean, + random_value_var, /*use_cudnn_batchnorm=*/false}; + params.push_back(p); + + // If testing the GPU backend, also run with cudnn batchnorm enabled. +#ifdef XLA_TEST_BACKEND_GPU + p.use_cudnn_batchnorm = true; + params.push_back(p); +#endif + }; + + add_testcase({2, 2, 2, 2}, 0, 100.2f, 200.0f); + add_testcase({2, 2, 2, 2}, 3, 300.f, 400.0f); + + add_testcase({1, 10, 1, 1}, 0, 10.1f, 20.1f); + add_testcase({10, 10, 10, 10}, 1, 3.14f, 314.15f); + add_testcase({10, 10, 10, 10}, 2, 666.6f, 777.7f); + add_testcase({10, 10, 10, 10}, 1, -666.6f, 777.7f); + add_testcase({10, 10, 10, 10}, 2, 0.f, 777.7f); + add_testcase({1, 1, 10, 130}, 2, 0.f, 777.7f); + add_testcase({1, 1, 130, 11}, 2, 0.f, 777.7f); + add_testcase({1, 1, 10, 1}, 3, 888.8f, 9.9f); + + add_testcase({24, 129, 1, 2}, 2, 10000, 10000); + add_testcase({24, 129, 1, 2}, 3, 10000, 10000); + + // Feature on low dimension to trigger relayout, check that internal logical + // to physical dimension calculation is correct after relayout. + add_testcase({1, 2, 3, 4}, 0, 100, 100); + + // Zero-sized tensor. + add_testcase({1, 0, 100, 42}, 0, 100, 100); + + return params; +} + +INSTANTIATE_TEST_CASE_P(BatchNormTest_Instantiation, BatchNormTestManySizes, + ::testing::ValuesIn(BuildBatchNormTestParams())); + +XLA_TEST_P(BatchNormTestManySizes, RandomizedTrainingTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector& bounds = GetParam().bounds; @@ -286,9 +513,9 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) { auto offset_activations = builder.Parameter(2, offset_literal->shape(), "scale"); - auto expected = *Literal::MakeTuple({expected_normalized.get(), - Literal::CreateR1(mean).get(), - Literal::CreateR1(var).get()}); + auto expected = Literal::MakeTuple({expected_normalized.get(), + Literal::CreateR1(mean).get(), + Literal::CreateR1(var).get()}); std::unique_ptr input_data = client_->TransferToServer(*input_literal).ConsumeValueOrDie(); @@ -300,13 +527,17 @@ XLA_TEST_P(BatchNormTest, RandomizedTests) { builder.BatchNormTraining(input_activations, scale_activations, offset_activations, epsilon, feature_index); + // Run all HLO passes during this test. In particular, ClientLibraryTestBase + // disables constant folding, but we want it enabled for our zero-sized tensor + // testcase. + execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); ComputeAndCompareTuple( - &builder, expected, + &builder, *expected, {input_data.get(), scale_data.get(), offset_data.get()}, ErrorSpec(0.01, 1)); } -XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { +XLA_TEST_P(BatchNormTestManySizes, RandomizedInferencingTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector& bounds = GetParam().bounds; @@ -402,6 +633,11 @@ XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { offset_activations, mean_activations, variance_activations, epsilon, feature_index); + // Run all HLO passes during this test. In particular, ClientLibraryTestBase + // disables constant folding, but we want it enabled for our zero-sized tensor + // testcase. + execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); + ComputeAndCompareR4( &builder, expected, {input_data.get(), scale_data.get(), offset_data.get(), mean_data.get(), @@ -409,7 +645,7 @@ XLA_TEST_P(BatchNormTest, RandomizedInferencingTests) { ErrorSpec(0.01, 1)); } -XLA_TEST_P(BatchNormTest, RandomizedGradTests) { +XLA_TEST_P(BatchNormTestManySizes, RandomizedGradTests) { float epsilon = 0.001; ComputationBuilder builder(client_, TestName()); const std::vector& bounds = GetParam().bounds; @@ -447,7 +683,11 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { std::vector mean(feature_bound); for (int64 i = 0; i < feature_bound; ++i) { - mean[i] = sum[i] / num_elements_per_feature; + if (num_elements_per_feature > 0) { + mean[i] = sum[i] / num_elements_per_feature; + } else { + mean[i] = 0; + } } std::vector mean_square(feature_bound); @@ -457,7 +697,11 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { std::vector square_mean(feature_bound); for (int64 i = 0; i < feature_bound; ++i) { - square_mean[i] = sum_squared[i] / num_elements_per_feature; + if (num_elements_per_feature > 0) { + square_mean[i] = sum_squared[i] / num_elements_per_feature; + } else { + square_mean[i] = 0; + } } std::vector var(feature_bound); @@ -535,8 +779,12 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { grad_activation, scale4D, [](float a, float b) { return a * b; }); grad_activation = *ReferenceUtil::MapArray4D( - grad_activation, rsqrt_var_add_epsilon, - [=](float a, float b) { return a * b / num_elements_per_feature; }); + grad_activation, rsqrt_var_add_epsilon, [=](float a, float b) { + if (num_elements_per_feature > 0) { + return a * b / num_elements_per_feature; + } + return 0.f; + }); auto expected_grad_activation = Literal::CreateR4FromArray4D(grad_activation); @@ -571,179 +819,20 @@ XLA_TEST_P(BatchNormTest, RandomizedGradTests) { grad_output_parameter, epsilon, feature_index); auto expected = - *Literal::MakeTuple({expected_grad_activation.get(), - Literal::CreateR1(grad_scale).get(), - Literal::CreateR1(grad_offset).get()}); + Literal::MakeTuple({expected_grad_activation.get(), + Literal::CreateR1(grad_scale).get(), + Literal::CreateR1(grad_offset).get()}); + + // Run all HLO passes during this test. In particular, ClientLibraryTestBase + // disables constant folding, but we want it enabled for our zero-sized tensor + // testcase. + execution_options_.mutable_debug_options()->clear_xla_disable_hlo_passes(); - ComputeAndCompareTuple(&builder, expected, + ComputeAndCompareTuple(&builder, *expected, {input_data.get(), scale_data.get(), mean_data.get(), var_data.get(), grad_output_data.get()}, ErrorSpec(0.01, 1)); } -INSTANTIATE_TEST_CASE_P( - BatchNormTest_Instantiation, BatchNormTest, - ::testing::Values(BatchNormTestParam{{2, 2, 2, 2}, 0, 100.2f, 200.0f}, - BatchNormTestParam{{2, 2, 2, 2}, 3, 300.f, 400.0f}, - - BatchNormTestParam{{1, 10, 1, 1}, 0, 10.1f, 20.1f}, - BatchNormTestParam{{10, 10, 10, 10}, 1, 3.14f, 314.15f}, - BatchNormTestParam{{10, 10, 10, 10}, 2, 666.6f, 777.7f}, - BatchNormTestParam{{10, 10, 10, 10}, 1, -666.6f, 777.7f}, - BatchNormTestParam{{10, 10, 10, 10}, 2, 0.f, 777.7f}, - BatchNormTestParam{{1, 1, 10, 130}, 2, 0.f, 777.7f}, - BatchNormTestParam{{1, 1, 130, 11}, 2, 0.f, 777.7f}, - BatchNormTestParam{{1, 1, 10, 1}, 3, 888.8f, 9.9f}, - - BatchNormTestParam{{24, 129, 1, 2}, 2, 10000, 10000}, - BatchNormTestParam{{24, 129, 1, 2}, 3, 10000, 10000}, - - // Feature on low dimension to trigger relayout, test - // internal logical to physical dimension calculation - // is correct after relayout. - BatchNormTestParam{{1, 2, 3, 4}, 0, 100, 100})); - -XLA_TEST_F(BatchNormTest, BasicTraining) { - const int kFeatureIndex = 3; - ComputationBuilder builder(client_, TestName()); - - auto operand = builder.ConstantR4FromArray4D( - {{{{1.f, 2.f}}, {{3.f, 4.f}}}, {{{5.f, 6.f}}, {{7.f, 8.f}}}}); - - auto scale = builder.ConstantR1({2.0f, 3.0f}); - - auto offset = builder.ConstantR1({1.0f, 2.0f}); - - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f, -2.0f}}, {{0.1f, 0.6f}}}, - {{{1.9f, 3.3f}}, {{3.7f, 6.0f}}}}) - .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); - - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, BasicTrainingOnSublane) { - const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); - - auto operand = builder.ConstantR4FromArray4D( - {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - - auto scale = builder.ConstantR1({2.0f, 3.0f}); - - auto offset = builder.ConstantR1({1.0f, 2.0f}); - - auto tuple = builder.BatchNormTraining(operand, scale, offset, - /*epsilon=*/0.001, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR4({{{{-1.6f}, {-2.0f}}, {{0.1f}, {0.6f}}}, - {{{1.9f}, {3.3f}}, {{3.7f}, {6.0f}}}}) - .get(), - Literal::CreateR1({4, 5}).get(), - Literal::CreateR1({5, 5}).get()}); - - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, DISABLED_ON_GPU(TrainingWithFeatureOnLowDimension)) { - // Use 0 dimension as feature, tests layout analyzer. - const int kFeatureIndex = 0; - ComputationBuilder builder(client_, TestName()); - - ComputationDataHandle h0; - auto operand = CreateR3Parameter(Array3D(260, 2, 2, 1.0f), - /*parameter_number=*/0, "operand", - &builder, &h0); - ComputationDataHandle h1; - auto scale = - CreateR1Parameter(std::vector(260, 1.0f), - /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; - auto offset = - CreateR1Parameter(std::vector(260, 1.0f), - /*parameter_number=*/2, "offset", &builder, &h2); - - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/1, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR3FromArray3D(Array3D(260, 2, 2, 1.0f)) - .get(), - Literal::CreateR1(std::vector(260, 1.0f)).get(), - Literal::CreateR1(std::vector(260, 0.0f)).get()}); - - ComputeAndCompareTuple(&builder, expected, - {operand.get(), scale.get(), offset.get()}, - ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, LargeEpsilonTest) { - // Test the correctness of choosing a large epsilon value. - const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); - - ComputationDataHandle h0; - auto operand = CreateR3Parameter({{{0.0f}, {10.0f}, {20.0f}, {30.0f}}}, - /*parameter_number=*/0, "operand", - &builder, &h0); - ComputationDataHandle h1; - auto scale = - CreateR1Parameter(std::vector(1, 1.0f), - /*parameter_number=*/1, "scale", &builder, &h1); - ComputationDataHandle h2; - auto offset = - CreateR1Parameter(std::vector(1, 0.0f), - /*parameter_number=*/2, "offset", &builder, &h2); - - // var = 125, mean = 15, epsilon = -100 - auto tuple = builder.BatchNormTraining(h0, h1, h2, - /*epsilon=*/-100, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR3FromArray3D({{{-3.0f}, {-1.0f}, {1.0f}, {3.0f}}}) - .get(), - Literal::CreateR1(std::vector(1, 15.0f)).get(), - Literal::CreateR1(std::vector(1, 125.0f)).get()}); - - ComputeAndCompareTuple(&builder, expected, - {operand.get(), scale.get(), offset.get()}, - ErrorSpec(0.1)); -} - -XLA_TEST_F(BatchNormTest, BatchNormGradBasic) { - const int kFeatureIndex = 2; - ComputationBuilder builder(client_, TestName()); - - auto operand = - builder.ConstantR4FromArray4D(Array4D(2, 2, 2, 1, 0.0f)); - - auto scale = builder.ConstantR1({1.0f, 1.0f}); - - auto mean = builder.ConstantR1({0.0f, 0.0f}); - - auto var = builder.ConstantR1({1.0f, 1.0f}); - - auto grad_output = builder.ConstantR4FromArray4D( - {{{{1.f}, {2.f}}, {{3.f}, {4.f}}}, {{{5.f}, {6.f}}, {{7.f}, {8.f}}}}); - - builder.BatchNormGrad(operand, scale, mean, var, grad_output, - /*epsilon=*/0.0, kFeatureIndex); - - auto expected = *Literal::MakeTuple( - {Literal::CreateR4({{{{-3.f}, {-3.f}}, {{-1.f}, {-1.f}}}, - {{{1.f}, {1.f}}, {{3.f}, {3.f}}}}) - .get(), - Literal::CreateR1({0, 0}).get(), - Literal::CreateR1({16, 20}).get()}); - - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.1)); -} - } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc index a1c53ef2aa95c7d2a9d46483dfda22a05ff0cf1a..b853dfaa15d7ff2e21048a5a6a486d22c5a05416 100644 --- a/tensorflow/compiler/xla/tests/bfloat16_test.cc +++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc @@ -61,6 +61,15 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) { error_spec_); } +XLA_TEST_F(Bfloat16Test, LogOperation) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.ConstantR0(static_cast(4.0f)); + builder.Log(x); + + ComputeAndCompareR0(&builder, static_cast(1.387f), {}, + error_spec_); +} + XLA_TEST_F(Bfloat16Test, NegateScalarF16) { ComputationBuilder builder(client_, TestName()); builder.Neg(builder.ConstantR0(static_cast(2.1f))); @@ -88,10 +97,11 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { auto tuple = builder.BatchNormTraining(operand, scale, offset, /*epsilon=*/0.001, kFeatureIndex); - auto expected = *Literal::MakeTuple( + auto expected = Literal::MakeTuple( {Literal::CreateR4( - {{{{static_cast(-1.7f)}, {static_cast(-2.04f)}}, - {{static_cast(0.105f)}, {static_cast(0.65f)}}}, + {{{{static_cast(-1.6875f)}, + {static_cast(-2.04f)}}, + {{static_cast(0.105f)}, {static_cast(0.66f)}}}, {{{static_cast(1.89f)}, {static_cast(3.35f)}}, {{static_cast(3.7f)}, {static_cast(6.04f)}}}}) .get(), @@ -102,7 +112,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) { {static_cast(5), static_cast(5)}) .get()}); - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); } XLA_TEST_F(Bfloat16Test, BatchNormGrad) { @@ -130,7 +140,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { builder.BatchNormGrad(operand, scale, mean, var, grad_output, /*epsilon=*/0.0, kFeatureIndex); - auto expected = *Literal::MakeTuple( + auto expected = Literal::MakeTuple( {Literal::CreateR4( {{{{static_cast(-3.f)}, {static_cast(-3.f)}}, {{static_cast(-1.f)}, {static_cast(-1.f)}}}, @@ -144,7 +154,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormGrad) { {static_cast(16), static_cast(20)}) .get()}); - ComputeAndCompareTuple(&builder, expected, {}, ErrorSpec(0.01)); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.01)); } } // namespace diff --git a/tensorflow/compiler/xla/tests/broadcast_test.cc b/tensorflow/compiler/xla/tests/broadcast_test.cc index 0294628a127c9d506e6387d0b80f3da583c5a174..6ebbf7191833ef85ee4a48cc96c0a3be38c71228 100644 --- a/tensorflow/compiler/xla/tests/broadcast_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_test.cc @@ -87,11 +87,11 @@ XLA_TEST_F(BroadcastTest, BroadcastVectorTo2D) { LiteralTestUtil::ExpectNear( *Literal::CreateR2({{1.0, 1.0}, {2.0, 2.0}, {3.0, 3.0}}), - result->tuple_literals(0), error_spec_); + LiteralView::Create(*result, {0}), error_spec_); LiteralTestUtil::ExpectNear( *Literal::CreateR2({{1.0, 2.0, 3.0}, {1.0, 2.0, 3.0}}), - result->tuple_literals(1), error_spec_); + LiteralView::Create(*result, {1}), error_spec_); } XLA_TEST_F(BroadcastTest, Broadcast2DTo2D) { diff --git a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc index 659660d91e519b428d28ced8591d05b4e4d45f53..f594cc10ac6496f710d03f0b0b134e6dd3b6d38f 100644 --- a/tensorflow/compiler/xla/tests/check_execution_arity_test.cc +++ b/tensorflow/compiler/xla/tests/check_execution_arity_test.cc @@ -104,7 +104,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), - ContainsRegex("expects parameter 0")); + ContainsRegex( + "Argument does not match shape of computation parameter 0")); // Shape mismatch in parameter 1 (rank) status = client_->Execute(computation, {f32_data.get(), f32_data.get()}, @@ -112,7 +113,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), - ContainsRegex("expects parameter 1")); + ContainsRegex( + "Argument does not match shape of computation parameter 1")); // Shape mismatch in parameter 1 (element type) status = client_->Execute(computation, {f32_data.get(), u8_4_data.get()}, @@ -120,7 +122,8 @@ XLA_TEST_F(CheckExecutionArityTest, CheckArgumentShapes) { ASSERT_FALSE(status.ok()); ASSERT_EQ(status.status().code(), tensorflow::error::INVALID_ARGUMENT); ASSERT_THAT(status.status().error_message(), - ContainsRegex("expects parameter 1")); + ContainsRegex( + "Argument does not match shape of computation parameter 1")); } } // namespace diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 15bd273e9b69f9c177a4ec6b5c9f0e1dccee7fc1..a677986cd926cc0054d8f36abc98ccac33dc043d 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -251,8 +251,17 @@ ClientLibraryTestBase::ComputeAndCompareLiteralWithAllInputLayouts( tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, + tensorflow::gtl::ArraySlice arguments_passed_in, const Shape* shape_with_layout) { + std::vector arguments(arguments_passed_in.begin(), + arguments_passed_in.end()); + if (!arguments_.empty()) { + CHECK(arguments.empty()); + for (const auto& argument : arguments_) { + arguments.push_back(argument.get()); + } + } + TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); if (ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())) { @@ -267,12 +276,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; std::unique_ptr converted_expected; Shape layout_shape; - if (expected.shape().element_type() == F32 && use_bfloat16_) { + if (use_bfloat16_) { converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; - layout_shape.set_element_type(BF16); + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); shape_with_layout = &layout_shape; } } @@ -295,8 +309,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec error, - const Shape* shape_with_layout) { + tensorflow::gtl::ArraySlice arguments_passed_in, + ErrorSpec error, const Shape* shape_with_layout) { + std::vector arguments(arguments_passed_in.begin(), + arguments_passed_in.end()); + if (!arguments_.empty()) { + CHECK(arguments.empty()); + for (const auto& argument : arguments_) { + arguments.push_back(argument.get()); + } + } + TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); @@ -305,13 +328,17 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( const Literal* expected_ptr = &expected; std::unique_ptr converted_expected; Shape layout_shape; - if (expected.shape().element_type() == F32 && use_bfloat16_) { + if (use_bfloat16_) { converted_expected = LiteralTestUtil::ConvertF32ToBF16(expected); expected_ptr = converted_expected.get(); - layout_shape.set_element_type(BF16); if (shape_with_layout != nullptr) { layout_shape = *shape_with_layout; - layout_shape.set_element_type(BF16); + ShapeUtil::ForEachMutableSubshape( + &layout_shape, [&](Shape* subshape, const ShapeIndex& /*index*/) { + if (subshape->element_type() == F32) { + subshape->set_element_type(BF16); + } + }); shape_with_layout = &layout_shape; } } @@ -348,7 +375,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1U8( VLOG(1) << "expected: " << expected_literal->ToString(); VLOG(1) << "actual: " << actual->ToString(); - EXPECT_EQ(expected, actual->u8s_string()); + EXPECT_EQ(expected, actual->GetR1U8AsString()); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -360,7 +387,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectEqualTuple(expected, *actual); + LiteralTestUtil::ExpectEqual(expected, *actual); } void ClientLibraryTestBase::ComputeAndCompareTuple( @@ -372,7 +399,7 @@ void ClientLibraryTestBase::ComputeAndCompareTuple( return; } auto actual = actual_status.ConsumeValueOrDie(); - LiteralTestUtil::ExpectNearTuple(expected, *actual, error); + LiteralTestUtil::ExpectNear(expected, *actual, error); } void ClientLibraryTestBase::ComputeAndCompare( @@ -499,17 +526,41 @@ std::unique_ptr ClientLibraryTestBase::CreateParameterAndTransferLiteral( int64 parameter_number, const Literal& literal, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle) { + return CreateParameterAndTransferLiteral(parameter_number, literal, name, + nullptr, builder, data_handle); +} + +std::unique_ptr +ClientLibraryTestBase::CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, ComputationBuilder* builder, + ComputationDataHandle* data_handle) { const Literal* param_literal = &literal; std::unique_ptr converted_literal; - if (use_bfloat16_ && literal.shape().element_type() == F32) { + if (use_bfloat16_) { converted_literal = LiteralTestUtil::ConvertF32ToBF16(literal); param_literal = converted_literal.get(); } std::unique_ptr data = - client_->TransferToServer(*param_literal).ConsumeValueOrDie(); + client_->TransferToServer(*param_literal, device_handle) + .ConsumeValueOrDie(); *data_handle = builder->Parameter(parameter_number, param_literal->shape(), name); return data; } +ComputationDataHandle ClientLibraryTestBase::AddParam( + const Literal& argument, ComputationBuilder* builder) { + ComputationDataHandle data_handle; + arguments_.push_back(CreateParameterAndTransferLiteral( + arguments_.size(), argument, "", builder, &data_handle)); + return data_handle; +} + +ComputationDataHandle ClientLibraryTestBase::CreateConstantFromLiteral( + const Literal& literal, ComputationBuilder* builder) { + return builder->ConstantLiteral( + use_bfloat16_ ? *LiteralTestUtil::ConvertF32ToBF16(literal) : literal); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 1d27880fb1413adbbe691b5d12cadcd85fbe5d92..ba0319990bc04196386e6812b0a03671676698ec 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -43,6 +43,23 @@ limitations under the License. namespace xla { +// Sets the use_bfloat16 on a container of test cases according to the values in +// use_bfloat16_params. Generates one set of test cases for each values in +// use_bfloat16_params with that value. Returns the result. +template +std::vector ExpandUseBfloat16( + tensorflow::gtl::ArraySlice use_bfloat16_params, + tensorflow::gtl::ArraySlice specs) { + std::vector expanded; + for (bool use_bfloat16 : use_bfloat16_params) { + for (const auto& spec : specs) { + expanded.push_back(spec); + expanded.back().use_bfloat16 = use_bfloat16; + } + } + return expanded; +} + // A client library test establishes an in-process XLA client connection. class ClientLibraryTestBase : public ::testing::Test { protected: @@ -194,7 +211,7 @@ class ClientLibraryTestBase : public ::testing::Test { tensorflow::gtl::ArraySlice arguments); void ComputeAndCompareTuple( ComputationBuilder* builder, const Literal& expected, - tensorflow::gtl::ArraySlice arguments, ErrorSpec abs_error); + tensorflow::gtl::ArraySlice arguments, ErrorSpec error); // Convenience method for running a built computation and comparing the result // with the HloEvaluator. @@ -253,6 +270,51 @@ class ClientLibraryTestBase : public ::testing::Test { int64 parameter_number, const Literal& literal, const string& name, ComputationBuilder* builder, ComputationDataHandle* data_handle); + // As above, but the caller can specify the device that the literal is + // transferred to. If device_handle is nullptr, the literal will be + // transferred to the default device. + std::unique_ptr CreateParameterAndTransferLiteral( + int64 parameter_number, const Literal& literal, const string& name, + const DeviceHandle* device_handle, ComputationBuilder* builder, + ComputationDataHandle* data_handle); + + // Creates a parameter instruction and sets the value that will be passed to + // the computation as specified. This function must be used for all parameters + // or none and no parameters must be passed when invoking the computation if + // using this mechanism. If using this mechanism, then each parameter must be + // set exactly once. The first added parameter gets index 0, then 1 and so on. + ComputationDataHandle AddParam(const Literal& argument, + ComputationBuilder* builder); + + template + ComputationDataHandle AddParam(const Array& argument, + ComputationBuilder* builder) { + return AddParam(*Literal::CreateFromArray(argument), builder); + } + + // Creates a constant instruction with the given literal. When the + // use_bfloat16 flag is set but the literal has F32 elements, the elements + // will be converted to BF16s. + ComputationDataHandle CreateConstantFromLiteral(const Literal& literal, + ComputationBuilder* builder); + + // Creates a constant instruction with the given array. When the use_bfloat16 + // flag is set but the array has float elements, the elements will be + // converted to bfloat16s. + template + ComputationDataHandle CreateConstantFromArray(const Array& array, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateFromArray(array), builder); + } + + // Same as CreateConstantFromArray, but for scalars. + template + ComputationDataHandle CreateConstantFromScalar(NativeT value, + ComputationBuilder* builder) { + return CreateConstantFromLiteral(*Literal::CreateR0(value), + builder); + } + // Creates a parameter instruction that wraps a given value and then stores // into "data_handle" the global handle for that parameter. // @@ -315,6 +377,9 @@ class ClientLibraryTestBase : public ::testing::Test { bool use_bfloat16() const { return use_bfloat16_; } void set_use_bfloat16(bool value) { use_bfloat16_ = value; } + // The float type used in this test, BF16 or F32 according to use_bfloat16. + PrimitiveType FloatType() const { return use_bfloat16_ ? BF16 : F32; } + Client* client_; ExecutionOptions execution_options_; @@ -344,6 +409,9 @@ class ClientLibraryTestBase : public ::testing::Test { // Whether to run tests with all float-type input/output converted to // bfloat16. bool use_bfloat16_ = false; + + // Arguments to be passed to the computation when it runs. + std::vector> arguments_; }; template @@ -363,6 +431,7 @@ void ClientLibraryTestBase::ComputeAndCompareR0( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -388,6 +457,7 @@ void ClientLibraryTestBase::ComputeAndCompareR1( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -413,6 +483,7 @@ void ClientLibraryTestBase::ComputeAndCompareR2( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -438,6 +509,7 @@ void ClientLibraryTestBase::ComputeAndCompareR3( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = @@ -463,6 +535,7 @@ void ClientLibraryTestBase::ComputeAndCompareR4( static_assert(std::is_same::value || std::is_same::value || std::is_same::value || + std::is_same::value || std::is_same::value, "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc index 8853ed9e5780672d4006c326291767b8b5253f56..045148cdd11da94ae4789a753efca95c6aaa1f27 100644 --- a/tensorflow/compiler/xla/tests/client_test.cc +++ b/tensorflow/compiler/xla/tests/client_test.cc @@ -36,7 +36,7 @@ namespace { class ClientTest : public ClientLibraryTestBase {}; -TEST_F(ClientTest, ExecuteWithLayout) { +XLA_TEST_F(ClientTest, ExecuteWithLayout) { ComputationBuilder b(client_, TestName()); std::vector> layouts = {{0, 1}, {1, 0}}; @@ -68,7 +68,7 @@ TEST_F(ClientTest, ExecuteWithLayout) { } } -TEST_F(ClientTest, ExecuteWithTupleLayout) { +XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) { ComputationBuilder b(client_, TestName()); b.Tuple({b.ConstantR2({{1, 2}, {3, 4}}), @@ -90,9 +90,9 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { auto result, client_->ExecuteAndTransfer(computation, {}, &execution_options)); LiteralTestUtil::ExpectR2Equal({{1, 2}, {3, 4}}, - result->tuple_literals(0)); + LiteralView::Create(*result, {0})); LiteralTestUtil::ExpectR2Equal({{10, 20}, {30, 40}}, - result->tuple_literals(1)); + LiteralView::Create(*result, {1})); EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); @@ -107,7 +107,8 @@ TEST_F(ClientTest, ExecuteWithTupleLayout) { /*minor_to_major=*/{1, 0}))); } -TEST_F(ClientTest, DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { +XLA_TEST_F(ClientTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) { Computation add_with_one_arg, mul_with_two_args, dot_with_one_arg; Shape shape = ShapeUtil::MakeShape(S32, {2, 2}); diff --git a/tensorflow/compiler/xla/tests/codegen_test_base.cc b/tensorflow/compiler/xla/tests/codegen_test_base.cc index e472408dcf7ed5fec74e886fd0092ce47ee2e7eb..022641394f113ef28e7c53058385d77572822213 100644 --- a/tensorflow/compiler/xla/tests/codegen_test_base.cc +++ b/tensorflow/compiler/xla/tests/codegen_test_base.cc @@ -21,9 +21,11 @@ StatusOr> CodegenTestBase::CompileToExecutable( std::unique_ptr hlo_module) { TF_ASSIGN_OR_RETURN(hlo_module, backend().compiler()->RunHloPasses( std::move(hlo_module), - backend().default_stream_executor())); + backend().default_stream_executor(), + /*device_allocator=*/nullptr)); return backend().compiler()->RunBackend(std::move(hlo_module), - backend().default_stream_executor()); + backend().default_stream_executor(), + /*device_allocator=*/nullptr); } StatusOr> diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index 5226a78386824a94572d3e5cc3329677108a910a..ec2c580670cfac14ba42e8c9a836c86551af4b89 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -149,7 +149,7 @@ TEST_F(ComputeConstantTest, Param) { auto computation = b.Add(param, b.ConstantR0(1.5f)); std::vector arguments; - arguments.emplace_back(*Literal::CreateR0(42.5f)); + arguments.push_back(std::move(*Literal::CreateR0(42.5f))); EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); auto value = @@ -168,7 +168,7 @@ TEST_F(ComputeConstantTest, DirectParamMissing) { auto value = ComputeConstantScalar(client, computation, &b); EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) + .contains("depends on a parameter")) << value.status(); } } @@ -184,7 +184,7 @@ TEST_F(ComputeConstantTest, IndirectParamMissing) { auto value = ComputeConstantScalar(client, computation, &b); EXPECT_TRUE(tensorflow::StringPiece(value.status().ToString()) - .contains("depends on parameter")) + .contains("depends on a parameter")) << value.status(); } } diff --git a/tensorflow/compiler/xla/tests/conditional_test.cc b/tensorflow/compiler/xla/tests/conditional_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc821674820fb128823786d7149037fc59b22ab6 --- /dev/null +++ b/tensorflow/compiler/xla/tests/conditional_test.cc @@ -0,0 +1,575 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" + +namespace xla { +namespace { + +class ConditionalOpTest : public ClientLibraryTestBase { + protected: + Computation CreateR0ConstantComputation(float value) { + ComputationBuilder builder(client_, "Constant"); + builder.Parameter(0, empty_tuple_, "tuple"); + builder.ConstantR0(value); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0IdentityComputation() { + ComputationBuilder builder(client_, "Identity"); + builder.Parameter(0, r0f32_, "x"); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateCeilComputation(const Shape& shape) { + ComputationBuilder builder(client_, "Ceil"); + auto param = builder.Parameter(0, shape, "param"); + builder.Ceil(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0CeilComputation() { + return CreateCeilComputation(r0f32_); + } + + Computation CreateR1CeilComputation() { + return CreateCeilComputation(r1s2f32_); + } + + Computation CreateFloorComputation(const Shape& shape) { + ComputationBuilder builder(client_, "Floor"); + auto param = builder.Parameter(0, shape, "param"); + builder.Floor(param); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0FloorComputation() { + return CreateFloorComputation(r0f32_); + } + + Computation CreateR1FloorComputation() { + return CreateFloorComputation(r1s2f32_); + } + + Computation CreateTupleCeilComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + auto x_ceil = builder.Ceil(x); + auto y_ceil = builder.Ceil(y); + builder.Tuple({x_ceil, y_ceil}); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleCeilComputation() { + return CreateTupleCeilComputation("CeilR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleCeilComputation() { + return CreateTupleCeilComputation("CeilR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleFloorComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + auto x_floor = builder.Floor(x); + auto y_floor = builder.Floor(y); + builder.Tuple({x_floor, y_floor}); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleFloorComputation() { + return CreateTupleFloorComputation("FloorR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleFloorComputation() { + return CreateTupleFloorComputation("FloorR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleAddComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Add(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleAddComputation() { + return CreateTupleAddComputation("AddR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleAddComputation() { + return CreateTupleAddComputation("AddR1", tuple_2_r1s2f32_); + } + + Computation CreateTupleSubComputation(const string& computation_name, + const Shape& tuple_shape) { + ComputationBuilder builder(client_, computation_name); + auto tuple = builder.Parameter(0, tuple_shape, "tuple"); + auto x = builder.GetTupleElement(tuple, 0); + auto y = builder.GetTupleElement(tuple, 1); + builder.Sub(x, y); + auto build_status = builder.Build(); + EXPECT_IS_OK(build_status.status()); + return build_status.ConsumeValueOrDie(); + } + + Computation CreateR0TupleSubComputation() { + return CreateTupleSubComputation("SubR0", tuple_2_r0f32_); + } + + Computation CreateR1TupleSubComputation() { + return CreateTupleSubComputation("SubR1", tuple_2_r1s2f32_); + } + + Shape r0f32_ = ShapeUtil::MakeShape(F32, {}); + Shape r1s2f32_ = ShapeUtil::MakeShape(F32, {2}); + Shape tuple_2_r0f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {})}); + Shape tuple_2_r1s2f32_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {2}), ShapeUtil::MakeShape(F32, {2})}); + Shape empty_tuple_ = ShapeUtil::MakeTupleShape({}); + ErrorSpec error_spec_{0.001}; +}; + +// Test true and false computations that do not take any parameters. +XLA_TEST_F(ConditionalOpTest, Parameters0) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({}); + auto true_computation = CreateR0ConstantComputation(56.0f); + auto false_computation = CreateR0ConstantComputation(12.0f); + auto result = builder.Conditional(pred, operands, true_computation, operands, + false_computation); + + ComputeAndCompareR0(&builder, 56.0f, {}, error_spec_); +} + +// Test true and false computations that take in 1 parameter. +XLA_TEST_F(ConditionalOpTest, Parameters1) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto identity = CreateR0IdentityComputation(); + auto result = + builder.Conditional(pred, operand1, identity, operand2, identity); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with two different computations in the true and false cases +// that take in different arguments. +XLA_TEST_F(ConditionalOpTest, DiffComputationsDiffArgs) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), + operand2, CreateR0FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with two different computations in the true and false cases +// that take in the same arguments. +XLA_TEST_F(ConditionalOpTest, DiffComputationsSameArg) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand = builder.ConstantR0(12.6f); + auto result = builder.Conditional(pred, operand, CreateR0CeilComputation(), + operand, CreateR0FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with the same computation in the true and false cases but +// take in different arguments. +XLA_TEST_F(ConditionalOpTest, SameComputationDiffArgs) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto floor = CreateR0FloorComputation(); + auto result = builder.Conditional(pred, operand1, floor, operand2, floor); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with the same computation in the true and false cases that +// take in the same arguments. +XLA_TEST_F(ConditionalOpTest, SameComputationSameArg) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand = builder.ConstantR0(12.6f); + auto floor = CreateR0FloorComputation(); + auto result = builder.Conditional(pred, operand, floor, operand, floor); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test conditional with different instances of the same computation in the true +// and false cases. +XLA_TEST_F(ConditionalOpTest, SameComputationDiffInstances) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + auto result = builder.Conditional(pred, operand1, CreateR0FloorComputation(), + operand2, CreateR0FloorComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test the case when a call invokes a computation that contains a conditional. +XLA_TEST_F(ConditionalOpTest, ConditionalWithCall) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + auto pred_cond = inner_builder.Parameter(0, r0bool, "param0"); + auto true_operand = inner_builder.Parameter(1, r0f32_, "param1"); + auto false_operand = inner_builder.Parameter(2, r0f32_, "param2"); + inner_builder.Conditional(pred_cond, true_operand, CreateR0CeilComputation(), + false_operand, CreateR0FloorComputation()); + auto inner_builder_result = inner_builder.Build(); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.4f); + auto operand2 = builder.ConstantR0(12.6f); + builder.Call(inner_builder_result.ConsumeValueOrDie(), + {pred, operand1, operand2}); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// true. +XLA_TEST_F(ConditionalOpTest, Parameters2TrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), + operands, CreateR0TupleSubComputation()); + + ComputeAndCompareR0(&builder, 68.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 parameters and predicate is +// false. +XLA_TEST_F(ConditionalOpTest, Parameters2FalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR0TupleAddComputation(), + operands, CreateR0TupleSubComputation()); + + ComputeAndCompareR0(&builder, 44.0f, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is true. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayTrueBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), + operands, CreateR1TupleSubComputation()); + + ComputeAndCompareR1(&builder, {34.0f, 67.0f}, {}, error_spec_); +} + +// Test true and false computations that take in 2 array parameters and +// predicate is false. +XLA_TEST_F(ConditionalOpTest, Parameters2ArrayFalseBranch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operand1 = builder.ConstantR1({24.0f, 56.0f}); + auto operand2 = builder.ConstantR1({10.0f, 11.0f}); + auto operands = builder.Tuple({operand1, operand2}); + auto result = + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), + operands, CreateR1TupleSubComputation()); + + ComputeAndCompareR1(&builder, {14.0f, 45.0f}, {}, error_spec_); +} + +// Test true and false computations that return a tuple of scalars. +XLA_TEST_F(ConditionalOpTest, ReturnTupleOfScalars) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operands = builder.Tuple( + {builder.ConstantR0(12.2f), builder.ConstantR0(25.6f)}); + builder.Conditional(pred, operands, CreateR0TupleCeilComputation(), operands, + CreateR0TupleFloorComputation()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(12.0f).get(), + Literal::CreateR0(25.0f).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a tuple of arrays. +XLA_TEST_F(ConditionalOpTest, ReturnTupleOfArrays) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({builder.ConstantR1({12.2f, 15.8f}), + builder.ConstantR1({25.6f, 29.2f})}); + builder.Conditional(pred, operands, CreateR1TupleCeilComputation(), operands, + CreateR1TupleFloorComputation()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR1({13.0f, 16.0f}).get(), + Literal::CreateR1({26.0f, 30.0f}).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a tuple of a predicate, a +// scalar, and an array. +XLA_TEST_F(ConditionalOpTest, ReturnTupleofPredicateScalarArray) { + ComputationBuilder true_builder(client_, TestName() + ".true"); + { + true_builder.Parameter(0, empty_tuple_, "tuple"); + auto true_pred = true_builder.ConstantR0(true); + auto true_scalar = true_builder.ConstantR0(12.2f); + auto true_array = true_builder.ConstantR1({12.8f, 14.6f}); + true_builder.Tuple({true_pred, true_scalar, true_array}); + } + auto true_builder_result = true_builder.Build(); + EXPECT_IS_OK(true_builder_result.status()); + + ComputationBuilder false_builder(client_, TestName() + ".false"); + { + false_builder.Parameter(0, empty_tuple_, "tuple"); + auto false_pred = false_builder.ConstantR0(false); + auto false_scalar = false_builder.ConstantR0(25.6f); + auto false_array = false_builder.ConstantR1({26.4f, 32.6f}); + false_builder.Tuple({false_pred, false_scalar, false_array}); + } + auto false_builder_result = false_builder.Build(); + EXPECT_IS_OK(false_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operands = builder.Tuple({}); + builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), + operands, false_builder_result.ConsumeValueOrDie()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple({Literal::CreateR0(true).get(), + Literal::CreateR0(12.2f).get(), + Literal::CreateR1({12.8f, 14.6f}).get()}), + {}, error_spec_); +} + +// Test true and false computations that return a nested tuple. +XLA_TEST_F(ConditionalOpTest, ReturnNestedTuple) { + ComputationBuilder true_builder(client_, TestName() + ".true"); + { + true_builder.Parameter(0, empty_tuple_, "tuple"); + auto true_constant1 = true_builder.ConstantR0(12.2f); + auto true_constant2 = true_builder.ConstantR1({12.8f, 14.6f}); + auto true_constant3 = true_builder.ConstantR1({25.4f, 29.8f}); + auto true_constant4 = true_builder.ConstantR0(35.6f); + true_builder.Tuple({true_builder.Tuple({true_constant1, true_constant2}), + true_builder.Tuple({true_constant3, true_constant4})}); + } + auto true_builder_result = true_builder.Build(); + EXPECT_IS_OK(true_builder_result.status()); + + ComputationBuilder false_builder(client_, TestName() + ".false"); + { + false_builder.Parameter(0, empty_tuple_, "tuple"); + auto false_constant1 = false_builder.ConstantR0(46.6f); + auto false_constant2 = false_builder.ConstantR1({54.4f, 58.4f}); + auto false_constant3 = false_builder.ConstantR1({62.1f, 67.4f}); + auto false_constant4 = false_builder.ConstantR0(9.3f); + false_builder.Tuple( + {false_builder.Tuple({false_constant1, false_constant2}), + false_builder.Tuple({false_constant3, false_constant4})}); + } + auto false_builder_result = false_builder.Build(); + EXPECT_IS_OK(false_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(false); + auto operands = builder.Tuple({}); + builder.Conditional(pred, operands, true_builder_result.ConsumeValueOrDie(), + operands, false_builder_result.ConsumeValueOrDie()); + + ComputeAndCompareTuple( + &builder, + *Literal::MakeTuple( + {Literal::MakeTuple({Literal::CreateR0(46.6f).get(), + Literal::CreateR1({54.4f, 58.4f}).get()}) + .get(), + Literal::MakeTuple({Literal::CreateR1({62.1f, 67.4f}).get(), + Literal::CreateR0(9.3f).get()}) + .get()}), + {}, error_spec_); +} + +// Test conditional that takes in scalar operands in the form of external +// params. +XLA_TEST_F(ConditionalOpTest, ScalarOperandsFromExternalParams) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle pred, operand1, operand2; + auto pred_arg = CreateR0Parameter(true, 0, "pred", &builder, &pred); + auto operand1_param = + CreateR0Parameter(56.3f, 1, "operand1", &builder, &operand1); + auto operand2_param = + CreateR0Parameter(12.7f, 2, "operand2", &builder, &operand2); + auto result = builder.Conditional(pred, operand1, CreateR0CeilComputation(), + operand2, CreateR0FloorComputation()); + + ComputeAndCompareR0( + &builder, 57.0f, + {pred_arg.get(), operand1_param.get(), operand2_param.get()}, + error_spec_); +} + +// Test conditional that takes in array operands in the form of external params. +XLA_TEST_F(ConditionalOpTest, ArrayOperandsFromExternalParams) { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + ComputationBuilder builder(client_, TestName()); + + ComputationDataHandle pred, operand1, operand2; + auto pred_arg = CreateR0Parameter(false, 0, "pred", &builder, &pred); + auto operand1_param = CreateR1Parameter({24.3f, 56.7f}, 1, "operand1", + &builder, &operand1); + auto operand2_param = CreateR1Parameter({10.2f, 11.6f}, 2, "operand2", + &builder, &operand2); + auto result = builder.Conditional(pred, operand1, CreateR1CeilComputation(), + operand2, CreateR1FloorComputation()); + + ComputeAndCompareR1( + &builder, {10.0f, 11.0f}, + {pred_arg.get(), operand1_param.get(), operand2_param.get()}, + error_spec_); +} + +// Test the case where one conditional is nested within another. +XLA_TEST_F(ConditionalOpTest, NestedConditionals) { + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); + } + auto inner_builder_result = inner_builder.Build(); + EXPECT_IS_OK(inner_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred1 = builder.ConstantR0(true); + auto pred2 = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(1.1f); + auto operand2 = builder.ConstantR0(12.2f); + auto operand3 = builder.ConstantR0(43.3f); + auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); + builder.Conditional(pred1, tuple_operand, + inner_builder_result.ConsumeValueOrDie(), operand3, + CreateR0IdentityComputation()); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +XLA_TEST_F(ConditionalOpTest, ConditionalInNestedComputation) { + ComputationBuilder inner_builder(client_, TestName() + ".inner_conditional"); + { + Shape r0bool = ShapeUtil::MakeShape(PRED, {}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({r0bool, r0f32_, r0f32_}); + auto param0 = inner_builder.Parameter(0, tuple_shape, "param0"); + auto pred_cond = inner_builder.GetTupleElement(param0, 0); + auto true_operand = inner_builder.GetTupleElement(param0, 1); + auto false_operand = inner_builder.GetTupleElement(param0, 2); + inner_builder.Conditional(pred_cond, true_operand, + CreateR0CeilComputation(), false_operand, + CreateR0FloorComputation()); + } + auto inner_builder_result = inner_builder.Build(); + EXPECT_IS_OK(inner_builder_result.status()); + + ComputationBuilder builder(client_, TestName()); + auto pred2 = builder.ConstantR0(false); + auto operand1 = builder.ConstantR0(1.1f); + auto operand2 = builder.ConstantR0(12.2f); + auto tuple_operand = builder.Tuple({pred2, operand1, operand2}); + builder.Call(inner_builder_result.ConsumeValueOrDie(), {tuple_operand}); + + ComputeAndCompareR0(&builder, 12.0f, {}, error_spec_); +} + +// Test a mismatch in the shape of the true operand and true computation. +XLA_TEST_F(ConditionalOpTest, ShapeMismatch) { + ComputationBuilder builder(client_, TestName()); + auto pred = builder.ConstantR0(true); + auto operand1 = builder.ConstantR0(56.0f); + auto operand2 = builder.ConstantR0(12.0f); + auto operands = builder.Tuple({operand1, operand2}); + builder.Conditional(pred, operands, CreateR1TupleAddComputation(), operands, + CreateR0TupleSubComputation()); + + auto result = builder.Build(); + EXPECT_FALSE(result.ok()); + EXPECT_THAT(result.status().error_message(), + ::testing::HasSubstr("true_operand must match the shape of the " + "only parameter of true_computation")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/constants_test.cc b/tensorflow/compiler/xla/tests/constants_test.cc index 97bd1553664a6c0fcb097b441ec42efb4eaa9cc2..35aa3f6d696297efb7d95d826ed75a504a24529d 100644 --- a/tensorflow/compiler/xla/tests/constants_test.cc +++ b/tensorflow/compiler/xla/tests/constants_test.cc @@ -141,11 +141,12 @@ TEST_F(ConstantsTest, Small_3x2x1x1) { {5.0f, 4.4f}, // p2 }); input_array.FillWithPZ(pz); - Literal input_literal = *Literal::CreateR4FromArray4D(input_array); + std::unique_ptr input_literal = + Literal::CreateR4FromArray4D(input_array); { ComputationBuilder builder(client_, TestName()); - builder.ConstantLiteral(input_literal); + builder.ConstantLiteral(*input_literal); ComputeAndCompareR4(&builder, input_array, {}, error_spec_); } @@ -165,10 +166,10 @@ TEST_F(ConstantsTest, DISABLED_TupleConstant) { std::unique_ptr result = ExecuteAndTransferOrDie(&builder, {}); - LiteralTestUtil::ExpectR2Near({{1.0}, {2.0}}, - result->tuple_literals(0), error_spec_); - LiteralTestUtil::ExpectR1Near({2.0, 42.0}, result->tuple_literals(1), - error_spec_); + LiteralTestUtil::ExpectR2Near( + {{1.0}, {2.0}}, LiteralView::Create(*result, {0}), error_spec_); + LiteralTestUtil::ExpectR1Near( + {2.0, 42.0}, LiteralView::Create(*result, {1}), error_spec_); } } // namespace diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 2924c08615fa706bb19addf04bf58e1d5dd5a659..0ceb9aff378ae8aa8098be9360310b1d78d31ab2 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -105,8 +105,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { })); ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -136,8 +136,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { })); // clang-format on ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -167,8 +167,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { })); // clang-format on ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -200,8 +200,8 @@ TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { })); // clang-format on ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, error_spec_); } @@ -501,10 +501,10 @@ XLA_TEST_P(ConvolveWithAndWithoutCanonicalization, Array2D expected_result(29, 10); expected_result.Fill(0); - ComputeAndCompare( - &builder, conv, - {*Literal::CreateFromArray(param0), *Literal::CreateFromArray(param1)}, - error_spec_); + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(param0)), + std::move(*Literal::CreateFromArray(param1))}, + error_spec_); } INSTANTIATE_TEST_CASE_P(ConvolveWithAndWithoutCanonicalization_Instantiation, @@ -608,5 +608,28 @@ INSTANTIATE_TEST_CASE_P( ); +TEST_F(ConvolutionTest, Convolve_bf16_1x1x1x2_1x1x1x2_Valid) { + ComputationBuilder builder(client_, TestName()); + Shape input_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(BF16, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + + Array4D input_data(1, 1, 1, 2); + input_data.FillWithYX(Array2D({ + {bfloat16(1), bfloat16(2)}, + })); + Array4D filter_data(1, 1, 1, 2); + filter_data.FillWithYX(Array2D({ + {bfloat16(5), bfloat16(6)}, + })); + + ComputeAndCompare(&builder, conv, + {std::move(*Literal::CreateFromArray(input_data)), + std::move(*Literal::CreateFromArray(filter_data))}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/copy_test.cc b/tensorflow/compiler/xla/tests/copy_test.cc index bcb85b04eefa349df1c055e010d584b85b55a4a8..ece7c3b05e7fafa299db7f9cbf50610c8204f95e 100644 --- a/tensorflow/compiler/xla/tests/copy_test.cc +++ b/tensorflow/compiler/xla/tests/copy_test.cc @@ -40,7 +40,7 @@ class CopyOpTest : public HloTestBase { void TestCopyOp(const Literal& literal) { auto builder = HloComputation::Builder(TestName()); auto constant = builder.AddInstruction( - HloInstruction::CreateConstant(MakeUnique(literal))); + HloInstruction::CreateConstant(literal.CloneToUnique())); builder.AddInstruction(HloInstruction::CreateUnary( constant->shape(), HloOpcode::kCopy, constant)); auto computation = builder.Build(); @@ -56,9 +56,13 @@ class CopyOpTest : public HloTestBase { tensorflow::gtl::ArraySlice permutation); }; -XLA_TEST_F(CopyOpTest, CopyR0Bool) { TestCopyOp(*Literal::CreateR0(true)); } +XLA_TEST_F(CopyOpTest, CopyR0Bool) { + TestCopyOp(*Literal::CreateR0(true)); +} -XLA_TEST_F(CopyOpTest, CopyR1S0U32) { TestCopyOp(*Literal::CreateR1({})); } +XLA_TEST_F(CopyOpTest, CopyR1S0U32) { + TestCopyOp(*Literal::CreateR1({})); +} XLA_TEST_F(CopyOpTest, CopyR1S3U32) { TestCopyOp(*Literal::CreateR1({1, 2, 3})); @@ -85,7 +89,6 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { // Copy literal to device to use as parameter. auto literal = Literal::CreateR0(42.0); Shape shape = literal->shape(); - auto constant_device_base = TransferToDevice(*literal); auto param0 = builder.AddInstruction( HloInstruction::CreateParameter(0, shape, "param0")); @@ -98,7 +101,7 @@ XLA_TEST_F(CopyOpTest, CopyParameterScalar) { module->AddEntryComputation(std::move(computation)); std::unique_ptr result = - ExecuteAndTransfer(std::move(module), {constant_device_base}); + ExecuteAndTransfer(std::move(module), {literal.get()}); LiteralTestUtil::ExpectR0Near(42.0f, *result, error_spec_); } @@ -129,7 +132,8 @@ XLA_TEST_F(CopyOpTest, CopyConstantR2DifferentLayouts) { std::unique_ptr literal = Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}); // Reverse the minor-to-major order of the literal. - Layout* literal_layout = literal->mutable_shape()->mutable_layout(); + Layout* literal_layout = + literal->mutable_shape_do_not_use()->mutable_layout(); ASSERT_EQ(2, literal_layout->minor_to_major_size()); literal_layout->mutable_minor_to_major()->SwapElements(0, 1); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 74f73a1ddc15be033e52b0b45f9961e5dc3a1ecb..2d847a66b0ae7c8f09fa0cb181a4c84ea99be5b1 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" @@ -128,5 +129,19 @@ XLA_TEST_F(CustomCallTest, Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, *result); } +class CustomCallClientAPITest : public ClientLibraryTestBase {}; + +// When using the client API, CustomCall targets can't begin with '$' -- these +// are reserved for internal use. +XLA_TEST_F(CustomCallClientAPITest, IllegalCustomCallTarget) { + ComputationBuilder builder(client_, TestName()); + auto call = builder.CustomCall("$illegal", /*operands=*/{}, + ShapeUtil::MakeShape(F32, {1})); + + StatusOr> result = + Execute(&builder, /*arguments=*/{}); + EXPECT_FALSE(result.ok()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index bfb04fd9f9bf6887c4462cb00fee00250517f5c4..6b0c04c2c083bbfce267dd92d24ef15c06186d26 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -51,8 +51,6 @@ class DotOperationTest : public ClientLibraryTestBase { template void TestNonsquareMatrixDot(bool lhs_row_major = false, bool rhs_row_major = false); - void TestMatrixDot(int M, int K, int N, bool lhs_row_major = false, - bool rhs_row_major = false); }; XLA_TEST_F(DotOperationTest, ZeroElementVectorDotF32) { @@ -199,158 +197,182 @@ void DotOperationTest::TestSquareMatrixDot(bool lhs_row_major, &builder, expected, {lhs_handle.get(), rhs_handle.get()}, error_spec_); } -void DotOperationTest::TestMatrixDot(int M, int K, int N, bool lhs_row_major, - bool rhs_row_major) { - std::unique_ptr> lhs_data = - MakeLinspaceArray2D(0.0, 1.0, M, K); - std::unique_ptr lhs_lit = Literal::CreateR2FromArray2DWithLayout( - *lhs_data, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(lhs_row_major))); - auto lhs_handle = client_->TransferToServer(*lhs_lit).ConsumeValueOrDie(); +struct DotTestParam { + int m; + int k; + int n; + bool dot_lhs_row_major; + bool dot_rhs_row_major; + bool has_addend; + bool addend_row_major; +}; + +string PrintDotTestParam( + const ::testing::TestParamInfo& test_param) { + const DotTestParam& param = test_param.param; + if (param.has_addend) { + return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, + "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F", + param.addend_row_major ? "T" : "F"); + } else { + return tensorflow::strings::StrCat(param.m, "x", param.k, "x", param.n, + "_MajorToMinor", + param.dot_lhs_row_major ? "T" : "F", + param.dot_rhs_row_major ? "T" : "F"); + } +} - std::unique_ptr> rhs_data = - MakeLinspaceArray2D(0.0, 1.0, K, N); - std::unique_ptr rhs_lit = Literal::CreateR2FromArray2DWithLayout( - *rhs_data, - LayoutUtil::MakeLayout(MinorToMajorForIsRowMajor(rhs_row_major))); - auto rhs_handle = client_->TransferToServer(*rhs_lit).ConsumeValueOrDie(); +class ParametricDotTest : public DotOperationTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(ParametricDotTest, TestF32) { + DotTestParam param = GetParam(); + + std::unique_ptr> dot_lhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.m, param.k); + std::unique_ptr dot_lhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_lhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_lhs_row_major))); + std::unique_ptr dot_lhs_handle = + client_->TransferToServer(*dot_lhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> dot_rhs_data = + MakeLinspaceArray2D(0.0, 1.0, param.k, param.n); + std::unique_ptr dot_rhs_lit = Literal::CreateR2FromArray2DWithLayout( + *dot_rhs_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.dot_rhs_row_major))); + std::unique_ptr dot_rhs_handle = + client_->TransferToServer(*dot_rhs_lit).ConsumeValueOrDie(); + + std::unique_ptr> addend_data; + std::unique_ptr addend_lit; + std::unique_ptr addend_handle; + + if (param.has_addend) { + addend_data = MakeLinspaceArray2D(0.0, 1.0, param.m, param.n); + addend_lit = Literal::CreateR2FromArray2DWithLayout( + *addend_data, LayoutUtil::MakeLayout( + MinorToMajorForIsRowMajor(param.addend_row_major))); + addend_handle = client_->TransferToServer(*addend_lit).ConsumeValueOrDie(); + } ComputationBuilder builder(client_, TestName()); auto prim_type = primitive_util::NativeToPrimitiveType(); auto result = builder.Dot( - builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {M, K}), "lhs"), - builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {K, N}), "rhs")); - - std::unique_ptr> expected = - ReferenceUtil::MatmulArray2D(*lhs_data, *rhs_data); - - ComputeAndCompareR2(&builder, *expected, - {lhs_handle.get(), rhs_handle.get()}, - ErrorSpec(0.3, 3e-3)); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTF) { - TestMatrixDot(12, 117, 7, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFT) { - TestMatrixDot(12, 117, 7, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorTT) { - TestMatrixDot(12, 117, 7, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_12_117_7_MinorToMajorFF) { - TestMatrixDot(12, 117, 7, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTT) { - TestMatrixDot(270, 270, 520, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorTF) { - TestMatrixDot(270, 270, 520, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFT) { - TestMatrixDot(270, 270, 520, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_270_270_520_MinorToMajorFF) { - TestMatrixDot(270, 270, 520, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTT) { - TestMatrixDot(269, 3, 520, true, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorTF) { - TestMatrixDot(260, 3, 520, true, false); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFT) { - TestMatrixDot(260, 3, 520, false, true); -} - -XLA_TEST_F(DotOperationTest, MatrixDotF32_260_3_520_MinorToMajorFF) { - TestMatrixDot(260, 3, 520, false, false); -} - -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x8) { - TestMatrixDot(1, 8, 8, true, true); -} + builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {param.m, param.k}), + "dot_lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {param.k, param.n}), + "dot_rhs")); + + if (param.has_addend) { + result = builder.Add( + result, + builder.Parameter( + 2, ShapeUtil::MakeShape(prim_type, {param.m, param.n}), "addend")); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x130x8) { - TestMatrixDot(1, 130, 8, true, true); -} + std::unique_ptr> expected; + if (param.has_addend) { + expected = ReferenceUtil::ApplyElementwise2D( + std::plus(), + *ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data), + *addend_data); + } else { + expected = ReferenceUtil::MatmulArray2D(*dot_lhs_data, *dot_rhs_data); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x130) { - TestMatrixDot(1, 8, 130, true, true); -} + std::vector args = {dot_lhs_handle.get(), dot_rhs_handle.get()}; + if (param.has_addend) { + args.push_back(addend_handle.get()); + } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x290x130) { - TestMatrixDot(1, 290, 130, true, true); + ComputeAndCompareR2(&builder, *expected, args, ErrorSpec(0.3, 3e-3)); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_2x1x1) { - TestMatrixDot(2, 1, 1, true, true); -} +std::vector CreateDotTestParameters() { + std::vector params; -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_8x8x1) { - TestMatrixDot(8, 8, 1, true, true); -} + auto add_matrix_matrix_dot_test = [&](int m, int k, int n) { + for (bool lhs_row_major : {true, false}) { + for (bool rhs_row_major : {true, false}) { + params.push_back({/*m=*/m, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/lhs_row_major, + /*dot_rhs_row_major=*/rhs_row_major, + /*has_addend=*/false, /*addend_row_major=*/true}); + } + } + }; + + auto add_matrix_vector_dot_test = [&](int k, int n) { + for (bool has_addend : {false, true}) { + params.push_back({/*m=*/1, /*k=*/k, /*n=*/n, + /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, + /*has_addend=*/has_addend, /*addend_row_major=*/true}); + if (n != 1) { + params.push_back( + {/*m=*/n, /*k=*/k, /*n=*/1, + /*dot_lhs_row_major=*/true, /*dot_rhs_row_major=*/true, + /*has_addend=*/has_addend, /*addend_row_major=*/true}); + } + } + }; -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x1x1) { - TestMatrixDot(16, 1, 1, true, true); -} + add_matrix_matrix_dot_test(/*m=*/12, /*k=*/117, /*n=*/7); + add_matrix_matrix_dot_test(/*m=*/270, /*k=*/270, /*n=*/520); + add_matrix_matrix_dot_test(/*m=*/260, /*k=*/3, /*n=*/520); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_16x3x1) { - TestMatrixDot(16, 3, 1, true, true); -} + add_matrix_vector_dot_test(/*k=*/8, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/130, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/8, /*n=*/130); + add_matrix_vector_dot_test(/*k=*/290, /*n=*/130); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/1); + add_matrix_vector_dot_test(/*k=*/1, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/3, /*n=*/16); + add_matrix_vector_dot_test(/*k=*/3, /*n=*/3); + add_matrix_vector_dot_test(/*k=*/29, /*n=*/29); + add_matrix_vector_dot_test(/*k=*/8, /*n=*/2); + add_matrix_vector_dot_test(/*k=*/2, /*n=*/8); + add_matrix_vector_dot_test(/*k=*/259, /*n=*/258); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_3x3x1) { - TestMatrixDot(3, 3, 1, true, true); + return params; } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_29x29x1) { - TestMatrixDot(29, 29, 1, true, true); -} +INSTANTIATE_TEST_CASE_P(DotTests, ParametricDotTest, + ::testing::ValuesIn(CreateDotTestParameters()), + PrintDotTestParam); -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x8x2) { - TestMatrixDot(1, 8, 2, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { + TestSquareMatrixDot(false, false); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_1x2x8) { - TestMatrixDot(1, 2, 8, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { + TestSquareMatrixDot(false, true); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1) { - TestMatrixDot(259, 258, 1, true, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { + TestSquareMatrixDot(true, false); } -XLA_TEST_F(DotOperationTest, MatrixVectorDotF32_259x258x1_FT) { - TestMatrixDot(259, 258, 1, false, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { + TestSquareMatrixDot(true, true); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFF) { - constexpr bool kLhsRowMajor = false; - constexpr bool kRhsRowMajor = false; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFF) { + TestSquareMatrixDot(false, false); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorFT) { - TestSquareMatrixDot(false, true); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorFT) { + TestSquareMatrixDot(false, true); } -XLA_TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTF) { - TestSquareMatrixDot(true, false); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTF) { + TestSquareMatrixDot(true, false); } -TEST_F(DotOperationTest, SquareMatrixDotF32MinorToMajorTT) { - constexpr bool kLhsRowMajor = true; - constexpr bool kRhsRowMajor = true; - TestSquareMatrixDot(kLhsRowMajor, kRhsRowMajor); +XLA_TEST_F(DotOperationTest, SquareMatrixDotC64MinorToMajorTT) { + TestSquareMatrixDot(true, true); } XLA_TEST_F(DotOperationTest, SquareMatrixDotF64) { @@ -498,9 +520,39 @@ XLA_TEST_F(DotOperationTest, BatchMatMul) { ComputeAndCompareR4( &builder, - /*expected=*/{{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, - {{{42900, 79200}, {429, 792}}, - {{250800, 299200}, {2508, 2992}}}}, + /*expected=*/ + {{{{1300, 2400}, {13, 24}}, {{11400, 13600}, {114, 136}}}, + {{{42900, 79200}, {429, 792}}, {{250800, 299200}, {2508, 2992}}}}, + {x_data.get(), y_data.get()}, error_spec_); +} + +XLA_TEST_F(DotOperationTest, GeneralMatMul) { + ComputationBuilder builder(client_, TestName()); + auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2}), "x"); + auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2, 2}), "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(2); + dnums.add_rhs_contracting_dimensions(1); + dnums.add_lhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(0); + + auto out = builder.DotGeneral(x, y, dnums); + + auto x_data = client_ + ->TransferToServer(*Literal::CreateR3( + {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}})) + .ConsumeValueOrDie(); + + auto y_data = client_ + ->TransferToServer(*Literal::CreateR3( + {{{1.0, 0.0}, {0.0, 1.0}}, {{1.0, 0.0}, {0.0, 1.0}}})) + .ConsumeValueOrDie(); + + ComputeAndCompareR3( + &builder, + /*expected=*/ + {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}}, {x_data.get(), y_data.get()}, error_spec_); } @@ -561,5 +613,95 @@ TEST_F(DotOperationTest, TransposeFolding) { } } +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstLHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_lhs_array(new Array2D( + {{1.0, 2.0, 3.0, 4.0, 5.0, 6.0}, {6.0, 5.0, 4.0, 3.0, 2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto lhs_constant = builder.ConstantR2FromArray2D(*constant_lhs_array); + auto rhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "rhs_arg_0"); + auto rhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {3, 2}), + "rhs_arg_1"); + auto rhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {1, 2}), + "rhs_arg_2"); + auto result = builder.Dot( + lhs_constant, builder.ConcatInDim({rhs_arg_0, rhs_arg_1, rhs_arg_2}, 0)); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0, 2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{53.0, 74.0}, {45.0, 66.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} + +TEST_F(DotOperationTest, DotOfConcatOptimizationWithConstRHS) { + auto prim_type = primitive_util::NativeToPrimitiveType(); + + std::unique_ptr> constant_rhs_array( + new Array2D({{1.0, 2.0}, + {3.0, 4.0}, + {5.0, 6.0}, + {6.0, 5.0}, + {4.0, 3.0}, + {2.0, 1.0}})); + + ComputationBuilder builder(client_, TestName()); + auto rhs_constant = builder.ConstantR2FromArray2D(*constant_rhs_array); + auto lhs_arg_0 = builder.Parameter(0, ShapeUtil::MakeShape(prim_type, {2, 2}), + "lhs_arg_0"); + auto lhs_arg_1 = builder.Parameter(1, ShapeUtil::MakeShape(prim_type, {2, 3}), + "lhs_arg_1"); + auto lhs_arg_2 = builder.Parameter(2, ShapeUtil::MakeShape(prim_type, {2, 1}), + "lhs_arg_2"); + auto result = builder.Dot( + builder.ConcatInDim({lhs_arg_0, lhs_arg_1, lhs_arg_2}, 1), rhs_constant); + + std::unique_ptr> arg_0_value_array( + new Array2D({{1.0, 2.0}, {3.0, 4.0}})); + std::unique_ptr> arg_1_value_array( + new Array2D({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}})); + std::unique_ptr> arg_2_value_array( + new Array2D({{1.0}, {2.0}})); + + TF_ASSERT_OK_AND_ASSIGN( + auto arg_0_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_0_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_1_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_1_value_array))); + TF_ASSERT_OK_AND_ASSIGN( + auto arg_2_value, + client_->TransferToServer( + *Literal::CreateR2FromArray2D(*arg_2_value_array))); + + Array2D expected({{38.0, 36.0}, {93.0, 91.0}}); + ComputeAndCompareR2( + &builder, expected, + {arg_0_value.get(), arg_1_value.get(), arg_2_value.get()}, error_spec_); +} } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index 8baaf39e3cf8fa7f6fa4a0224c1297f82e0d92aa..877dc7db0eec229a7119b3627f177a33ed0d971b 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -51,12 +51,16 @@ class DynamicSliceTest : public ClientLibraryTestBase { RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {3}, {2, 3, 4}); // Slice at dimension boundaries. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {5}, {3}, {5, 6, 7}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); // Zero element slice. RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {2}, {0}, {}); } + template + void TestR1Wrap() { + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR1({0, 1, 2, 3, 4, 5, 6, 7}, {6}, {4}, {6, 7, 0, 1}); + } + template void TestR2() { // Slice at dimension start. @@ -68,15 +72,19 @@ class DynamicSliceTest : public ClientLibraryTestBase { // Slice at dimension boundaries. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {2, 1}, {{5}, {8}}); - // Slice at dimension boundaries, but with sizes that cause indices to wrap. - RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, - {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); // Zero element slice: 2x0. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {2, 0}, {{}, {}}); // Zero element slice: 0x2. RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {0, 0}, {0, 2}, - Array2D(0, 2)); + Array2D(0, 2)); + } + + template + void TestR2Wrap() { + // Slice at dimension boundaries, but with sizes that cause indices to wrap. + RunR2({{1, 2, 3}, {4, 5, 6}, {7, 8, 9}}, {1, 1}, {3, 3}, + {{5, 6, 4}, {8, 9, 7}, {2, 3, 1}}); } template @@ -97,85 +105,119 @@ class DynamicSliceTest : public ClientLibraryTestBase { {{7, 8}, {9, 10}, {11, 12}}}, {0, 1, 1}, {2, 2, 1}, {{{4}, {6}}, {{10}, {12}}}); + // clang-format on + } + template + void TestR3Wrap() { // Slice at dimension boundaries, but with sizes that cause indices to wrap. RunR3( {{{1, 2}, {3, 4}, {5, 6}}, {{7, 8}, {9, 10}, {11, 12}}}, {0, 2, 1}, {2, 1, 2}, {{{6, 5}}, {{12, 11}}}); - - // clang-format on } template - void RunR1(tensorflow::gtl::ArraySlice input_values, + void RunR1(tensorflow::gtl::ArraySlice input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - tensorflow::gtl::ArraySlice expected_values) { + tensorflow::gtl::ArraySlice expected_values_int) { + // bfloat16 has explicit constructors, so it does not implicitly convert the + // way built-in types do, which is why we can't take the parameter as an + // ArraySlice. We also can't convert it to a vector, because + // vector is special so that it cannot be an ArraySlice, which + // is what the code below wants. So instead we do this. + Literal input_values = + std::move(*Literal::CreateR1(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR1(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR1(input_values); + auto input = builder.ConstantLiteral(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR1(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR2(const Array2D& input_values, + void RunR2(const Array2D& input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - const Array2D& expected_values) { + const Array2D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR2FromArray2D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR2FromArray2D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR2FromArray2D(input_values); + auto input = builder.ConstantLiteral(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR2(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR3(const Array3D& input_values, + void RunR3(const Array3D& input_values_int, const std::vector slice_starts, const std::vector& slice_sizes, - const Array3D& expected_values) { + const Array3D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR3FromArray3D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR3FromArray3D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR3FromArray3D(input_values); + auto input = builder.ConstantLiteral(input_values); builder.DynamicSlice(input, starts, slice_sizes); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } }; +XLA_TEST_F(DynamicSliceTest, Int32R1BF16) { TestR1(); } XLA_TEST_F(DynamicSliceTest, Int32R1) { TestR1(); } - +XLA_TEST_F(DynamicSliceTest, Int32R1Wrap) { TestR1Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R1) { TestR1(); } - XLA_TEST_F(DynamicSliceTest, UInt64R1) { TestR1(); } -XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } - +XLA_TEST_F(DynamicSliceTest, Int32R2BF16) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int32R2) { TestR2(); } +XLA_TEST_F(DynamicSliceTest, Int32R2Wrap) { TestR2Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R2) { TestR2(); } - XLA_TEST_F(DynamicSliceTest, UInt64R2) { TestR2(); } -XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } - +XLA_TEST_F(DynamicSliceTest, Int32R3BF16) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, Int32R3) { TestR3(); } +XLA_TEST_F(DynamicSliceTest, Int32R3Wrap) { TestR3Wrap(); } XLA_TEST_F(DynamicSliceTest, Int64R3) { TestR3(); } - XLA_TEST_F(DynamicSliceTest, UInt64R3) { TestR3(); } XLA_TEST_F(DynamicSliceTest, Int32R1Pred) { @@ -213,7 +255,7 @@ XLA_TEST_F(DynamicSliceTest, Int32R2Pred) { // Zero element slice: 0x2. RunR2( {{true, false, true}, {false, false, true}, {true, true, false}}, {0, 0}, - {0, 2}, Array2D(0, 2)); + {0, 2}, Array2D(0, 2)); } XLA_TEST_F(DynamicSliceTest, Int32R3Pred) { @@ -300,107 +342,154 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } template - void RunR1(tensorflow::gtl::ArraySlice input_values, - tensorflow::gtl::ArraySlice update_values, + void RunR1(tensorflow::gtl::ArraySlice input_values_int, + tensorflow::gtl::ArraySlice update_values_int, const std::vector slice_starts, - tensorflow::gtl::ArraySlice expected_values) { + tensorflow::gtl::ArraySlice expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR1(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_values = + std::move(*Literal::CreateR1(update_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR1(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR1(input_values); - auto update = builder.ConstantR1(update_values); + auto input = builder.ConstantLiteral(input_values); + auto update = builder.ConstantLiteral(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR1(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR2(const Array2D& input_values, - const Array2D& update_values, + void RunR2(const Array2D& input_values_int, + const Array2D& update_values_int, const std::vector slice_starts, - const Array2D& expected_values) { + const Array2D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR2FromArray2D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_values = + std::move(*Literal::CreateR2FromArray2D(update_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR2FromArray2D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR2FromArray2D(input_values); - auto update = builder.ConstantR2FromArray2D(update_values); + auto input = builder.ConstantLiteral(input_values); + auto update = builder.ConstantLiteral(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR2(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } template - void RunR3(const Array3D& input_values, - const Array3D& update_values, + void RunR3(const Array3D& input_values_int, + const Array3D& update_values_int, const std::vector slice_starts, - const Array3D& expected_values) { + const Array3D& expected_values_int) { + Literal input_values = + std::move(*Literal::CreateR3FromArray3D(input_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal update_values = + std::move(*Literal::CreateR3FromArray3D(update_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + Literal expected_values = + std::move(*Literal::CreateR3FromArray3D(expected_values_int) + ->Convert(primitive_util::NativeToPrimitiveType()) + .ValueOrDie()); + ComputationBuilder builder(client_, TestName()); // Initialize and transfer dynamic slice start indices parameter. ComputationDataHandle starts; std::unique_ptr start_data = CreateR1Parameter( slice_starts, 0, "slice_starts", &builder, &starts); // Build dynamic slice computation. - auto input = builder.ConstantR3FromArray3D(input_values); - auto update = builder.ConstantR3FromArray3D(update_values); + auto input = builder.ConstantLiteral(input_values); + auto update = builder.ConstantLiteral(update_values); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, {start_data.get()}); + ComputeAndCompareLiteral(&builder, expected_values, {start_data.get()}); } + template void RunR3Contiguous(std::vector operand_shape, int32 index, int32 size) { +#ifdef XLA_TEST_BACKEND_CPU_PARALLEL + // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. + if (std::is_same::value) { + return; + } +#endif + const int32 kSeq = operand_shape[0]; const int32 kBatch = operand_shape[1]; const int32 kDim = operand_shape[2]; - Array3D input_values(kSeq, kBatch, kDim); - Array3D update_values(size, kBatch, kDim); - Array3D expected_values(kSeq, kBatch, kDim); + Array3D input_values(kSeq, kBatch, kDim); + Array3D update_values(size, kBatch, kDim); + Array3D expected_values(kSeq, kBatch, kDim); - input_values.FillIota(0); - float val = 1000; - update_values.FillIota(val); + input_values.FillIota(static_cast(0)); + T value = static_cast(10); + update_values.FillIota(static_cast(value)); // TODO(b/34128753) Expected values may vary depending on backend when // the update wraps. According to documentation, the results are technically // implementation specific where the update is out of bounds, and hence // we don't really know what to pass into ComputeAndCompareR3. - expected_values.FillIota(0); + expected_values.FillIota(static_cast(0)); for (int i = 0; i < size; i++) { for (int j = 0; j < kBatch; j++) { for (int k = 0; k < kDim; k++) { - expected_values((index + i) % kSeq, j, k) = val++; + expected_values((index + i) % kSeq, j, k) = value++; } } } if (VLOG_IS_ON(1)) { - DumpArray("input", input_values); - DumpArray("update", update_values); - DumpArray("expected", expected_values); + DumpArray("input", input_values); + DumpArray("update", update_values); + DumpArray("expected", expected_values); } // Build dynamic slice computation. ComputationBuilder builder(client_, TestName()); // Initialize and transfer input parameter. ComputationDataHandle input; - std::unique_ptr input_data = CreateR3Parameter( - input_values, 0, "input_values", &builder, &input); + std::unique_ptr input_data = + CreateR3Parameter(input_values, 0, "input_values", &builder, &input); // Initialize and transfer update parameter. ComputationDataHandle update; - std::unique_ptr update_data = CreateR3Parameter( + std::unique_ptr update_data = CreateR3Parameter( update_values, 1, "update_values", &builder, &update); auto starts = builder.ConstantR1({index, 0, 0}); builder.DynamicUpdateSlice(input, update, starts); // Run computation and compare against expected values. - ComputeAndCompareR3(&builder, expected_values, - {input_data.get(), update_data.get()}, - ErrorSpec(0.000001)); + ComputeAndCompareR3(&builder, expected_values, + {input_data.get(), update_data.get()}, + ErrorSpec(0.000001)); } template @@ -411,28 +500,35 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase { } }; +// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) { + TestR1(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1(); } +// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) { + TestR2(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2(); } +// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10. +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R3BF16)) { + TestR3(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3(); } +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32WrapBF16)) { + TestWrap(); +} XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap(); } - XLA_TEST_F(DynamicUpdateSliceTest, Int64Wrap) { TestWrap(); } - XLA_TEST_F(DynamicUpdateSliceTest, UInt64Wrap) { TestWrap(); } XLA_TEST_F(DynamicUpdateSliceTest, Int32R1Pred) { @@ -498,36 +594,70 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int32R3Pred) { XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElement) { // Single element, no wrap. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousSingleElementBF16) { + // Single element, no wrap. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElements) { // Multiple element, no wrap. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleElementsBF16) { + // Multiple element, no wrap. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrapping) { // Multiple element, wrapping. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousMultipleWrappingBF16) { + // Multiple element, wrapping. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/3, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLarge) { // Multiple element, update size larger than operand. std::vector operand_shape({4, 5, 2}); - RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); + RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousTooLargeBF16) { + // Multiple element, update size larger than operand. + std::vector operand_shape({4, 5, 2}); + RunR3Contiguous(operand_shape, /*index=*/5, /*size=*/2); } XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnaligned) { std::vector operand_shape({3, 123, 247}); - RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, R3ContiguousUnalignedBF16) { + std::vector operand_shape({3, 123, 247}); + RunR3Contiguous(operand_shape, /*index=*/1, /*size=*/1); } // TODO(b/34134076) Disabled on GPU 2016-01-06 due to out-of-memory error. XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLarger)) { std::vector operand_shape({32, 128, 1024}); - RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); + RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); +} + +XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_GPU(R3ContiguousLargerBF16)) { + std::vector operand_shape({32, 128, 1024}); + RunR3Contiguous(operand_shape, /*index=*/7, /*size=*/1); } void BM_DynamicSlice(int num_iters) { @@ -559,20 +689,20 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0, - shape_size_fn) + auto buffer = client->backend() + .transfer_manager() + ->AllocateScopedShapedBuffer( + start_indices_shape, &allocator, /*device_ordinal=*/0) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *start_indices_literal, - buffer->mutable_buffer({}))); + executors[device_ordinal], *start_indices_literal, *buffer)); std::unique_ptr executable = - client->Compile(computation, {&buffer->shape()}, ExecutableBuildOptions()) + client + ->Compile(computation, {&buffer->on_host_shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); // Run some warm-up executions. diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..644cbbf40f296eb2a574ae568b4f32aa3d0bd12f --- /dev/null +++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc @@ -0,0 +1,71 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/global_data.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class ExecutionProfileTest : public ClientLibraryTestBase {}; + +XLA_TEST_F(ExecutionProfileTest, + DISABLED_ON_CPU_PARALLEL(ExecuteWithExecutionProfile)) { + Shape shape = ShapeUtil::MakeShape(F32, {256, 256}); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr input, + client_->TransferToServer( + *Literal::CreateR2F32Linspace(1e0, 1e5, 256, 256))); + + ComputationBuilder b(client_, TestName() + ".add"); + b.Dot(b.Parameter(0, shape, "param_0"), b.Parameter(1, shape, "param_1")); + TF_ASSERT_OK_AND_ASSIGN(Computation dot_product, b.Build()); + + ExecutionProfile execution_profile; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr data, + client_->Execute(dot_product, {input.get(), input.get()}, + &execution_options_, &execution_profile)); + + VLOG(3) << "execution_profile.compute_cycle_count() = " + << execution_profile.compute_cycle_count(); + VLOG(3) << "execution_profile.compute_and_transfer_time_ns() = " + << execution_profile.compute_and_transfer_time_ns(); + VLOG(3) << "execution_profile.compute_time_ns() = " + << execution_profile.compute_time_ns(); + + bool hlo_profiling_enabled = + execution_options_.debug_options().xla_hlo_profile(); + + // If HLO profiling is enabled we always expect cycle count to be populated. + // If HLO profiling is disabled then depending on the backend the cycle count + // may or may not be populated. + if (hlo_profiling_enabled) { + EXPECT_GT(execution_profile.compute_cycle_count(), 0); + } + + EXPECT_GT(execution_profile.compute_and_transfer_time_ns(), 0); + EXPECT_GT(execution_profile.compute_time_ns(), 0); + + TF_ASSERT_OK_AND_ASSIGN(auto computed, client_->Transfer(*data, &shape)); + (void)computed; +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6fe7737de7af349dca2931b52d62dbc03b14e0b3 --- /dev/null +++ b/tensorflow/compiler/xla/tests/exhaustive_f32_elementwise_op_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/core/lib/core/casts.h" + +namespace xla { +namespace { +class ExhaustiveF32ElementwiseOpTest + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> { + protected: + ErrorSpec error_spec_{0.0001, 0.0001, /*relaxed_nans=*/true}; + + template + void ExhaustivelyTestF32Op(EnqueueOpTy enqueue_op, + float (*evaluate_op)(float), + std::pair known_incorrect_range) { + int64 begin, end; + std::tie(begin, end) = GetParam(); + int64 input_size = end - begin; + LOG(INFO) << "Checking range [" << begin << ", " << end << ")"; + + ComputationBuilder builder(client_, TestName()); + + std::unique_ptr input_literal = + Literal::CreateFromDimensions(F32, {input_size}); + for (int64 i = begin; i < end; i++) { + if (i >= known_incorrect_range.first && + i < known_incorrect_range.second) { + // If the operation is known to be buggy on a specific input clamp that + // input to 0 under the assumption that the op is at least correct on 0. + input_literal->Set({i - begin}, 0.0f); + } else { + input_literal->Set({i - begin}, tensorflow::bit_cast(i)); + } + } + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + client_->TransferToServer(*input_literal)); + + auto input = builder.Parameter(0, input_literal->shape(), "input"); + enqueue_op(&builder, input); + + std::vector expected_result; + expected_result.reserve(input_size); + for (int64 i = 0; i < input_size; i++) { + expected_result.push_back(evaluate_op(input_literal->Get({i}))); + } + + ComputeAndCompareR1(&builder, expected_result, {input_data.get()}, + error_spec_); + } +}; + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, LogF32) { +#ifdef XLA_TEST_BACKEND_CPU + // TODO(b/73141998): The vectorized Log implementation gives results outside + // our error spec in this range (these numbers are bitwise representations of + // floats expressed as a zero extended int64): + std::pair known_incorrect_range = {1, 8315654}; +#else + std::pair known_incorrect_range = {0, 0}; +#endif + + ExhaustivelyTestF32Op( + [](ComputationBuilder* builder, const ComputationDataHandle& input) { + builder->Log(input); + }, + std::log, known_incorrect_range); +} + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, ExpF32) { +#ifdef XLA_TEST_BACKEND_CPU + // TODO(b/73142289): The vectorized Exp implementation gives results outside + // our error spec in this range (these numbers are bitwise representations of + // floats expressed as a zero extended int64): + std::pair known_incorrect_range = {1107296256 + 11583654, + 1107296256 + 11629080}; +#else + std::pair known_incorrect_range = {0, 0}; +#endif + + ExhaustivelyTestF32Op( + [](ComputationBuilder* builder, const ComputationDataHandle& input) { + builder->Exp(input); + }, + std::exp, known_incorrect_range); +} + +XLA_TEST_P(ExhaustiveF32ElementwiseOpTest, TanhF32) { + ExhaustivelyTestF32Op( + [](ComputationBuilder* builder, const ComputationDataHandle& input) { + builder->Tanh(input); + }, + std::tanh, /*known_incorrect_range=*/{0, 0}); +} + +std::vector> CreateExhaustiveParameters() { + // We break up the 2^32-element space into small'ish chunks to keep peak + // memory usage low. + std::vector> result; + const int64 step = 1 << 25; + for (int64 i = 0; i < (1l << 32); i += step) { + result.push_back({i, i + step}); + } + return result; +} + +INSTANTIATE_TEST_CASE_P(ExhaustiveF32ElementwiseOpTestInstance, + ExhaustiveF32ElementwiseOpTest, + ::testing::ValuesIn(CreateExhaustiveParameters())); +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/filecheck.h b/tensorflow/compiler/xla/tests/filecheck.h index 493ff7414bde31b18a39a5098925d9c991529b00..3830d5a44d2ca483fbe839231b0136d13033b48b 100644 --- a/tensorflow/compiler/xla/tests/filecheck.h +++ b/tensorflow/compiler/xla/tests/filecheck.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ -#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ #include @@ -30,4 +30,4 @@ StatusOr RunFileCheck(const string& input, const string& pattern); } // namespace xla -#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ +#endif // TENSORFLOW_COMPILER_XLA_TESTS_FILECHECK_H_ diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2686afccc216095345dbb7b43e916fbbe7c8ea39..a292eab1d198fbf69c6dc81c780487ea46756f72 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -816,7 +816,8 @@ void BM_ParallelFusion(int num_iters) { std::unique_ptr executable = client ->Compile(computation, - {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + {&buffer0->on_host_shape(), &buffer1->on_host_shape(), + &buffer2->on_host_shape()}, ExecutableBuildOptions()) .ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/half_test.cc b/tensorflow/compiler/xla/tests/half_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ec2f49d43bd8cee84c6b0abe1892e8b2278eefeb --- /dev/null +++ b/tensorflow/compiler/xla/tests/half_test.cc @@ -0,0 +1,257 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/client/computation.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" + +// Tests the handling of the basic mathematics operations with F16 operands. + +namespace xla { +namespace { + +class HalfTestBase : public ClientLibraryTestBase { + protected: + const ErrorSpec error_spec_{0.001, 0.001}; + // Number of elements in the input buffers. + static const int kNumElements = 4; +}; + +using UnaryBuildFuncTy = + std::function; + +struct UnaryOpTestParam { + std::function compute_func; + UnaryBuildFuncTy build_func; +}; + +class UnaryOpTest : public HalfTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(UnaryOpTest, Ops) { + std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + std::function compute_func = GetParam().compute_func; + std::vector expected; + for (int64 i = 0; i < x.size(); ++i) { + expected.push_back(compute_func(x[i])); + } + + UnaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd); + + ComputeAndCompareR1(&builder, expected, {x_data.get()}, error_spec_); +} + +half sign_imp(half value) { + const float x(std::move(value)); + return half((x < .0) ? -1 : (x > .0)); +} + +half round_imp(half value) { + return half(round(static_cast(std::move(value)))); +} + +INSTANTIATE_TEST_CASE_P( + half, UnaryOpTest, + ::testing::Values(UnaryOpTestParam{[](half x) { return abs(x); }, + &ComputationBuilder::Abs}, + UnaryOpTestParam{[](half x) { return round_imp(x); }, + &ComputationBuilder::Round}, + UnaryOpTestParam{[](half x) { return ceil(x); }, + &ComputationBuilder::Ceil}, + UnaryOpTestParam{[](half x) { return cos(x); }, + &ComputationBuilder::Cos}, + UnaryOpTestParam{[](half x) { return exp(x); }, + &ComputationBuilder::Exp}, + UnaryOpTestParam{[](half x) { return floor(x); }, + &ComputationBuilder::Floor}, + UnaryOpTestParam{[](half x) { return log(x); }, + &ComputationBuilder::Log}, + UnaryOpTestParam{[](half x) { return -x; }, + &ComputationBuilder::Neg}, + UnaryOpTestParam{[](half x) { return sign_imp(x); }, + &ComputationBuilder::Sign}, + UnaryOpTestParam{[](half x) { return sin(x); }, + &ComputationBuilder::Sin}, + UnaryOpTestParam{[](half x) { return tanh(x); }, + &ComputationBuilder::Tanh} + + )); + +struct UnaryPredTestParam { + std::function compute_func; + UnaryBuildFuncTy build_func; +}; + +class UnaryPredTest : public HalfTestBase, + public ::testing::WithParamInterface { +}; + +XLA_TEST_P(UnaryPredTest, Ops) { + std::vector x({half(1.4), half(-2.3), half(3.2), half(-4.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + std::function compute_func = GetParam().compute_func; + CHECK_EQ(kNumElements, x.size()); + bool expected[kNumElements]; + for (int64 i = 0; i < x.size(); ++i) { + expected[i] = compute_func(x[i]); + } + + UnaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd); + + ComputeAndCompareR1(&builder, expected, {x_data.get()}); +} + +INSTANTIATE_TEST_CASE_P(half, UnaryPredTest, + ::testing::Values(UnaryPredTestParam{ + [](half x) { return isfinite(x); }, + &ComputationBuilder::IsFinite})); + +using BinaryBuildFuncTy = std::function)>; + +struct BinaryOpTestParam { + std::function compute_func; + BinaryBuildFuncTy build_func; +}; + +class BinaryOpTest : public HalfTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(BinaryOpTest, Ops) { + std::vector x({half(1.0), half(2.0), half(3.0), half(-4.0)}); + std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + ComputationDataHandle y_opnd; + auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", + &builder, &y_opnd); + + std::function compute_func = GetParam().compute_func; + std::vector expected; + for (int64 i = 0; i < x.size(); ++i) { + expected.push_back(compute_func(x[i], y[i])); + } + + BinaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd, y_opnd, {}); + + ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}, + error_spec_); +} + +half atan2_imp(half x, half y) { + return half(atan2(static_cast(std::move(x)), + static_cast(std::move(y)))); +} + +INSTANTIATE_TEST_CASE_P( + half, BinaryOpTest, + ::testing::Values( + BinaryOpTestParam{[](half x, half y) { return x + y; }, + &ComputationBuilder::Add}, + BinaryOpTestParam{[](half x, half y) { return atan2_imp(x, y); }, + &ComputationBuilder::Atan2}, + BinaryOpTestParam{[](half x, half y) { return x / y; }, + &ComputationBuilder::Div}, + BinaryOpTestParam{[](half x, half y) { return max(x, y); }, + &ComputationBuilder::Max}, + BinaryOpTestParam{[](half x, half y) { return min(x, y); }, + &ComputationBuilder::Min}, + BinaryOpTestParam{[](half x, half y) { return x * y; }, + &ComputationBuilder::Mul}, + BinaryOpTestParam{[](half x, half y) { return pow(x, y); }, + &ComputationBuilder::Pow}, + BinaryOpTestParam{[](half x, half y) { return x - y; }, + &ComputationBuilder::Sub} + + )); + +struct BinaryPredTestParam { + std::function compute_func; + BinaryBuildFuncTy build_func; +}; + +class BinaryPredTest + : public HalfTestBase, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(BinaryPredTest, Ops) { + std::vector x({half(1.0), half(2.0), half(0.2), half(-4.0)}); + std::vector y({half(0.4), half(-0.3), half(0.2), half(0.1)}); + ComputationBuilder builder(client_, TestName()); + ComputationDataHandle x_opnd; + auto x_data = CreateR1Parameter(x, /*parameter_number=*/0, "x", + &builder, &x_opnd); + + ComputationDataHandle y_opnd; + auto y_data = CreateR1Parameter(y, /*parameter_number=*/1, "y", + &builder, &y_opnd); + + std::function compute_func = GetParam().compute_func; + CHECK_EQ(kNumElements, x.size()); + bool expected[kNumElements]; + for (int64 i = 0; i < x.size(); ++i) { + expected[i] = compute_func(x[i], y[i]); + } + + BinaryBuildFuncTy build_func = GetParam().build_func; + build_func(&builder, x_opnd, y_opnd, {}); + + ComputeAndCompareR1(&builder, expected, {x_data.get(), y_data.get()}); +} + +INSTANTIATE_TEST_CASE_P( + half, BinaryPredTest, + ::testing::Values(BinaryPredTestParam{[](half x, half y) { return x == y; }, + &ComputationBuilder::Eq}, + BinaryPredTestParam{[](half x, half y) { return x != y; }, + &ComputationBuilder::Ne}, + BinaryPredTestParam{[](half x, half y) { return x >= y; }, + &ComputationBuilder::Ge}, + BinaryPredTestParam{[](half x, half y) { return x > y; }, + &ComputationBuilder::Gt}, + BinaryPredTestParam{[](half x, half y) { return x <= y; }, + &ComputationBuilder::Le}, + BinaryPredTestParam{[](half x, half y) { return x < y; }, + &ComputationBuilder::Lt} + + )); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index d73c05ff92578209143e0679558848160cae99bd..9f5806c5e16c30cf198027cffab5f78c315cb957 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -15,13 +15,22 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include #include #include #include +#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -30,44 +39,235 @@ namespace se = ::perftools::gputools; namespace xla { +namespace { + +using tensorflow::StringPiece; +using tensorflow::gtl::ArraySlice; +using tensorflow::gtl::optional; + +constexpr char kInterpreter[] = "interpreter"; + +// Helper functions to get test and reference platforms. +se::Platform* GetReferencePlatform() { + auto result = PlatformUtil::GetPlatform(kInterpreter); + TF_CHECK_OK(result.status()) << "could not get interpreter platform"; + return result.ValueOrDie(); +} + +se::Platform* GetTestPlatform() { + auto result = PlatformUtil::GetDefaultPlatform(); + TF_CHECK_OK(result.status()) << "could not get test platform"; + return result.ValueOrDie(); +} + +bool ProgramShapesEqual(const ProgramShape& lhs, const ProgramShape& rhs) { + if (lhs.parameters_size() != rhs.parameters_size()) { + return false; + } + for (int i = 0; i < lhs.parameters_size(); i++) { + if (!ShapeUtil::Equal(lhs.parameters(i), rhs.parameters(i))) { + return false; + } + } + return ShapeUtil::Equal(lhs.result(), rhs.result()); +} + +ProgramShape GetProgramShapeWithLayout(const HloModule& module) { + ProgramShape program_shape; + const auto* entry = module.entry_computation(); + for (const auto* param : entry->parameter_instructions()) { + *program_shape.add_parameters() = param->shape(); + *program_shape.add_parameter_names() = param->name(); + } + *program_shape.mutable_result() = entry->root_instruction()->shape(); + return program_shape; +} + +} // namespace + +HloTestBase::HloTestBase() + : HloTestBase(GetTestPlatform(), GetReferencePlatform()) {} + +HloTestBase::HloTestBase(se::Platform* test_platform, + se::Platform* reference_platform) + : test_runner_(test_platform), reference_runner_(reference_platform) { + hlo_verifier_ = MakeUnique(); +} + /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; + config.set_debug_options(GetDebugOptionsForTest()); + return MakeUnique(TestName(), VersionedComputationHandle(), + config); +} +/*static*/ DebugOptions HloTestBase::GetDebugOptionsForTest() { auto debug_options = legacy_flags::GetDebugOptionsFromFlags(); // TODO(b/38354253): Change tests to use Parameters instead of Constants. debug_options.add_xla_disable_hlo_passes("constant_folding"); + return debug_options; +} - config.set_debug_options(debug_options); - - return MakeUnique(TestName(), VersionedComputationHandle(), - config); +StatusOr> HloTestBase::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments); } -StatusOr HloTestBase::Execute( +std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape) { - return runner_.Execute(std::move(module), arguments, result_shape); + tensorflow::gtl::ArraySlice arguments) { + return test_runner_.Execute(std::move(module), arguments).ValueOrDie(); } -se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - return runner_.TransferToDevice(literal).ValueOrDie(); +StatusOr> HloTestBase::MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor) { + std::unique_ptr reference_module = test_module.Clone(); + const auto& program_shape = GetProgramShapeWithLayout(test_module); + + if (reference_preprocessor != nullptr) { + reference_preprocessor(reference_module.get()); + if (!ProgramShapesEqual(program_shape, + GetProgramShapeWithLayout(*reference_module))) { + return InvalidArgument( + "reference preprocessor must not modify the program shape"); + } + } + TF_RETURN_IF_ERROR(VerifyHloModule(*reference_runner_.backend().platform(), + reference_module.get())); + return std::move(reference_module); } -std::unique_ptr HloTestBase::TransferFromDevice( - const Shape& shape, se::DeviceMemoryBase device_base) { - return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); +template +StatusOr<::testing::AssertionResult> HloTestBase::RunAndCompareInternal( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor) { + static_assert( + std::is_same::value || + std::is_same, LiteralPtr>::value, + "The LiteralPtr type only accepts Literal* or std::unique_ptr."); + TF_RETURN_IF_ERROR( + VerifyHloModule(*test_runner_.backend().platform(), module.get())); + TF_ASSIGN_OR_RETURN(auto reference_module, + MakeReferenceModule(*module, reference_preprocessor)); + + // Execute on two backends. + TF_ASSIGN_OR_RETURN( + auto test, + test_runner_.Execute(std::move(module), arguments, run_hlo_passes)); + TF_ASSIGN_OR_RETURN(auto reference, + reference_runner_.Execute(std::move(reference_module), + arguments, run_hlo_passes)); + return LiteralTestUtil::NearOrEqual(/*expected=*/*reference, /*actual=*/*test, + error); } -std::unique_ptr HloTestBase::ExecuteAndTransfer( - std::unique_ptr module, - tensorflow::gtl::ArraySlice arguments) { - return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); +template +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/true, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +template +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const ArraySlice arguments, + const optional& error, + const std::function& reference_preprocessor) { + auto result = + RunAndCompareInternal(std::move(module), arguments, error, + /*run_hlo_passes=*/false, reference_preprocessor); + if (!result.ok()) { + return ::testing::AssertionFailure() << result.status(); + } + return result.ValueOrDie(); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompare>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + std::unique_ptr module, const optional& error, + const std::function& reference_preprocessor) { + const auto& fake_arguments = + MakeFakeArguments(module.get()).ConsumeValueOrDie(); + return RunAndCompareNoHloPasses>( + std::move(module), fake_arguments, error, reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompare( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompare(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPasses( + const StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::CreateModuleFromString(hlo_string, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "Error while parsing HLO text format: " + << module_or_status.status().ToString(); + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); +} + +::testing::AssertionResult HloTestBase::RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor) { + auto module_or_status = + HloRunner::ReadModuleFromHloTextFile(filename, GetDebugOptionsForTest()); + if (!module_or_status.ok()) { + return ::testing::AssertionFailure() + << "failed reading hlo module from file"; + } + return RunAndCompareNoHloPasses(module_or_status.ConsumeValueOrDie(), error, + reference_preprocessor); } -Backend& HloTestBase::backend() { return runner_.backend(); } +Backend& HloTestBase::backend() { return test_runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 7f068dce36be3546298de2f06bf6d33446d07ca2..4aea9fc9fd027231106e529eb16bcd43f23fbe1c 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -24,52 +24,150 @@ limitations under the License. #include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" namespace xla { -// A base class for tests which build and run HLO code. This is a lower level of -// abstraction than using the client interface and enables, for one, explicitly -// building a graph of HLO instructions to run. +// A base class for tests which build and/or run HLO code. The class includes +// support for running an HLO module on two platforms and compare the results. +// This is a lower level of abstraction than using the client interface and +// enables, for one, explicitly building a graph of HLO instructions to run. +// +// This can also be used to write text/file-based test cases. Note that the test +// target is responsible for linking the needed backends. A covenient way to do +// this is to make it an xla_test: it will generate test targets linking with +// the respective backends, which will be used as the test backend; the +// interpreter backend is already linked with hlo_test_base so it will be the +// default reference backend. For example, if you want to compare both cpu vs. +// interpreter, and gpu vs. interpreter, you can: +// +// xla_test ( +// name = "sample_text_test", +// srcs = ["sample_text_test.cc"], +// backends = [ +// "cpu", +// "gpu", +// ], +// deps = [ +// "//third_party/tensorflow/compiler/xla/tests:hlo_test_base", +// ... +// ], +// ) +// +// For a more detailed example, see "../tests/sample_text_test.cc". class HloTestBase : public ::testing::Test { protected: - HloTestBase() {} + // This uses the interpreter backend as the reference backend and + // automatically finds another supported backend as the test backend. If the + // interpreter is the only supported backend, it will be both the test backend + // and the reference backend. + HloTestBase(); + + // If your test doesn't use interpreter as the reference backend, you can use + // this constructor. Note that your test target is responsible for linking in + // both needed backends. + HloTestBase(::perftools::gputools::Platform* test_platform, + ::perftools::gputools::Platform* reference_platform); ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug - // options from command-line flags. It's recommended to use this method to - // create all HloModules for tests. + // options from command-line flags. If you want a fresh HloModule object and + // then add HloComputations to it, it's recommended to use this method in your + // tests. static std::unique_ptr CreateNewModule(); - // Executes the given module and returns a global data handle. - StatusOr Execute( + // Populates debug options from command-line flags and adjusts the options for + // testing. It is recommended to use this when you need to pass in + // DebugOptions, e.g. when creating a module from a string or a file. + static DebugOptions GetDebugOptionsForTest(); + + // Executes the given module and return the result as a Literal. + StatusOr> Execute( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments, - Shape* result_shape); + tensorflow::gtl::ArraySlice arguments); - // Transfers the given literal to the device and returns the data handle. - perftools::gputools::DeviceMemoryBase TransferToDevice( - const Literal& literal); + std::unique_ptr ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments); + + // Executes the given hlo module on two backends and compares results. + // + // 'arguments': the input of the hlo module. The LiteralPtr type accepts + // Literal* or std::unique_ptr. + // + // 'error': if has value, expects the results to be near (within the error + // bound). Otherwise, expects the results to be equal. + // + // 'reference_preprocessor': the module should be ready to run on the test + // backend, but it might need to be tailored so that it is able to run on the + // reference backend. Note that the program shape of the module must not be + // modified. + template + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Same as above, except that the module will be executed without Hlo + // optimization. + template + ::testing::AssertionResult RunAndCompareNoHloPasses( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; - // Transfers the array referred to by the given handle from the device and - // returns as a Literal. - std::unique_ptr TransferFromDevice( - const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + // Executes an hlo module with fake inputs and compares the results. + ::testing::AssertionResult RunAndCompare( + std::unique_ptr module, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; - // Executes the given module and return the result as a Literal. - std::unique_ptr ExecuteAndTransfer( + // Same as above, except that the module will be executed without Hlo + // optimization. + ::testing::AssertionResult RunAndCompareNoHloPasses( std::unique_ptr module, - tensorflow::gtl::ArraySlice - arguments); + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + + // Convenient wrappers for executing and comparing an hlo module with fake + // input. Module can be passed in directly, or parsed from an hlo_string, + // or loaded from a file. + ::testing::AssertionResult RunAndCompare( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPasses( + const tensorflow::StringPiece hlo_string, + const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; + ::testing::AssertionResult RunAndCompareNoHloPassesFromFile( + const string& filename, const tensorflow::gtl::optional& error, + const std::function& reference_preprocessor = nullptr) + TF_MUST_USE_RESULT; // Convenience method to force the layout of a given parameter in a module. // The layout of parameter number 'param_no' in the 'module' is set to @@ -99,14 +197,38 @@ class HloTestBase : public ::testing::Test { ->Clear(); } + // Return an HLO verifier constructed for the test backend. + HloVerifier& verifier() const { return *hlo_verifier_; } + static string TestName(); - // Returns the backend owned by the HloRunner. + // Returns the backend owned by the test runner. Backend& backend(); - HloRunner runner_; + HloRunner test_runner_; + HloRunner reference_runner_; + + std::unique_ptr hlo_verifier_; ErrorSpec error_spec_{0.0001}; + + private: + // Given the test module, makes a reference module that is ready to run on the + // reference platform. This assumes that the given module is ready to run on + // the test platform. + StatusOr> MakeReferenceModule( + const HloModule& test_module, + const std::function& reference_preprocessor); + + // Runs the module on two platforms with or without running hlo passes and + // compares the results. Returns whether the results are near or equal. If any + // error happens before the results are computed, returns the error status. + template + StatusOr<::testing::AssertionResult> RunAndCompareInternal( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice arguments, + const tensorflow::gtl::optional& error, bool run_hlo_passes, + const std::function& reference_preprocessor); }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc index 31060b9e80fcd50aefdedca27c70ec8a9b8be743..506091ddd8d1d8e6519525bb7031f4e8b296b5fb 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -23,15 +23,8 @@ limitations under the License. namespace xla { -/*static*/ int64 HloVerifiedTestBase::DefaultShapeSize(const Shape& shape) { - constexpr int64 kPointerSize = sizeof(void*); - if (ShapeUtil::IsOpaque(shape)) { - return kPointerSize; - } - return ShapeUtil::ByteSizeOf(shape, kPointerSize); -} - -HloVerifiedTestBase::HloVerifiedTestBase() : shape_size_fn_(DefaultShapeSize) {} +HloVerifiedTestBase::HloVerifiedTestBase() + : shape_verifier_(MakeUnique()) {} HloVerifiedTestBase::~HloVerifiedTestBase() { // We can't call the ASSERT or EXPECT test macros in destructors, so we @@ -47,7 +40,7 @@ void HloVerifiedTestBase::TearDown() { << "TearDown called more than once; it should be called exactly once."; tear_down_called_ = true; if (module_) { - HloVerifier verifier(shape_size_fn_); + HloVerifier verifier; xla::StatusOr mutated = verifier.Run(module_.get()); if (!mutated.ok()) { ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h index b3d6b5af3b46f932707abf309669d23c327d1334..492688bf7d682cf991cb8c09399492a0437f651b 100644 --- a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -28,14 +28,13 @@ namespace xla { // A base class for HLO tests that stores a default HloModule, and automatically // performs verification on that module on tear-down. class HloVerifiedTestBase : public HloTestBase { - public: - // Returns the size in bytes of the given shape, using a default pointer size. - static int64 DefaultShapeSize(const Shape& shape); - protected: HloVerifiedTestBase(); ~HloVerifiedTestBase() override; + // Constructs a default shape verifier. + std::unique_ptr MakeShapeVerifier(); + // Performs verification on the default HloModule returned by module(). // Automatically called by the testing framework for each test. // @@ -47,14 +46,14 @@ class HloVerifiedTestBase : public HloTestBase { HloModule& module(); // Sets the shape-size function used during hlo verification. If this isn't - // called, DefaultShapeSize is used instead. - void SetShapeSizeFn(std::function shape_size_fn) { - shape_size_fn_ = std::move(shape_size_fn); + // called, a default ShapeVerifier is used instead. + void SetShapeVerifier(std::unique_ptr shape_verifier) { + shape_verifier_ = std::move(shape_verifier); } private: std::unique_ptr module_; // Lazily populated. Access via module(). - std::function shape_size_fn_; + std::unique_ptr shape_verifier_; bool tear_down_called_ = false; }; diff --git a/tensorflow/compiler/xla/tests/isolated_convolution.hlo b/tensorflow/compiler/xla/tests/isolated_convolution.hlo new file mode 100644 index 0000000000000000000000000000000000000000..9452780930efbb1ecc13b35cd4ab53678d36c37f --- /dev/null +++ b/tensorflow/compiler/xla/tests/isolated_convolution.hlo @@ -0,0 +1,8 @@ +HloModule convolution.167: + +ENTRY %convolution.167 (parameter.0: f32[16,28,28,128], parameter.1: f32[3,3,128,128]) -> f32[16,28,28,128] { + %parameter.0 = f32[16,28,28,128]{3,0,2,1} parameter(0) + %parameter.1 = f32[3,3,128,128]{3,2,1,0} parameter(1) + ROOT %convolution.167 = f32[16,28,28,128]{3,0,2,1} convolution(f32[16,28,28,128]{3,0,2,1} %parameter.0, f32[3,3,128,128]{3,2,1,0} %parameter.1), window={size=3x3 pad=1_1x1_1}, dim_labels=b01f_01oi->b01f +} + diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 6aa27e5470d22a8c6698389a720a38e9ea254617..5aa71a9261dbd414d1499f15c9b83cd63b634b49 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -57,7 +57,8 @@ namespace xla { } for (int i = 0; i < expected.tuple_shapes_size(); ++i) { ::testing::AssertionResult result = - EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)) + << "mismatch in tuple index " << i; if (!result) { return result; } @@ -100,36 +101,57 @@ namespace xla { ASSERT_EQ(expected.ShortDebugString(), actual.ShortDebugString()); } +namespace { + +// Return a literal with all arrays of type FromNativeT converted to type +// ToNativeT in the given literal. +template +std::unique_ptr ConvertType(const Literal& literal) { + // First construct shape of the result. + Shape result_shape(literal.shape()); + ShapeUtil::ForEachMutableSubshape( + &result_shape, [](Shape* subshape, const ShapeIndex&) { + if (subshape->element_type() == + primitive_util::NativeToPrimitiveType()) { + subshape->set_element_type( + primitive_util::NativeToPrimitiveType()); + } + }); + auto result = MakeUnique(result_shape); + + // Then copy over the data from 'literal' converting FromNativeT values to + // ToNativeT values as necessary. + ShapeUtil::ForEachSubshape( + literal.shape(), + [&](const Shape& subshape, const ShapeIndex& shape_index) { + if (ShapeUtil::IsArray(subshape)) { + if (subshape.element_type() == + primitive_util::NativeToPrimitiveType()) { + auto src = literal.data(shape_index); + auto dest = result->data(shape_index); + for (int64 i = 0; i < src.size(); ++i) { + dest[i] = static_cast(src[i]); + } + } else { + TF_CHECK_OK(result->CopyFrom(literal, + /*dest_shape_index=*/shape_index, + /*src_shape_index=*/shape_index)); + } + } + }); + return result; +} + +} // namespace + /* static */ std::unique_ptr LiteralTestUtil::ConvertBF16ToF32( - const Literal& bf16_literal) { - CHECK_EQ(bf16_literal.shape().element_type(), BF16); - Shape converted_shape = bf16_literal.shape(); - converted_shape.set_element_type(F32); - auto converted = Literal::CreateFromShape(converted_shape); - if (!ShapeUtil::HasZeroElements(converted_shape)) { - std::vector index(converted_shape.dimensions_size(), 0); - do { - converted->Set( - index, static_cast(bf16_literal.Get(index))); - } while (IndexUtil::BumpIndices(converted_shape, &index)); - } - return converted; + const Literal& literal) { + return ConvertType(literal); } /* static */ std::unique_ptr LiteralTestUtil::ConvertF32ToBF16( - const Literal& f32_literal) { - CHECK_EQ(f32_literal.shape().element_type(), F32); - Shape converted_shape = f32_literal.shape(); - converted_shape.set_element_type(BF16); - auto converted = Literal::CreateFromShape(converted_shape); - if (!ShapeUtil::HasZeroElements(converted_shape)) { - std::vector index(converted_shape.dimensions_size(), 0); - do { - converted->Set( - index, static_cast(f32_literal.Get(index))); - } while (IndexUtil::BumpIndices(converted_shape, &index)); - } - return converted; + const Literal& literal) { + return ConvertType(literal); } namespace { @@ -279,6 +301,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case BF16: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case F16: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case F32: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; @@ -290,9 +315,14 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, break; case TUPLE: { bool tuple_match = true; - for (int i = 0; i < actual.tuple_literals_size(); ++i) { - auto result = - Equal(expected.tuple_literals(i), actual.tuple_literals(i)); + for (int i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", + ShapeUtil::HumanString(expected.shape()))); + + // Create LiteralViews of the expected and actual elements. + auto result = Equal(LiteralView::Create(expected, {i}), + LiteralView::Create(actual, {i})); tuple_match = tuple_match ? !!result : false; } match = tuple_match; @@ -313,25 +343,6 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, return result; } -/* static */ void LiteralTestUtil::ExpectEqualTuple(const Literal& expected, - const Literal& actual) { - VLOG(1) << "expected: " << expected.ToString(); - VLOG(1) << "actual: " << actual.ToString(); - - ASSERT_TRUE(ShapeUtil::IsTuple(expected.shape())); - ASSERT_TRUE(ShapeUtil::IsTuple(actual.shape())); - AssertEqualShapes(expected.shape(), actual.shape()); - for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { - const auto& expected_element = expected.tuple_literals(i); - const auto& actual_element = actual.tuple_literals(i); - if (ShapeUtil::IsTuple(expected_element.shape())) { - ExpectEqualTuple(expected_element, actual_element); - } else { - ExpectEqual(expected_element, actual_element); - } - } -} - namespace { // Helper class for comparing floating-point literals within an error bound. @@ -344,9 +355,9 @@ class NearComparator { // temporary files on failure. Returns true if literals match. bool ExpectNear(const Literal& expected, const Literal& actual) { VLOG(1) << "expected:"; - XLA_VLOG_LINES(1, expected.ToString()); + XLA_VLOG_LINES(1, TruncateHugeLiteral(expected)); VLOG(1) << "actual:"; - XLA_VLOG_LINES(1, actual.ToString()); + XLA_VLOG_LINES(1, TruncateHugeLiteral(actual)); // If the shapes mismatch, we simply fail the expectation instead of // printing out data, as it's a type error rather than a value error. @@ -365,16 +376,21 @@ class NearComparator { abs_expected_miscompare_sum_ = 0.0; max_rel_err_ = 0.0; max_abs_err_ = 0.0; - *miscompares_.mutable_shape() = - ShapeUtil::ChangeElementType(actual.shape(), PRED); - miscompares_.mutable_preds()->resize( - ShapeUtil::ElementsIn(miscompares_.shape()), false); + first_linear_index_ = -1; + last_linear_index_ = -1; + max_rel_linear_index_ = -1; + max_abs_linear_index_ = -1; + miscompares_ = Literal(ShapeUtil::ChangeElementType(actual.shape(), PRED)); + miscompares_.PopulateWithValue(false); multi_index_.resize(expected.shape().dimensions_size(), 0); switch (expected.shape().element_type()) { case BF16: ExpectLiteralsNear(expected, actual, 0); break; + case F16: + ExpectLiteralsNear(expected, actual, 0); + break; case F32: ExpectLiteralsNear(expected, actual, 0); break; @@ -393,21 +409,33 @@ class NearComparator { if (num_miscompares_ > 0) { if (!VLOG_IS_ON(1)) { LOG(INFO) << "expected: " << ShapeUtil::HumanString(expected.shape()) - << " " << expected.ToString(); + << " " << TruncateHugeLiteral(expected); LOG(INFO) << "actual: " << ShapeUtil::HumanString(actual.shape()) - << " " << actual.ToString(); + << " " << TruncateHugeLiteral(actual); + LOG(INFO) << "Dumping literals to temp files..."; + WriteLiteralToTempFile(expected, "expected"); + WriteLiteralToTempFile(actual, "actual"); + WriteLiteralToTempFile(miscompares_, "miscompares"); } EXPECT_TRUE(num_miscompares_ == 0) << "\nmax relative mismatch at index " - << LiteralTestUtil::MultiIndexAsString(max_rel_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), max_rel_linear_index_)) << "\nmaximum relative error " << max_rel_err_ << "\nmax absolute mismatch at index " - << LiteralTestUtil::MultiIndexAsString(max_abs_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), max_abs_linear_index_)) << "\nmaximum absolute error " << max_abs_err_ << "\nfirst mismatch at index " - << LiteralTestUtil::MultiIndexAsString(first_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), first_linear_index_)) << "\nlast mismatch at index " - << LiteralTestUtil::MultiIndexAsString(last_multi_index_) + << LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex( + actual.shape(), last_linear_index_)) << "\ntotal absolute error " << abs_diff_sum_ << "\ntotal absolute error of miscompares " << abs_diff_miscompare_sum_ << "\ntotal relative error " @@ -415,18 +443,18 @@ class NearComparator { << "\ntotal relative error of miscompares " << (abs_diff_miscompare_sum_ / abs_expected_miscompare_sum_) << "\nfailure count " << num_miscompares_; - - WriteLiteralToTempFile(expected, "expected"); - WriteLiteralToTempFile(actual, "actual"); - WriteLiteralToTempFile(miscompares_, "miscompares"); } return num_miscompares_ == 0; } private: template - bool NanMismatch(NativeT lhs, NativeT rhs) { - return std::isnan(lhs) != std::isnan(rhs); + bool NanMismatch(NativeT expected, NativeT actual, bool relaxed_nans) { + if (relaxed_nans) { + return !std::isnan(expected) && std::isnan(actual); + } else { + return std::isnan(expected) != std::isnan(actual); + } } template @@ -446,57 +474,94 @@ class NearComparator { return true; } - float abs_diff = std::abs(actual - expected); - float rel_err = abs_diff / std::abs(expected); + const float abs_diff = std::abs(actual - expected); + const float rel_err = abs_diff / std::abs(expected); + const bool nan_mismatch = + NanMismatch(expected, actual, error_.relaxed_nans); + const bool mismatch = + (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); + return !mismatch; + } + + // Assumes that expected vs actual fail ExpectValuesNear. + template + void UpdateAndLogMiscompares(const NativeT expected, const NativeT actual, + const Shape& shape, const int64 linear_index) { + const float abs_diff = std::abs(actual - expected); + const float rel_err = abs_diff / std::abs(expected); abs_diff_sum_ += abs_diff; abs_expected_sum_ += std::abs(expected); - if (rel_err > max_rel_err_) { + if (rel_err > max_rel_err_ || std::isnan(rel_err)) { max_rel_err_ = rel_err; - max_rel_multi_index_ = multi_index_; + max_rel_linear_index_ = linear_index; } - if (abs_diff > max_abs_err_) { + if (abs_diff > max_abs_err_ || std::isnan(abs_diff)) { max_abs_err_ = abs_diff; - max_abs_multi_index_ = multi_index_; + max_abs_linear_index_ = linear_index; } - VLOG(10) << tensorflow::strings::Printf( - "index %s abs_diff %f rel_err %f", - LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff, - rel_err); - bool nan_mismatch = NanMismatch(expected, actual); - bool mismatch = - (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); - if (mismatch) { - abs_diff_miscompare_sum_ += abs_diff; - abs_expected_miscompare_sum_ += std::abs(expected); - const int64 kMaxFailures = 2; - if (num_miscompares_ < kMaxFailures) { - ::testing::Message msg; - msg << "mismatch at index " - << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff " - << abs_diff << " rel err " << rel_err << " failure #" - << num_miscompares_; - ExpectNear(expected, actual, msg); - } else if (num_miscompares_ == kMaxFailures) { - LOG(ERROR) - << "reached max 'loud' failure count; silently proceeding..."; - } - if (num_miscompares_ == 0) { - first_multi_index_ = multi_index_; - } - num_miscompares_++; - last_multi_index_ = multi_index_; + if (VLOG_IS_ON(10)) { + VLOG(10) << tensorflow::strings::Printf( + "index %s abs_diff %f rel_err %f", + LiteralTestUtil::MultiIndexAsString( + IndexUtil::LinearIndexToMultidimensionalIndex(shape, + linear_index)) + .c_str(), + abs_diff, rel_err); } - return !mismatch; + abs_diff_miscompare_sum_ += abs_diff; + abs_expected_miscompare_sum_ += std::abs(expected); + const int64 kMaxFailures = 2; + if (num_miscompares_ < kMaxFailures) { + const auto multi_index = + IndexUtil::LinearIndexToMultidimensionalIndex(shape, linear_index); + ::testing::Message msg; + msg << "mismatch at index " + << LiteralTestUtil::MultiIndexAsString(multi_index) << " abs diff " + << abs_diff << " rel err " << rel_err << " failure #" + << num_miscompares_; + ExpectNear(expected, actual, msg); + } else if (num_miscompares_ == kMaxFailures) { + LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; + } + if (num_miscompares_ == 0) { + first_linear_index_ = linear_index; + } + num_miscompares_++; + last_linear_index_ = linear_index; + miscompares_.data()[linear_index] = true; } // Recursive function which compares the two given literals elementwise. template void ExpectLiteralsNear(const Literal& expected, const Literal& actual, int64 dimension) { + // Fast path optimization for the case were layouts match. + if (LayoutUtil::Equal(actual.shape().layout(), expected.shape().layout())) { + tensorflow::gtl::ArraySlice expected_data = + expected.data(); + tensorflow::gtl::ArraySlice actual_data = + actual.data(); + const int64 len = expected_data.size(); + for (int64 i = 0; i < len; ++i) { + const bool near = ExpectValuesNear(expected_data[i], actual_data[i]); + if (!near) { + UpdateAndLogMiscompares(expected_data[i], actual_data[i], + actual.shape(), i); + } + } + return; + } + if (dimension == expected.shape().dimensions_size()) { bool near = ExpectValuesNear(expected.Get(multi_index_), actual.Get(multi_index_)); - miscompares_.Set(multi_index_, !near); + if (!near) { + UpdateAndLogMiscompares( + expected.Get(multi_index_), + actual.Get(multi_index_), actual.shape(), + IndexUtil::MultidimensionalIndexToLinearIndex(actual.shape(), + multi_index_)); + } } else { for (int64 i = 0; i < expected.shape().dimensions(dimension); ++i) { multi_index_[dimension] = i; @@ -517,6 +582,32 @@ class NearComparator { LOG(ERROR) << "wrote to " << name << " file: " << filename; } + // Gets the total element count. For tuples, this is not the count of tuple + // elements, but the sum of elements of each tuple element. + int64 RecursiveElementCount(const Shape& shape) { + if (ShapeUtil::IsTuple(shape)) { + const int64 tuple_elements = ShapeUtil::TupleElementCount(shape); + int64 total = 0; + for (int64 i = 0; i < tuple_elements; ++i) { + total += + RecursiveElementCount(ShapeUtil::GetTupleElementShape(shape, i)); + } + return total; + } else { + return ShapeUtil::ElementsIn(shape); + } + } + + // Calling ToString on a literal with over 100 million elements takes around + // 3 minutes. The utility of printing a literal with >1000 elements is + // questionable, especially when writing the Literal proto to disk is orders + // of magnitude faster. + string TruncateHugeLiteral(const Literal& literal) { + return RecursiveElementCount(literal.shape()) < 1000 + ? literal.ToString() + : "[TRUNCATED, Literal with more than 1000 values]"; + } + ErrorSpec error_; // Number of element miscomparisons encountered so far. @@ -537,16 +628,18 @@ class NearComparator { double abs_expected_miscompare_sum_; float max_rel_err_; float max_abs_err_; - std::vector first_multi_index_; - std::vector last_multi_index_; - std::vector max_rel_multi_index_; - std::vector max_abs_multi_index_; + int64 first_linear_index_; + int64 last_linear_index_; + int64 max_rel_linear_index_; + int64 max_abs_linear_index_; }; template <> -bool NearComparator::NanMismatch(complex64 lhs, complex64 rhs) { - return std::isnan(lhs.real()) != std::isnan(rhs.real()) || - std::isnan(lhs.imag()) != std::isnan(rhs.imag()); +bool NearComparator::NanMismatch(complex64 expected, + complex64 actual, + bool relaxed_nans) { + return NanMismatch(expected.real(), actual.real(), relaxed_nans) || + NanMismatch(expected.imag(), actual.imag(), relaxed_nans); } template <> @@ -567,14 +660,64 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, static_cast(actual)); } +template <> +bool NearComparator::ExpectValuesNear(half expected, half actual) { + return ExpectValuesNear(static_cast(std::move(expected)), + static_cast(std::move(actual))); +} + +template <> +void NearComparator::UpdateAndLogMiscompares( + const bfloat16 expected, const bfloat16 actual, const Shape& shape, + const int64 linear_index) { + UpdateAndLogMiscompares(static_cast(expected), + static_cast(actual), shape, linear_index); +} + +template <> +void NearComparator::UpdateAndLogMiscompares(half expected, half actual, + const Shape& shape, + const int64 linear_index) { + UpdateAndLogMiscompares(static_cast(std::move(expected)), + static_cast(std::move(actual)), shape, + linear_index); +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) { - NearComparator comparator(error); - return comparator.ExpectNear(expected, actual) - ? ::testing::AssertionSuccess() - : ::testing::AssertionFailure() << "values were not near"; + ::testing::AssertionResult err = + EqualShapes(expected.shape(), actual.shape()); + if (!err) { + return err; + } + + if (ShapeUtil::IsTuple(expected.shape())) { + for (int64 i = 0; i < ShapeUtil::TupleElementCount(expected.shape()); ++i) { + SCOPED_TRACE(tensorflow::strings::StrCat( + "Tuple index ", i, " in ", ShapeUtil::HumanString(expected.shape()))); + const auto expected_element = LiteralView::Create(expected, {i}); + const auto actual_element = LiteralView::Create(actual, {i}); + + ::testing::AssertionResult res = + Near(expected_element, actual_element, error); + if (err && !res) { + err = res; + } + } + return err; + } + + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { + NearComparator comparator(error); + return comparator.ExpectNear(expected, actual) + ? ::testing::AssertionSuccess() + : ::testing::AssertionFailure() << "values were not near"; + } + + return Equal(expected, actual); } /* static */ void LiteralTestUtil::ExpectNear(const Literal& expected, @@ -587,47 +730,21 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, : tensorflow::strings::StrCat("\nmessage: ", message)); } -/* static */ ::testing::AssertionResult LiteralTestUtil::NearTuple( - const Literal& expected, const Literal& actual, const ErrorSpec& error) { - VLOG(1) << "expected: " << expected.ToString(); - VLOG(1) << "actual: " << actual.ToString(); - - if (!ShapeUtil::IsTuple(expected.shape()) || - !ShapeUtil::IsTuple(actual.shape())) { - return ::testing::AssertionFailure() - << "tuples expected expected shape = " - << expected.shape().ShortDebugString() - << " actual shape = " << actual.shape().ShortDebugString(); - } - AssertEqualShapes(expected.shape(), actual.shape()); - for (uint64 i = 0; i < expected.tuple_literals_size(); ++i) { - const auto& expected_element = expected.tuple_literals(i); - const auto& actual_element = actual.tuple_literals(i); - if (ShapeUtil::IsTuple(expected_element.shape())) { - auto ret = NearTuple(expected_element, actual_element, error); - if (!ret) { - return ret; - } - } else if (ShapeUtil::ElementIsFloating(expected_element.shape())) { - auto ret = Near(expected_element, actual_element, error); - if (!ret) { - return ret; - } - } else { - auto ret = Equal(expected_element, actual_element); - if (!ret) { - return ret; - } - } +/*static*/ ::testing::AssertionResult LiteralTestUtil::NearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + if (error.has_value()) { + VLOG(1) << "Expects near"; + return Near(expected, actual, *error); } - - return ::testing::AssertionSuccess(); + VLOG(1) << "Expects equal"; + return Equal(expected, actual); } -/* static */ void LiteralTestUtil::ExpectNearTuple(const Literal& expected, - const Literal& actual, - const ErrorSpec& error) { - EXPECT_TRUE(NearTuple(expected, actual, error)); +/*static*/ void LiteralTestUtil::ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error) { + EXPECT_TRUE(NearOrEqual(expected, actual, error)); } /* static */ string LiteralTestUtil::MultiIndexAsString( @@ -644,10 +761,10 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, new_num_elements *= new_dimensions[i]; } CHECK_EQ(ShapeUtil::ElementsIn(literal.shape()), new_num_elements); + CHECK_EQ(new_dimensions.size(), minor_to_major.size()); - auto new_literal = MakeUnique(); - *new_literal->mutable_shape() = - ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions); + auto new_literal = MakeUnique( + ShapeUtil::MakeShape(literal.shape().element_type(), new_dimensions)); // Create a new shape with the given minor-to-major layout. This shape is used // solely for converting linear address to multi-dimensional addresses when @@ -655,9 +772,6 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, Shape shape_with_layout = new_literal->shape(); *shape_with_layout.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major); - // Allocate space in the new literal. - new_literal->Reserve(ShapeUtil::ElementsIn(literal.shape())); - // Copy data into new literal, element-by-element. for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { std::vector from_multi_index = @@ -697,6 +811,10 @@ bool NearComparator::ExpectValuesNear(bfloat16 expected, new_literal->Set(to_multi_index, literal.Get(from_multi_index)); break; + case C64: + new_literal->Set(to_multi_index, + literal.Get(from_multi_index)); + break; default: LOG(FATAL) << "Unhandled primitive element type: " << PrimitiveType_Name(literal.shape().element_type()); diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index 6e4add2690fd958d555eab3cef51cdbbd01819c9..7b757a4bd7e7592583b7596b4305ddb7e6c52d75 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -39,10 +40,16 @@ namespace xla { // Structure describing permissible absolute and relative error bounds. struct ErrorSpec { - explicit ErrorSpec(float aabs, float arel = 0) : abs(aabs), rel(arel) {} + explicit ErrorSpec(float aabs, float arel = 0, bool relaxed_nans = false) + : abs(aabs), rel(arel), relaxed_nans(relaxed_nans) {} float abs; // Absolute error bound. float rel; // Relative error bound. + + // If relaxed_nans is true then any result is valid if we are expecting NaNs. + // In effect, this allows the tested operation to produce incorrect results + // for inputs outside its mathematical domain. + bool relaxed_nans; }; // Utility class for making expectations/assertions related to XLA literals. @@ -59,10 +66,14 @@ class LiteralTestUtil { static void AssertEqualShapesAndLayouts(const Shape& expected, const Shape& actual); - // Converts a bfloat16 literal to a float literal. + // If the given literal's data type is bfloat16, converts it to a float + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. static std::unique_ptr ConvertBF16ToF32(const Literal& bf16_literal); - // Converts a float literal to a bfloat16 literal. + // If the given literal's data type is float, converts it to a bfloat16 + // literal; otherwise, returns a copy of it. If the literal is a tuple, + // recursively converts its elements. static std::unique_ptr ConvertF32ToBF16(const Literal& f32_literal); // Asserts that the expected and actual literals are (bitwise) equal for all @@ -106,13 +117,18 @@ class LiteralTestUtil { static void ExpectR4EqualArray4D(const Array4D& expected, const Literal& actual); - // Expects that the values of the elements in the expected and actual tuples - // are equal. Tuples are matched recursively. - static void ExpectEqualTuple(const Literal& expected, const Literal& actual); - // Asserts that the expected and actual literals are within the given error // bound for all elements. Also, asserts that the rank, dimensions sizes, and - // bounds are equivalent. Only supported for floating point values. + // bounds are equivalent. + // + // Tuples are matched recursively. When comparing tensors of + // non-floating-point type, checks for exact equality, ignoring the ErroSpec. + // + // If the shape of the literals is neither a complex/floating-point tensor nor + // a tuple which contains a complex/floating-point tensor, Near() is + // equivalent to Equal(). We don't raise an error in this case, because we + // want to allow callers to call Near() even if they have no preconceptions + // about the shapes being compared. static ::testing::AssertionResult Near( const Literal& expected, const Literal& actual, const ErrorSpec& error) TF_MUST_USE_RESULT; @@ -161,17 +177,18 @@ class LiteralTestUtil { const Literal& actual, const ErrorSpec& error); - // Returns whether the values of the elements in the expected and actual - // tuples are within the given error bound. Tuples are matched recursively. - // If the elements of the tuple are not floating-point types, the error spec - // is ignored and exact equality is checked. - static ::testing::AssertionResult NearTuple( + // If the error spec is given, returns whether the expected and the actual are + // within the error bound; otherwise, returns whether they are equal. Tuples + // will be compared recursively. + static ::testing::AssertionResult NearOrEqual( const Literal& expected, const Literal& actual, - const ErrorSpec& error) TF_MUST_USE_RESULT; + const tensorflow::gtl::optional& error) TF_MUST_USE_RESULT; - // Expects that the expected and actual values are near. - static void ExpectNearTuple(const Literal& expected, const Literal& actual, - const ErrorSpec& error); + // If the error spec is given, expects the expected and the actual to be near; + // otherwise, expects them to be equal. Tuples will be compared recursively. + static void ExpectNearOrEqual( + const Literal& expected, const Literal& actual, + const tensorflow::gtl::optional& error); // Returns a multi-dimensional index as a string. For example: '{7, 8}' will // be returned for a 2-dimensional index with dimension 0 index equal to 7, diff --git a/tensorflow/compiler/xla/tests/literal_test_util_test.cc b/tensorflow/compiler/xla/tests/literal_test_util_test.cc index 2acf27ed390b0732ba40fcf505c746bd7d8b651e..3a421f8458268a14dcdd84889bcae4990c095ea4 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util_test.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util_test.cc @@ -83,18 +83,43 @@ TEST(LiteralTestUtilTest, ExpectNearFailurePlacesResultsInTemporaryDirectory) { LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), result, &literal_proto)); - Literal literal(literal_proto); + std::unique_ptr literal = + Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); if (result.find("expected") != string::npos) { - EXPECT_EQ("2", literal.ToString()); + EXPECT_EQ("2", literal->ToString()); } else if (result.find("actual") != string::npos) { - EXPECT_EQ("4", literal.ToString()); + EXPECT_EQ("4", literal->ToString()); } else if (result.find("miscompares") != string::npos) { - EXPECT_EQ("true", literal.ToString()); + EXPECT_EQ("true", literal->ToString()); } else { FAIL() << "unknown file in temporary directory: " << result; } } } +TEST(LiteralTestUtilTest, NearComparatorR1) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtilTest, NearComparatorR1Nan) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + auto b = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, NAN, 0.5, 0.6, 0.7, 0.8}); + EXPECT_TRUE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); +} + +TEST(LiteralTestUtil, NearComparatorDifferentLengths) { + auto a = + Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}); + auto b = Literal::CreateR1({0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7}); + EXPECT_FALSE(LiteralTestUtil::Near(*a, *b, ErrorSpec{0.0001})); + EXPECT_FALSE(LiteralTestUtil::Near(*b, *a, ErrorSpec{0.0001})); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc index b5b95967ff9162301a092f3a57996e0f3f78658f..7e92439c494b677f718a63c71c20828d65bebef4 100644 --- a/tensorflow/compiler/xla/tests/llvm_compiler_test.cc +++ b/tensorflow/compiler/xla/tests/llvm_compiler_test.cc @@ -74,7 +74,8 @@ class LLVMCompilerTest : public ::testing::Test { ASSERT_TRUE(compiler ->RunBackend(std::move(hlo_module), - backend_->default_stream_executor()) + backend_->default_stream_executor(), + /*device_allocator=*/nullptr) .ok()); // Test that hooks were called. @@ -98,7 +99,8 @@ class LLVMCompilerTest : public ::testing::Test { executors.push_back({backend_->default_stream_executor()}); executors.push_back({backend_->default_stream_executor()}); - EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors))); + EXPECT_IS_OK(compiler->Compile(std::move(modules), std::move(executors), + /*device_allocator=*/nullptr)); } private: diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test.cc b/tensorflow/compiler/xla/tests/local_client_aot_test.cc index 569d5944cab0ae8f6a7b58a651285d20d4f9d019..47cab796041e9669affaebd7866d0d80100730f1 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test.cc @@ -44,8 +44,7 @@ TEST_F(LocalClientAotTest, Constant) { OpaqueData opaque_data{100, 20, 3}; void* parameters[] = {&opaque_data}; float out = 0; - char tmp[4] = {0}; - void* temporary_buffers[] = {nullptr, &out, &tmp}; + void* temporary_buffers[] = {nullptr, &out}; SumAndDouble(&out, &run_options, parameters, temporary_buffers); EXPECT_EQ(out, 246.0f); diff --git a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc index 0cd44a72b5818c1bf66fd4cd1929572038596b47..3704ddd8010bf727b75ff81b63605e8b7ffe2ca8 100644 --- a/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc +++ b/tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc @@ -63,8 +63,6 @@ int main(int argc, char** argv) { triple_string = "x86_64-apple-macosx"; } else if (target_cpu == "arm") { triple_string = "aarch64-none-linux-gnu"; - } else if (target_cpu == "ppc") { - triple_string = "powerpc64le-unknown-linux-gnu"; } else if (target_cpu == "local") { triple_string = xla::llvm_ir::AsString(llvm::sys::getDefaultTargetTriple()); } else { @@ -89,10 +87,9 @@ int main(int argc, char** argv) { // It's lame to hard-code the buffer assignments, but we need // local_client_aot_test.cc to be able to easily invoke the function. CHECK_EQ(result->result_buffer_index(), 1); - CHECK_EQ(result->buffer_sizes().size(), 3); + CHECK_EQ(result->buffer_sizes().size(), 2); CHECK_EQ(result->buffer_sizes()[0], -1); // param buffer CHECK_EQ(result->buffer_sizes()[1], sizeof(float)); // result buffer - CHECK_EQ(result->buffer_sizes()[2], sizeof(float)); // temp buffer if (triple.isOSBinFormatELF()) { // Check the ELF magic. CHECK_EQ(result->object_file_data()[0], 0x7F); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index ad71d40197fe48b4343ee5f5f7f71b282a05cbf5..2462ea39f914b1dbb525ea777a48d9ce66035638 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -138,13 +138,13 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { // Create x as a col-major array. auto x_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{1.0f, 2.0f}, {3.0f, 4.0f}}, LayoutUtil::MakeLayout({0, 1}))); - EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(x_array->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. auto y_array = LiteralToShapedBuffer(*Literal::CreateR2WithLayout( {{10.0f, 20.0f}, {30.0f, 40.0f}}, LayoutUtil::MakeLayout({1, 0}))); - EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(y_array->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); std::unique_ptr result_colmaj = @@ -179,7 +179,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {0, 1})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_colmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({0, 1}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_colmaj), @@ -191,7 +191,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { DefaultExecutableBuildOptions().set_result_layout( ShapeUtil::MakeShapeWithLayout(F32, /*dimensions=*/{2, 2}, {1, 0})), DefaultExecutableRunOptions()); - EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->shape().layout(), + EXPECT_TRUE(LayoutUtil::Equal(result_rowmaj->on_device_shape().layout(), LayoutUtil::MakeLayout({1, 0}))); LiteralTestUtil::ExpectR2Near({{11.0f, 22.0f}, {33.0f, 44.0f}}, *ShapedBufferToLiteral(*result_rowmaj), @@ -213,16 +213,17 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(3, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - result_literal->tuple_literals(1)); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(2)); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralView::Create(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {2})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { @@ -241,19 +242,21 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_array.get(), y_array.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(1)); - const Literal& inner_tuple_literal = result_literal->tuple_literals(0); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - inner_tuple_literal.tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{10.0f, 20.0f}, {30.0f, 40.0f}}, - inner_tuple_literal.tuple_literals(1)); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - inner_tuple_literal.tuple_literals(2)); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralView::Create(*result_literal, {0, 0})); + LiteralTestUtil::ExpectR2Equal( + {{10.0f, 20.0f}, {30.0f, 40.0f}}, + LiteralView::Create(*result_literal, {0, 1})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, + LiteralView::Create(*result_literal, {0, 2})); } XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { @@ -278,10 +281,10 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { DefaultExecutableRunOptions()); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{1.0f, 2.0f}, {3.0f, 4.0f}}, - result_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{1.0f, 2.0f}, {3.0f, 4.0f}}, LiteralView::Create(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { @@ -320,14 +323,15 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { std::unique_ptr result = ExecuteLocallyOrDie(computation, {x_buffer.get(), y_buffer.get()}); - EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); - EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + EXPECT_TRUE(ShapeUtil::IsTuple(result->on_host_shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->on_host_shape())); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, - result_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{56.0f, 46.0f}, {36.0f, 26.0f}}, + LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal( + {40.0f, 71.0f, 117.0f}, LiteralView::Create(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { @@ -365,10 +369,10 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { ExecuteLocallyOrDie(computation, {arg_buffer.get()}); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, - result_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, - result_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{-1.0, -2.0}, {-3.0, -4}}, LiteralView::Create(*result_literal, {0})); + LiteralTestUtil::ExpectR1Equal( + {264.0, 73.0, 133.0}, LiteralView::Create(*result_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { @@ -395,18 +399,19 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { std::unique_ptr result_0 = ExecuteLocallyOrDie(computation, {arg_buffer.get()}); std::unique_ptr result_0_literal = ShapedBufferToLiteral(*result_0); - LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, - result_0_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, - result_0_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{-1.0, -2.0}, {-3.0, -4.0}}, + LiteralView::Create(*result_0_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{22.0, 6.0}, {8.0, 10}}, LiteralView::Create(*result_0_literal, {1})); std::unique_ptr result_1 = ExecuteLocallyOrDie(computation, {result_0.get()}); std::unique_ptr result_1_literal = ShapedBufferToLiteral(*result_1); - LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, - result_1_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, - result_1_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR2Equal( + {{1.0, 2.0}, {3.0, 4.0}}, LiteralView::Create(*result_1_literal, {0})); + LiteralTestUtil::ExpectR2Equal( + {{44.0, 12.0}, {16.0, 20}}, LiteralView::Create(*result_1_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { @@ -455,7 +460,8 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { for (int i = 0; i < kElementCount; ++i) { LiteralTestUtil::ExpectR1Near( - {2.0f * i, 0.0f}, result_literal->tuple_literals(i), error_spec_); + {2.0f * i, 0.0f}, LiteralView::Create(*result_literal, {i}), + error_spec_); } } @@ -512,8 +518,8 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { for (int i = 0; i < kFanout; ++i) { for (int j = 0; j < kFanout; ++j) { LiteralTestUtil::ExpectR0Near( - i + j + i * kFanout + j, - result_literal->tuple_literals(i).tuple_literals(j), error_spec_); + i + j + i * kFanout + j, LiteralView::Create(*result_literal, {i, j}), + error_spec_); } } } @@ -554,11 +560,12 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { ExecuteLocallyOrDie(computation, {arg_buffer.get()}); std::unique_ptr result_literal = ShapedBufferToLiteral(*result); - const Literal* result_element = result_literal.get(); + ShapeIndex index; for (int i = 0; i < kTupleDepth; ++i) { - result_element = &result_element->tuple_literals(0); + index.push_back(0); } - LiteralTestUtil::ExpectR0Equal(165.0, *result_element); + LiteralTestUtil::ExpectR0Equal( + 165.0, LiteralView::Create(*result_literal, index)); } XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { @@ -575,7 +582,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), - ContainsRegex("invalid number of arguments")); + ContainsRegex("Invalid number of arguments")); } XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { @@ -591,7 +598,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { EXPECT_FALSE(execute_status.ok()); EXPECT_THAT(execute_status.status().error_message(), - ContainsRegex("invalid argument shape")) + ContainsRegex("Invalid argument shape")) << execute_status.status(); } @@ -763,10 +770,10 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) { std::unique_ptr result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); std::unique_ptr tuple_literal = ShapedBufferToLiteral(*result); - LiteralTestUtil::ExpectR1Equal({2.0f, 4.0f, 6.0f}, - tuple_literal->tuple_literals(0)); - LiteralTestUtil::ExpectR1Equal({1.0f, 2.0f, 3.0f}, - tuple_literal->tuple_literals(1)); + LiteralTestUtil::ExpectR1Equal( + {2.0f, 4.0f, 6.0f}, LiteralView::Create(*tuple_literal, {0})); + LiteralTestUtil::ExpectR1Equal( + {1.0f, 2.0f, 3.0f}, LiteralView::Create(*tuple_literal, {1})); } XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { @@ -906,20 +913,18 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto shape_size_fn = [client](const Shape& shape) { - return client->backend().transfer_manager()->GetByteSizeRequirement(shape); - }; - auto buffer = ScopedShapedBuffer::Allocate( - shape, &allocator, /*device_ordinal=*/0, shape_size_fn) - .ConsumeValueOrDie(); + auto buffer = + transfer_manager + ->AllocateScopedShapedBuffer(shape, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( - executors[device_ordinal], *literal, buffer->mutable_buffer({}))); + executors[device_ordinal], *literal, *buffer)); const int kWarmups = 2; - auto executable_status = client->Compile(computation, {&buffer->shape()}, - ExecutableBuildOptions()); + auto executable_status = client->Compile( + computation, {&buffer->on_host_shape()}, ExecutableBuildOptions()); ASSERT_IS_OK(executable_status); std::unique_ptr executable = executable_status.ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 062a9246e49598d5d03dce8c1f437138923449bf..96b976d25d75d35f46adfd104a03aceb363661eb 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -188,7 +188,7 @@ LocalClientTestBase::ExecuteLocally( const ExecutableRunOptions& run_options) { std::vector argument_layouts(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { - argument_layouts[i] = &arguments[i]->shape(); + argument_layouts[i] = &arguments[i]->on_host_shape(); } TF_ASSIGN_OR_RETURN( std::unique_ptr executable, diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 0fb87c3c2ccbad387d46016cfad4e7d3cc537dcc..6c86dd5b9ef673c9facffafa37e00a859ce82010 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -221,5 +221,77 @@ INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest, ::testing::Combine(::testing::Bool(), ::testing::Bool(), ::testing::Bool())); +class MatOpsDotAddTest_bf16 + : public ClientLibraryTestBase, + public ::testing::WithParamInterface> {}; + +TEST_P(MatOpsDotAddTest_bf16, Dot_Add_2x2_2x2) { + bool row_major = std::get<0>(GetParam()); + bool add_lhs = std::get<1>(GetParam()); + bool transpose = std::get<2>(GetParam()); + Array2D lhs( + {{bfloat16(1.0f), bfloat16(2.0f)}, {bfloat16(3.0), bfloat16(4.0)}}); + Array2D rhs( + {{bfloat16(10.0f), bfloat16(11.0f)}, {bfloat16(12.0f), bfloat16(13.0f)}}); + + auto minor_to_major = [](bool row_major) -> std::vector { + return {row_major ? 1 : 0, row_major ? 0 : 1}; + }; + + auto prim_type = primitive_util::NativeToPrimitiveType(); + Shape lhs_shape = + ShapeUtil::MakeShape(prim_type, {lhs.height(), lhs.width()}); + Shape rhs_shape = + ShapeUtil::MakeShape(prim_type, {rhs.height(), rhs.width()}); + + TF_ASSERT_OK_AND_ASSIGN( + auto lhs_handle, + client_->TransferToServer( + *Literal::CreateR2FromArray2DWithLayout( + lhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + TF_ASSERT_OK_AND_ASSIGN( + auto rhs_handle, + client_->TransferToServer( + *Literal::CreateR2FromArray2DWithLayout( + rhs, LayoutUtil::MakeLayout(minor_to_major(row_major))))); + + ComputationBuilder builder(client_, TestName()); + auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs"); + auto lhs_mat_arg = lhs_arg; + if (transpose) { + lhs_mat_arg = builder.Transpose(lhs_mat_arg, {1, 0}); + } + auto rhs_arg = builder.Parameter(1, rhs_shape, "rhs"); + auto result = builder.Dot(lhs_mat_arg, rhs_arg); + Array2D expected; + if (add_lhs) { + result = builder.Add(result, lhs_arg); + if (transpose) { + expected = Array2D( + {{bfloat16(47), bfloat16(52)}, {bfloat16(71), bfloat16(78)}}); + } else { + expected = Array2D( + {{bfloat16(35), bfloat16(39)}, {bfloat16(81), bfloat16(89)}}); + } + } else { + result = builder.Add(result, rhs_arg); + if (transpose) { + expected = Array2D( + {{bfloat16(56), bfloat16(61)}, {bfloat16(80), bfloat16(87)}}); + } else { + expected = Array2D( + {{bfloat16(44), bfloat16(48)}, {bfloat16(90), bfloat16(98)}}); + } + } + + ComputeAndCompareR2(&builder, expected, + {lhs_handle.get(), rhs_handle.get()}, + ErrorSpec(1e-6)); +} + +INSTANTIATE_TEST_CASE_P(MatOpsDotAddTestInstances, MatOpsDotAddTest_bf16, + ::testing::Combine(::testing::Bool(), ::testing::Bool(), + ::testing::Bool())); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc index 22d2b917a1d55f4f453e21c2d8fea38e32ff796b..0a603f4954badd12adf3144320789a5edd0d9c6c 100644 --- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc +++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -35,6 +36,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -76,8 +78,11 @@ class MultiOutputFusionTest : public HloTestBase { elem_shape2, HloOpcode::kAdd, broadcast, param1)); HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary( elem_shape2, HloOpcode::kSubtract, param1, broadcast)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); HloInstruction* dot = builder.AddInstruction( - HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2)); + HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -96,14 +101,13 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input; - input.PopulateWithValue(2.5f, {size, size}); - auto p1 = TransferToDevice(input); - auto p0 = TransferToDevice(*Literal::CreateR0(-9.0f)); + Literal arg1(ShapeUtil::MakeShape(F32, {size, size})); + arg1.PopulateWithValue(2.5f); - Literal expect; - expect.PopulateWithValue(size * 1.5f * 3.5f, {size, size}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + Literal expect(ShapeUtil::MakeShape(F32, {size, size})); + expect.PopulateWithValue(size * 1.5f * 3.5f); + auto actual = ExecuteAndTransfer( + std::move(hlo_module), {Literal::CreateR0(-9.0f).get(), &arg1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } @@ -133,8 +137,11 @@ class MultiOutputFusionTest : public HloTestBase { HloInstruction* reshape = builder.AddInstruction(HloInstruction::CreateReshape( ShapeUtil::MakeShape(F32, {size, 1}), add)); - HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary( - ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape)); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(0); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot( + ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums)); auto computation = hlo_module->AddEntryComputation(builder.Build(dot)); if (manual_fusion) { @@ -154,14 +161,13 @@ class MultiOutputFusionTest : public HloTestBase { nullptr); } - Literal input0, input1; - input0.PopulateWithValue(2.5f, {size}); - input1.PopulateWithValue(1, {size}); - auto p0 = TransferToDevice(input0); - auto p1 = TransferToDevice(input1); + Literal input0(ShapeUtil::MakeShape(F32, {size})); + input0.PopulateWithValue(2.5f); + Literal input1(ShapeUtil::MakeShape(F64, {size})); + input1.PopulateWithValue(1.); - Literal expect = *Literal::CreateR1({size * 1.5f * 3.5f}); - auto actual = ExecuteAndTransfer(std::move(hlo_module), {p0, p1}); + Literal expect = std::move(*Literal::CreateR1({size * 1.5f * 3.5f})); + auto actual = ExecuteAndTransfer(std::move(hlo_module), {&input0, &input1}); LiteralTestUtil::ExpectNear(expect, *actual, error_spec_); } }; @@ -172,5 +178,38 @@ XLA_TEST_F(MultiOutputFusionTest, 2DFusionSize129) { RunTest2D(true, 129); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesNoFusion) { RunTest1D(false, 8); } XLA_TEST_F(MultiOutputFusionTest, DiffentTypesFusion) { RunTest1D(true, 8); } +XLA_TEST_F(MultiOutputFusionTest, FusionNodeIsRoot) { + const char* testcase = R"( + HloModule m + + fused_computation { + x.param_0 = (((s32[]), f32[]), (f32[], s32[])) parameter(0) + gte.3 = ((s32[]), f32[]) get-tuple-element(x.param_0), index=0 + gte.2 = (s32[]) get-tuple-element(gte.3), index=0 + gte.4 = s32[] get-tuple-element(gte.2), index=0 + copy = s32[] copy(gte.4) + ROOT tuple = (s32[]) tuple(copy) + } + + ENTRY thing.v3 { + x = (((s32[]), f32[]), (f32[], s32[])) parameter(0) + ROOT fusion = (s32[]) fusion(x), kind=kLoop, calls=fused_computation + } + )"; + auto module = + HloRunner::CreateModuleFromString(testcase, GetDebugOptionsForTest()) + .ValueOrDie(); + auto param = Literal::MakeTupleOwned( + Literal::MakeTupleOwned( + Literal::MakeTupleOwned(Literal::CreateR0(42)), + Literal::CreateR0(1.0)), + Literal::MakeTupleOwned(Literal::CreateR0(3.0), + Literal::CreateR0(4))); + TF_ASSERT_OK_AND_ASSIGN(auto result, + Execute(std::move(module), {param.get()})); + EXPECT_TRUE(LiteralTestUtil::Equal( + *result, *Literal::MakeTupleOwned(Literal::CreateR0(42)))); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/pad_test.cc b/tensorflow/compiler/xla/tests/pad_test.cc index 3fd83a4c3b104831f03366339fb7b8b5d816a3f7..8cef8dd34dc7b16b1e58ded67d6b6a4ba79f20db 100644 --- a/tensorflow/compiler/xla/tests/pad_test.cc +++ b/tensorflow/compiler/xla/tests/pad_test.cc @@ -33,6 +33,14 @@ limitations under the License. namespace xla { namespace { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + class PadTest : public ClientLibraryTestBase { protected: PadTest() { @@ -61,8 +69,22 @@ class PadTest : public ClientLibraryTestBase { PaddingConfig r4_padding_on_dim0_dim1_; }; +class PadTestFloat : public PadTest, + public ::testing::WithParamInterface { + protected: + PadTestFloat() { set_use_bfloat16(GetParam()); } + + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-3, 1e-3); + } else { + return ErrorSpec(1e-5, 1e-5); + } + } +}; + // Tests a Pad() with a zero-element input and output. -XLA_TEST_F(PadTest, Pad1DS0ToS0Array) { +XLA_TEST_P(PadTestFloat, Pad1DS0ToS0Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 0, high: 0, interior: 0}. PaddingConfig padding_config; @@ -71,12 +93,13 @@ XLA_TEST_F(PadTest, Pad1DS0ToS0Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(0); - b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); - ComputeAndCompareR1(&b, {}, {}, ErrorSpec(0.0001)); + b.Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); + ComputeAndCompareR1(&b, {}, {}, DefaultErrorSpec()); } // Tests a Pad() with a zero-element input but a non-zero-element output. -XLA_TEST_F(PadTest, Pad1DS0ToS5Array) { +XLA_TEST_P(PadTestFloat, Pad1DS0ToS5Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; @@ -85,12 +108,13 @@ XLA_TEST_F(PadTest, Pad1DS0ToS5Array) { dimension->set_edge_padding_high(4); dimension->set_interior_padding(7); - b.Pad(b.ConstantR1({}), b.ConstantR0(0.1), padding_config); + b.Pad(AddParam(*Literal::CreateR1({}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); ComputeAndCompareR1(&b, std::vector(5, 0.1), {}, - ErrorSpec(0.0001)); + DefaultErrorSpec()); } -XLA_TEST_F(PadTest, Pad1DS3Array) { +XLA_TEST_P(PadTestFloat, Pad1DS3Array) { ComputationBuilder b(client_, TestName()); // Set up the padding configuration {low: 3, high: 0, interior: 1}. PaddingConfig padding_config; @@ -99,21 +123,21 @@ XLA_TEST_F(PadTest, Pad1DS3Array) { dimension->set_edge_padding_high(0); dimension->set_interior_padding(1); - b.Pad(b.ConstantR1({1, 2, 3}), b.ConstantR0(0.1), - padding_config); + b.Pad(AddParam(*Literal::CreateR1({1, 2, 3}), &b), + AddParam(*Literal::CreateR0(0.1), &b), padding_config); std::vector expected({0.1, 0.1, 0.1, 1, 0.1, 2, 0.1, 3}); - ComputeAndCompareR1(&b, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR1(&b, expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, Pad4D_2x0x3x2_FloatArray) { +XLA_TEST_P(PadTestFloat, Pad4D_2x0x3x2_FloatArray) { ComputationBuilder b(client_, TestName()); - b.Pad(b.ConstantR4FromArray4D(Array4D(2, 0, 3, 2)), - b.ConstantR0(1.5), r4_padding_on_dim0_dim1_); + b.Pad(AddParam(Array4D(2, 0, 3, 2), &b), + AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); ComputeAndCompareR4(&b, Array4D(5, 2, 3, 2, 1.5f), {}, - ErrorSpec(0.0001)); + DefaultErrorSpec()); } -TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { +TEST_P(PadTestFloat, Pad4DFloat_1x1x3x2_Array) { ComputationBuilder b(client_, TestName()); auto input = MakeUnique>(1, 1, 3, 2); Array2D input_xy({ @@ -123,7 +147,7 @@ TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { }); input->FillWithYX(input_xy); - b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(1.5), + b.Pad(AddParam(*input, &b), AddParam(*Literal::CreateR0(1.5), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -134,15 +158,15 @@ TEST_F(PadTest, Pad4DFloat_1x1x3x2_Array) { (*expected)(1, 0, 1, 1) = 4.0f; (*expected)(1, 0, 2, 0) = 5.0f; (*expected)(1, 0, 2, 1) = 6.0f; - ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareR4(&b, *expected, {}, DefaultErrorSpec()); } -TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) { +TEST_P(PadTestFloat, Pad4DFloatArrayWithInteriorPadding) { ComputationBuilder b(client_, TestName()); const float pad_value = 1.5f; Array4D input(3, 2, 1, 1, {1, 2, 3, 4, 5, 6}); - b.Pad(b.ConstantR4FromArray4D(input), b.ConstantR0(pad_value), + b.Pad(AddParam(input, &b), AddParam(*Literal::CreateR0(pad_value), &b), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(8, 5, 1, 1); @@ -156,7 +180,7 @@ TEST_F(PadTest, Pad4DFloatArrayWithInteriorPadding) { ComputeAndCompareR4(&b, *expected, {}, ErrorSpec(0.0001)); } -TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { +TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstSmall) { ComputationBuilder b(client_, TestName()); PaddingConfig padding_config; @@ -184,7 +208,8 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + b.Pad(AddParam(*input, &b), + AddParam(*Literal::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 1, 5, 8); expected_array.Fill(pad_value); @@ -197,7 +222,7 @@ TEST_F(PadTest, Pad4DFloatArrayMinorFirstSmall) { ComputeAndCompareR4(&b, expected_array, {}, ErrorSpec(0.0001)); } -XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { +XLA_TEST_P(PadTestFloat, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { ComputationBuilder b(client_, TestName()); PaddingConfig padding_config; @@ -229,7 +254,8 @@ XLA_TEST_F(PadTest, Pad4DFloatArrayMinorFirstNonTrivialMinorDimensions) { auto input = Literal::CreateR4FromArray4D(input_array); input = input->Relayout(layout); - b.Pad(b.ConstantLiteral(*input), b.ConstantR0(pad_value), padding_config); + b.Pad(AddParam(*input, &b), + AddParam(*Literal::CreateR0(pad_value), &b), padding_config); Array4D expected_array(1, 25, 17, 11); expected_array.Fill(pad_value); @@ -249,7 +275,7 @@ XLA_TEST_F(PadTest, Pad4DU8Array) { }); input->FillWithYX(input_xy); - b.Pad(b.ConstantR4FromArray4D(*input), b.ConstantR0(35), + b.Pad(AddParam(*input, &b), b.ConstantR0(35), r4_padding_on_dim0_dim1_); auto expected = MakeUnique>(2, 3, 3, 2); @@ -277,8 +303,7 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { auto ones = MakeUnique>(2, 3, 3, 2); zeros->Fill(0); ones->Fill(1); - b.Select(padded, b.ConstantR4FromArray4D(*ones), - b.ConstantR4FromArray4D(*zeros)); + b.Select(padded, AddParam(*ones, &b), AddParam(*zeros, &b)); auto expected = MakeUnique>(2, 3, 3, 2); expected->Fill(0); @@ -291,10 +316,12 @@ XLA_TEST_F(PadTest, Pad4DPredArray) { ComputeAndCompareR4(&b, *expected, {}); } -XLA_TEST_F(PadTest, Large2DPad) { +XLA_TEST_P(PadTestFloat, Large2DPad) { ComputationBuilder b(client_, TestName()); - auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {4, 4}), "input"); + auto ones = MakeUnique>(4, 4); + ones->Fill(1.0f); + auto input = AddParam(*ones, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -302,25 +329,22 @@ XLA_TEST_F(PadTest, Large2DPad) { padding_config.mutable_dimensions(dim)->set_edge_padding_high(58 + 100 * dim); } - auto padded = b.Pad(input, b.ConstantR0(0.0f), padding_config); - - auto ones = MakeUnique>(4, 4); - ones->Fill(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*ones); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(0.0f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*ones, padding_config, 0.0f); - ComputeAndCompareR2(&b, *expected, {input_data.get()}); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, AllTypes2DPad) { +XLA_TEST_P(PadTestFloat, AllTypes2DPad) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 35; constexpr int64 in_cols = 35; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(0.0f); + auto input = AddParam(*operand, &b); + PaddingConfig padding_config = MakeNoPaddingConfig(2); padding_config.mutable_dimensions(0)->set_edge_padding_low(7); padding_config.mutable_dimensions(0)->set_edge_padding_high(5); @@ -328,20 +352,14 @@ XLA_TEST_F(PadTest, AllTypes2DPad) { padding_config.mutable_dimensions(1)->set_edge_padding_low(6); padding_config.mutable_dimensions(1)->set_edge_padding_high(4); padding_config.mutable_dimensions(1)->set_interior_padding(2); - auto padded = b.Pad(input, b.ConstantR0(3.14f), padding_config); - - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(0.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(3.14f), &b), + padding_config); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 3.14f); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec{0.0001}); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, High2DPad) { +XLA_TEST_P(PadTestFloat, High2DPad) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 129; @@ -349,8 +367,9 @@ XLA_TEST_F(PadTest, High2DPad) { constexpr int64 low_padding = 0; int64 high_padding[2] = {5, 7}; constexpr int64 interior_padding = 0; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low(low_padding); @@ -359,20 +378,15 @@ XLA_TEST_F(PadTest, High2DPad) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, NegativePadding2D) { +XLA_TEST_P(PadTestFloat, NegativePadding2D) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 129; @@ -380,8 +394,9 @@ XLA_TEST_F(PadTest, NegativePadding2D) { int64 low_padding[2] = {-1, -2}; int64 high_padding[2] = {-3, 4}; constexpr int64 interior_padding = 0; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -391,20 +406,15 @@ XLA_TEST_F(PadTest, NegativePadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { +XLA_TEST_P(PadTestFloat, NegativeAndInteriorPadding2D) { ComputationBuilder b(client_, TestName()); constexpr int64 in_rows = 8; @@ -412,8 +422,9 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { int64 low_padding[2] = {4, -1}; int64 high_padding[2] = {-2, -4}; int64 interior_padding[2] = {1, 2}; - auto input = - b.Parameter(0, ShapeUtil::MakeShape(F32, {in_rows, in_cols}), "input"); + auto operand = MakeUnique>(in_rows, in_cols); + operand->FillUnique(1.0f); + auto input = AddParam(*operand, &b); PaddingConfig padding_config = MakeNoPaddingConfig(2); for (int dim : {0, 1}) { padding_config.mutable_dimensions(dim)->set_edge_padding_low( @@ -423,44 +434,40 @@ XLA_TEST_F(PadTest, NegativeAndInteriorPadding2D) { padding_config.mutable_dimensions(dim)->set_interior_padding( interior_padding[dim]); } - auto padded = b.Pad(input, b.ConstantR0(2.718f), padding_config); + auto padded = b.Pad(input, AddParam(*Literal::CreateR0(2.718f), &b), + padding_config); - auto operand = MakeUnique>(in_rows, in_cols); - operand->FillUnique(1.0f); - auto input_literal = Literal::CreateR2FromArray2D(*operand); auto expected = ReferenceUtil::PadArray2D(*operand, padding_config, 2.718f); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputeAndCompareR2(&b, *expected, {input_data.get()}, - ErrorSpec(0.0001)); + ComputeAndCompareR2(&b, *expected, {}, DefaultErrorSpec()); } // Regression test for b/31827337. -XLA_TEST_F(PadTest, ReducePad) { +XLA_TEST_P(PadTestFloat, ReducePad) { ComputationBuilder b(client_, TestName()); - auto input = b.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2, 2, 2}), "input"); + auto ones = MakeUnique>(2, 2, 2, 2); + ones->Fill(1.0); + auto input = AddParam(*ones, &b); - Computation add_f32 = CreateScalarAddComputation(F32, &b); - auto reduce = b.Reduce(input, b.ConstantR0(0.0), add_f32, {0}); + Computation add = CreateScalarAddComputation(FloatType(), &b); + auto reduce = + b.Reduce(input, AddParam(*Literal::CreateR0(0.0), &b), add, {0}); PaddingConfig padding_config = MakeNoPaddingConfig(3); padding_config.mutable_dimensions(0)->set_edge_padding_low(1); padding_config.mutable_dimensions(0)->set_edge_padding_high(1); - auto pad = b.Pad(reduce, b.ConstantR0(0.0), padding_config); - - auto ones = MakeUnique>(2, 2, 2, 2); - ones->Fill(1.0); - auto input_literal = Literal::CreateR4FromArray4D(*ones); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + auto padded = b.Pad(reduce, AddParam(*Literal::CreateR0(0.0f), &b), + padding_config); Array3D expected({{{0.0, 0.0}, {0.0, 0.0}}, {{2.0, 2.0}, {2.0, 2.0}}, {{2.0, 2.0}, {2.0, 2.0}}, {{0.0, 0.0}, {0.0, 0.0}}}); - ComputeAndCompareR3(&b, expected, {input_data.get()}); + ComputeAndCompareR3(&b, expected, {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(PadTestFloatInstantiation, PadTestFloat, + ::testing::ValuesIn(use_bfloat16_params)); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc index b7f62b8aa167b2d9ef1bb2fa83af5aaeda1d6652..bb7e800df84121f2045141bc366c34b94ba694ea 100644 --- a/tensorflow/compiler/xla/tests/params_test.cc +++ b/tensorflow/compiler/xla/tests/params_test.cc @@ -334,10 +334,109 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU( ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); } +// Test large number of parameters flowing into a while-loop. +// Construct conceptually the following HLO graph: +// +// p0 = parameter(0) +// p1 = parameter(1) +// ... +// pN = parameter(N) +// result = while (false) { +// p0 += (1, 1); +// p1 += (1, 1); +// ... +// pN += (1, 1) +// } +// result = {p0, p1, ..., pN} +// +// TODO(b/70173746): Times out during compilation on GPU and CPU backends as of +// 2017-12-12. +XLA_TEST_F(ParamsTest, + DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) { + ComputationBuilder builder(client_, TestName()); + + std::vector> param_data_owner; + constexpr int kParamCount = 1900; + std::vector params; + std::vector parameter_shapes; + for (int i = 0; i < kParamCount; ++i) { + std::unique_ptr literal = Literal::CreateR1({i, i}); + param_data_owner.push_back( + std::move(client_->TransferToServer(*literal)).ValueOrDie()); + ComputationDataHandle param = + builder.Parameter(i, literal->shape(), "param"); + params.push_back(param); + parameter_shapes.push_back(literal->shape()); + } + + // Add bool parameter for the loop condition. Use a parameter HLO instead of a + // constant because DCE may eliminate the while-body otherwise. + std::unique_ptr bool_literal = Literal::CreateR0(false); + param_data_owner.push_back( + std::move(client_->TransferToServer(*bool_literal)).ValueOrDie()); + ComputationDataHandle bool_param = + builder.Parameter(kParamCount, bool_literal->shape(), "bool_param"); + params.push_back(bool_param); + parameter_shapes.push_back(bool_literal->shape()); + + auto init = builder.Tuple(params); + + // Create a computation for the condition: while(bool_param). + Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes); + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto condition_parameter = + builder.Parameter(0, while_shape, "condition_parameter"); + builder.GetTupleElement(condition_parameter, kParamCount); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add {1, 1} to the each tuple element. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto body_parameter = builder.Parameter(0, while_shape, "body_parameter"); + std::vector updates; + for (int i = 0; i < kParamCount; ++i) { + auto add = builder.Add(builder.GetTupleElement(body_parameter, i), + builder.ConstantR1({1, 1})); + updates.push_back(add); + } + // Add bool parameter. + updates.push_back(builder.GetTupleElement(body_parameter, kParamCount)); + + builder.Tuple(updates); + body = builder.Build().ConsumeValueOrDie(); + } + + auto loop = builder.While(condition, body, init); + + std::vector outputs; + for (int i = 0; i < kParamCount; ++i) { + outputs.push_back(builder.GetTupleElement(loop, i)); + } + builder.Tuple(outputs); + + std::vector param_data; + param_data.reserve(param_data_owner.size()); + for (const std::unique_ptr& data : param_data_owner) { + param_data.push_back(data.get()); + } + + std::vector> elements; + std::vector ptrs; + for (int i = 0; i < kParamCount; ++i) { + elements.push_back(Literal::CreateR1({i, i})); + ptrs.push_back(elements.back().get()); + } + ComputeAndCompareTuple(&builder, *Literal::MakeTuple(ptrs), param_data); +} + #endif -XLA_TEST_F(ParamsTest, - DISABLED_ON_CPU_PARALLEL(TupleOfR1ParametersAddedTogether)) { +XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) { ComputationBuilder builder(client_, TestName()); Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3}); @@ -363,10 +462,8 @@ XLA_TEST_F(ParamsTest, // Verifies that passing a 2x2 with {0, 1} layout returns the same value back // when (transferred to the server and) passed through a parameter. XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { - std::unique_ptr literal = Literal::CreateR2({ - {1, 2}, {3, 4}, - }); - *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + std::unique_ptr literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); ComputationBuilder builder(client_, TestName()); builder.Parameter(0, literal->shape(), "input"); @@ -377,10 +474,8 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) { // As above, but for {1, 0} layout. XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) { - std::unique_ptr literal = Literal::CreateR2({ - {1, 3}, {2, 4}, - }); - *literal->mutable_shape()->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + std::unique_ptr literal = Literal::CreateR2WithLayout( + {{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0})); ComputationBuilder builder(client_, TestName()); builder.Parameter(0, literal->shape(), "input"); @@ -401,7 +496,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) { original.layout().minor_to_major().begin(), original.layout().minor_to_major().end()); std::reverse(original_layout.begin(), original_layout.end()); - *literal->mutable_shape()->mutable_layout() = + *literal->mutable_shape_do_not_use()->mutable_layout() = LayoutUtil::MakeLayout(original_layout); ASSERT_EQ(2, literal->Get({0, 1})); } diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 209f063cc5a34648453d12deae79f261b95dc3b4..6aafb9fa6cb2175c478f0e9a5e16f5808cbea590 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include "tensorflow/compiler/xla/client/computation_builder.h" @@ -36,65 +37,42 @@ namespace { class PrngTest : public ClientLibraryTestBase { protected: template - void UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims); - void BernoulliTest(float p, tensorflow::gtl::ArraySlice dims); + std::unique_ptr UniformTest(T a, T b, + tensorflow::gtl::ArraySlice dims, + int64 seed = 42); // Computes the χ² statistic of a sample of the discrete uniform distribution // of the given range size. `expected_count` is the number of times each // possible value is expected to be generated. Thus, the sample size is // `range_size * expected_count`. - double UniformChiSquared(int32 range_size, int32 expected_count); + double UniformChiSquared(int32 range_size, int32 expected_count, + int64 seed = 42); }; template -void PrngTest::UniformTest(T a, T b, tensorflow::gtl::ArraySlice dims) { +std::unique_ptr PrngTest::UniformTest( + T a, T b, tensorflow::gtl::ArraySlice dims, int64 seed) { ComputationBuilder builder(client_, TestName()); builder.RngUniform( builder.ConstantR0(a), builder.ConstantR0(b), ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType(), dims)); - SetSeed(42); + SetSeed(seed); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); actual->EachCell([=](tensorflow::gtl::ArraySlice, T value) { EXPECT_LE(a, value); EXPECT_LT(value, b); }); -} - -void PrngTest::BernoulliTest(float p, tensorflow::gtl::ArraySlice dims) { - ComputationBuilder builder(client_, TestName()); - auto shape = ShapeUtil::MakeShape(U32, dims); - builder.RngBernoulli(builder.ConstantR0(p), shape); - - TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); - ExecutionOptions execution_options = execution_options_; - execution_options.set_seed(42); - TF_ASSERT_OK_AND_ASSIGN( - auto actual, client_->ExecuteAndTransfer(computation, /*arguments=*/{}, - &execution_options)); - EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions())); - int32 sum = 0; - actual->EachCell( - [&sum](tensorflow::gtl::ArraySlice, uint32 value) { - EXPECT_TRUE(value == 0 || value == 1); - sum += value; - }); - int32 total = ShapeUtil::ElementsIn(shape); - float p_tilde = sum / static_cast(total); - - // Test within expected range using normal approximation. The test uses a - // fixed seed and has a fixed output per p and backend. Using the normal - // approximation as this test is invoked for different `p` and the different - // backends could use different random number generators and produce different - // values. Choose 95% confidence level, so that z_{1-\alpha/2} = 1.96. - float normal_approximation_term = 1.96 * sqrt(p * (1 - p) / total); - EXPECT_GE(p_tilde, p - normal_approximation_term); - EXPECT_LE(p_tilde, p + normal_approximation_term); + return actual; } // Uniform random number generation tests XLA_TEST_F(PrngTest, ScalarU01) { UniformTest(0, 1, {}); } +XLA_TEST_F(PrngTest, ScalarU01limits) { + UniformTest(std::numeric_limits::min(), + std::numeric_limits::max(), {}); +} XLA_TEST_F(PrngTest, ZeroValuesU01) { UniformTest(0, 1, {0}); } XLA_TEST_F(PrngTest, TenValuesU01) { UniformTest(0, 1, {10}); } XLA_TEST_F(PrngTest, TenValuesU37) { UniformTest(3, 7, {10}); } @@ -102,6 +80,56 @@ XLA_TEST_F(PrngTest, ZeroValuesR2) { UniformTest(0, 1, {0, 20}); } XLA_TEST_F(PrngTest, LargeU01) { UniformTest(0, 1, {0x100, 0x100}); } XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest(5, 24, {12}); } +// TODO(b/71543667): Fix Rng ops on LLVM backends. +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL( + DISABLED_ON_CPU(ScalarBF16Tests)))) { + for (int64 seed = 0; seed < 100; ++seed) { + // The largest negative number smaller than zero in bf16 that's not + // denormalized. + int32 low_raw = 0x80800000; + const float low = reinterpret_cast(low_raw); + float high = 0.0f; + UniformTest(static_cast(low), + static_cast(high), {}, /*seed=*/seed); + + // Test odd and even values. + UniformTest(static_cast(32.75), + static_cast(33), {}, /*seed=*/seed); + UniformTest(static_cast(32.50), + static_cast(32.75), {}, /*seed=*/seed); + UniformTest(static_cast(-33.00), + static_cast(-32.75), {}, /*seed=*/seed); + UniformTest(static_cast(-32.75), + static_cast(-32.50), {}, /*seed=*/seed); + } +} + +// TODO(b/71543667): Fix Rng ops on LLVM backends. +XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU( + DISABLED_ON_CPU_PARALLEL(ScalarBF16CountTests)))) { + // There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75, + // they should get similar counts. + bfloat16 low = static_cast(32.25); + bfloat16 high = static_cast(33); + bfloat16 interval = static_cast(0.25); + std::vector counts(static_cast((high - low) / interval), 0); + + constexpr int64 count = 100; + for (int64 seed = 0; seed < count; ++seed) { + auto result = UniformTest(low, high, {}, /*seed=*/seed); + result->Literal::EachCell( + [&](tensorflow::gtl::ArraySlice, bfloat16 value) { + int64 index = static_cast((value - low) / interval); + counts[index]++; + }); + } + // Each bucket should have similar amount of counts. That is, not more than + // 10% of total counts. This mostly tests that we don't fall into a 1:2:2 + // distribution, which yields 20% expected difference. + EXPECT_LT(std::abs(counts[0] - counts[1]), count * 0.1); + EXPECT_LT(std::abs(counts[1] - counts[2]), count * 0.1); +} + namespace { template T Square(T x) { @@ -109,7 +137,8 @@ T Square(T x) { } } // namespace -double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { +double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count, + int64 seed) { int32 sample_size = range_size * expected_count; ComputationBuilder builder(client_, TestName()); @@ -117,7 +146,7 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count) { builder.ConstantR0(range_size), ShapeUtil::MakeShape(S32, {sample_size})); - SetSeed(42); + SetSeed(seed); auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); std::vector counts(range_size, 0); actual->EachCell([&counts](tensorflow::gtl::ArraySlice, @@ -181,10 +210,12 @@ XLA_TEST_F(PrngTest, MapUsingRng) { computation, /*arguments=*/{param0_data.get()}, &execution_options)); - EXPECT_EQ(actual->f32s_size(), param0_literal->f32s_size()); - for (int i = 0; i < param0_literal->f32s_size(); ++i) { - EXPECT_GE(actual->f32s(i), param0_literal->f32s(i)); - EXPECT_LT(actual->f32s(i), param0_literal->f32s(i) + 1.0f); + EXPECT_EQ(ShapeUtil::ElementsIn(actual->shape()), + ShapeUtil::ElementsIn(param0_literal->shape())); + for (int i = 0; i < ShapeUtil::ElementsIn(actual->shape()); ++i) { + EXPECT_GE(actual->data()[i], param0_literal->data()[i]); + EXPECT_LT(actual->data()[i], + param0_literal->data()[i] + 1.0f); } } @@ -250,10 +281,6 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) { LiteralTestUtil::ExpectNotEqual(*result5, *result6); } -// Bernoulli random number generation tests -XLA_TEST_F(PrngTest, HundredValuesB10p5) { BernoulliTest(0.5, {100}); } -XLA_TEST_F(PrngTest, HundredValuesB10p1) { BernoulliTest(0.1, {100}); } - XLA_TEST_F(PrngTest, TenValuesN01) { ComputationBuilder builder(client_, TestName()); builder.RngNormal(builder.ConstantR0(0), builder.ConstantR0(1), diff --git a/tensorflow/compiler/xla/tests/reduce_hlo_test.cc b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c0a2c0ca4cb8414e0771a541b9f963f9aedc8376 --- /dev/null +++ b/tensorflow/compiler/xla/tests/reduce_hlo_test.cc @@ -0,0 +1,132 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +// Tests the Reduce HLO in ways that can't be done using the ComputationBuilder +// API. + +namespace xla { +namespace { + +namespace str_util = tensorflow::str_util; +namespace strings = tensorflow::strings; + +struct ReduceLayout { + std::array input_minor_to_major; + std::array output_minor_to_major; + + string ToString() const { + return strings::StrCat(str_util::Join(input_minor_to_major, "x"), "_", + str_util::Join(output_minor_to_major, "x")); + } +}; + +string PrintReduceLayout( + ::testing::TestParamInfo reduce_layout_param) { + return reduce_layout_param.param.ToString(); +} + +void PrintTo(const ReduceLayout& reduce_layout, ::std::ostream* os) { + *os << reduce_layout.ToString(); +} + +class ReduceWithLayoutTest + : public HloTestBase, + public ::testing::WithParamInterface {}; + +StatusOr> GetParsedModule() { + const char* const hlo_string = R"( +HloModule BadReduce + +Sum { + x.1 = f32[] parameter(0) + y.1 = f32[] parameter(1) + ROOT add.1 = f32[] add(x.1, y.1) +} + +ENTRY reduce.1 { + parameter = f32[2,2,2,3]{3,2,1,0} parameter(0) + init_value = f32[] constant(0) + reduce = f32[2,2,3]{2,1,0} reduce(parameter, init_value), dimensions={1}, to_apply=Sum + ROOT copy = f32[2,2,3]{2,1,0} copy(reduce) +} +)"; + + return tools::Parse(hlo_string); +} + +// TODO(b/72454718): XLA:GPU does not support executing code compiled without +// optimizations. +XLA_TEST_P(ReduceWithLayoutTest, DISABLED_ON_GPU(Reduce)) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, GetParsedModule()); + HloInstruction* reduce_instruction = + module->entry_computation()->root_instruction()->mutable_operand(0); + ASSERT_EQ(reduce_instruction->opcode(), HloOpcode::kReduce); + + const ReduceLayout& reduce_layout = GetParam(); + + Shape* reduce_output_shape = reduce_instruction->mutable_shape(); + *reduce_output_shape->mutable_layout() = + LayoutUtil::MakeLayout(reduce_layout.output_minor_to_major); + + Shape* reduce_input_shape = + reduce_instruction->mutable_operand(0)->mutable_shape(); + *reduce_input_shape->mutable_layout() = + LayoutUtil::MakeLayout(reduce_layout.input_minor_to_major); + + std::unique_ptr reduce_input = + Literal::CreateR4({{ /*i0=0*/ + {/*i1=0*/ + {-0.246092796, -0.179497838, -0.161181688}, + {-0.151643038, -0.240213156, -0.198156}}, + {/*i1=1*/ + {-0.14222312, -0.162200093, -0.193907976}, + {-0.239411, -0.198166847, -0.172471642}}}, + { /*i0=1*/ + {/*i1=0*/ + {-0.22965157, -0.218723893, -0.129257083}, + {-0.188762426, -0.16123569, -0.181166649}}, + {/*i1=1*/ + {-0.241772294, -0.245131493, -0.160247207}, + {-0.179881215, -0.23383224, -0.121976733}}}}); + + EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec(1e-5))); +} + +INSTANTIATE_TEST_CASE_P(ReduceWithLayoutTest_Instantiation, + ReduceWithLayoutTest, + ::testing::Values( // + ReduceLayout{{3, 2, 1, 0}, {0, 1, 2}}, // + ReduceLayout{{3, 2, 1, 0}, {0, 2, 1}}, // + ReduceLayout{{3, 2, 1, 0}, {1, 2, 0}}, // + ReduceLayout{{3, 2, 1, 0}, {1, 0, 2}}, // + ReduceLayout{{3, 2, 1, 0}, {2, 0, 1}}, // + ReduceLayout{{3, 2, 1, 0}, {2, 1, 0}}, // + ReduceLayout{{3, 1, 2, 0}, {1, 2, 0}}, // + ReduceLayout{{1, 2, 3, 0}, {1, 0, 2}}, // + ReduceLayout{{0, 2, 1, 3}, {2, 0, 1}}), // + PrintReduceLayout); + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_precision_test.cc b/tensorflow/compiler/xla/tests/reduce_precision_test.cc index 4756ba096896806ece8fe35d18c4eaef041b8830..dc7ce3253cee255a7949326fa5b49fc8917432b8 100644 --- a/tensorflow/compiler/xla/tests/reduce_precision_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_precision_test.cc @@ -249,7 +249,9 @@ INSTANTIATE_TEST_CASE_P(ReducePrecisionAccuracyTest, // ReducePrecisionInsertion passes. class ReducePrecisionInsertionTest : public ClientLibraryTestBase {}; -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionBeforeFusion) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionBeforeFusion)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -276,7 +278,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionBeforeFusion) { ComputeAndCompareR1(&builder, {0.0f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedAfterFusion) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionSkippedAfterFusion)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -300,7 +304,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedAfterFusion) { ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedAfterFusion) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionAddedAfterFusion)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -322,7 +328,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedAfterFusion) { ComputeAndCompareR1(&builder, {-1.0f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedFusionContains) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionSkippedFusionContains)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); @@ -345,7 +353,9 @@ XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionSkippedFusionContains) { ComputeAndCompareR1(&builder, {-1.00001f}, {a_data.get()}); } -XLA_TEST_F(ReducePrecisionInsertionTest, ReducePrecisionAddedFusionContains) { +// The interpreter has no fusion pass, so skip this test. +XLA_TEST_F(ReducePrecisionInsertionTest, + DISABLED_ON_INTERPRETER(ReducePrecisionAddedFusionContains)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr a_literal = Literal::CreateR1({1.00001}); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 7bc3185c367f076c9a7d211c9799557e1a91d92f..50d7b5074d201d2292cf90224ef4cd37efdbb8d3 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -143,6 +143,55 @@ class ReduceTest : public ClientLibraryTestBase { ComputeAndCompareR0(&builder, expected, {input_global_data.get()}); } + // Reduce predicate tensor with dimension rows * cols to dimension cols, to + // test the implementation of atomic operations on misaligned small data + // types. + template + void RunR2ToR1PredTest(bool and_reduce, int64 rows, int64 minor = 1, + int64 major = 0) { + ComputationBuilder builder(client_, TestName()); + const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols}); + auto input = builder.Parameter(0, input_shape, "input"); + auto input_pred = builder.Eq(input, builder.ConstantR0(1)); + + ComputationDataHandle init_value; + Computation reduce_op; + if (and_reduce) { + init_value = builder.ConstantR0(true); + reduce_op = CreateScalarAndComputation(&builder); + } else { + init_value = builder.ConstantR0(false); + reduce_op = CreateScalarOrComputation(&builder); + } + + builder.Reduce(input_pred, init_value, reduce_op, + /*dimensions_to_reduce=*/{0}); + + Array2D input_data(rows, cols); + input_data.FillRandom(0, 1); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2D(input_data); + input_literal = + input_literal->Relayout(LayoutUtil::MakeLayout({minor, major})); + std::unique_ptr input_global_data = + client_->TransferToServer(*input_literal).ConsumeValueOrDie(); + + std::array expected; + for (int64 colno = 0; colno < cols; ++colno) { + bool column_sum = and_reduce ? true : false; + for (int64 rowno = 0; rowno < rows; ++rowno) { + if (and_reduce) { + column_sum = column_sum && input_data(rowno, colno); + } else { + column_sum = column_sum || input_data(rowno, colno); + } + } + expected[colno] = column_sum; + } + + ComputeAndCompareR1(&builder, expected, {input_global_data.get()}); + } + // Runs an R2 => R0 reduction test with the given number of (rows, cols). void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) { ComputationBuilder builder(client_, TestName()); @@ -352,15 +401,13 @@ XLA_TEST_F(ReduceTest, ReduceR2_111x50_01_To_R1) { XLA_TEST_F(ReduceTest, ReduceR2_1024x1024_To_R1) { RunR2ToR1Test(1024, 1024); } XLA_TEST_F(ReduceTest, ReduceR2_1000x1500_To_R1) { RunR2ToR1Test(1000, 1500); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, AndReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { @@ -369,15 +416,13 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(AndReduceOnesAndZerosR1_10_Pred)) { RunR1ToR0PredTest(/*and_reduce=*/true, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceAllOnesR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceAllOnesR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count, 1); RunR1ToR0PredTest(/*and_reduce=*/false, input); } -// TODO(b/34969189): Invalid CAS generated on GPU. -XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OrReduceOnesAndZerosR1_10_Pred)) { +XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) { constexpr int element_count = 10; std::vector input(element_count); for (int i = 0; i < element_count; ++i) { @@ -449,6 +494,26 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) { ErrorSpec(0.01, 1e-4)); } +// Test that algebraic simplifier does not incorrectly fold a transpose into a +// reduction operation. +XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) { + ComputationBuilder builder(client_, TestName()); + Computation add_f32 = CreateScalarAddComputation(F32, &builder); + const Shape input_shape = ShapeUtil::MakeShape(F32, {12, 111, 50}); + ComputationDataHandle input = builder.Parameter(0, input_shape, "input"); + ComputationDataHandle zero = builder.ConstantR0(0.0); + ComputationDataHandle transpose = + builder.Transpose(input, /*permutation=*/{1, 0, 2}); + ComputationDataHandle reduce = + builder.Reduce(transpose, zero, add_f32, /*dimensions_to_reduce=*/{0}); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, + MakeFakeLiteral(input_shape)); + + ComputeAndCompare(&builder, reduce, {std::move(*input_data)}, + ErrorSpec(0.01, 1e-4)); +} + XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { const int64 rows = 111, cols = 50; @@ -812,5 +877,12 @@ XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) { ComputeAndCompareR0(&builder, 4.0f, {b_data.get()}); } +XLA_TEST_F(ReduceTest, ReduceAndPredR2_128x64_To_R1) { + RunR2ToR1PredTest(/*and_reduce=true*/ true, /*rows=128*/ 128); +} +XLA_TEST_F(ReduceTest, ReduceOrPredR2_64x32_To_R1) { + RunR2ToR1PredTest(/*and_reduce=false*/ false, /*rows=64*/ 64); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 0601a1466bd87ab721443e0da725006e2d73e392..b11b64e40a582150d6adf29e915cd70b4bcb982b 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -41,16 +41,40 @@ limitations under the License. namespace xla { namespace { -class ReduceWindowTest : public ClientLibraryTestBase { +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + +class ReduceWindowTestBase : public ClientLibraryTestBase { public: - ReduceWindowTest() : builder_(client_, TestName()) {} + ErrorSpec DefaultErrorSpec() const { + if (use_bfloat16()) { + return ErrorSpec(1e-1, 5e-2); + } else { + return ErrorSpec(1e-3, 1e-3); + } + } +}; + +class ReduceWindowTest : public ::testing::WithParamInterface, + public ReduceWindowTestBase { + public: + ReduceWindowTest() : builder_(client_, TestName()) { + set_use_bfloat16(GetParam()); + } void ReduceWindowAdd(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), + auto init = + CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } @@ -58,30 +82,32 @@ class ReduceWindowTest : public ClientLibraryTestBase { tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow( - input, builder_.ConstantLiteral(Literal::MinValue(F32)), - CreateScalarMax(), window_dimensions, window_strides, padding); + auto init = CreateConstantFromLiteral(Literal::MinValue(F32), &builder_); + builder_.ReduceWindow(input, init, CreateScalarMax(), window_dimensions, + window_strides, padding); } void ReduceWindowMin(const ComputationDataHandle& input, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - builder_.ReduceWindow(input, - builder_.ConstantLiteral(Literal::MaxValue(F32)), - CreateScalarMinComputation(F32, &builder_), + auto init = CreateConstantFromLiteral(Literal::MaxValue(F32), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarMinComputation(FloatType(), &builder_), window_dimensions, window_strides, padding); } ComputationBuilder builder_; }; -TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { - const auto input = builder_.ConstantR1({1, 1, 1, 1}); - const auto init_value = builder_.ConstantR0(0); +TEST_P(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({1, 1, 1, 1}), &builder_); + const auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(0), &builder_); TF_ASSERT_OK(builder_.first_error()); builder_.ReduceWindow(input, init_value, - CreateScalarAddComputation(F32, &builder_), + CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{1, 2}, /*window_strides=*/{1}, Padding::kValid); ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) @@ -91,88 +117,106 @@ TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { } // Regression test for b/68964348. -TEST_F(ReduceWindowTest, R0ReduceWindow) { - auto input = builder_.ConstantR0(42); - auto init = builder_.ConstantR0(1.0); - builder_.ReduceWindow(input, init, CreateScalarAddComputation(F32, &builder_), +TEST_P(ReduceWindowTest, R0ReduceWindow) { + const auto input = + CreateConstantFromLiteral(*Literal::CreateR0(42.0), &builder_); + const auto init = + CreateConstantFromLiteral(*Literal::CreateR0(1.0), &builder_); + builder_.ReduceWindow(input, init, + CreateScalarAddComputation(FloatType(), &builder_), /*window_dimensions=*/{}, /*window_strides=*/{}, Padding::kSame); - ComputeAndCompareR0(&builder_, 43, {}, ErrorSpec(0.00001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR0(43.0), {}, + ErrorSpec(0.00001)); } -TEST_F(ReduceWindowTest, Min3In5Stride2) { - const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); +TEST_P(ReduceWindowTest, Min3In5Stride2) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); ReduceWindowMin(input, {3}, {2}, Padding::kValid); - ComputeAndCompareR1(&builder_, {100, 1}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({100, 1}), {}, + ErrorSpec(0.00001)); } -XLA_TEST_F(ReduceWindowTest, ZeroElementSmall) { - Array4D input_array(1, 0, 2, 1); +TEST_P(ReduceWindowTest, Min3In5Stride1WithSamePadding) { + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1({10000, 1000, 100, 10, 1}), &builder_); + ReduceWindowMin(input, /*window_dimensions=*/{3}, /*window_strides=*/{1}, + Padding::kSame); + ComputeAndCompareLiteral(&builder_, + *Literal::CreateR1({1000, 100, 10, 1, 1}), {}, + ErrorSpec(0.00001)); +} - const auto input = builder_.ConstantR4FromArray4D(input_array); +XLA_TEST_P(ReduceWindowTest, ZeroElementSmall) { + Array4D input_array(1, 0, 2, 1); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, NonSquareSmall) { +TEST_P(ReduceWindowTest, NonSquareSmall) { Array4D input_array(1, 2, 2, 1); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 2, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, MiddleDimsSmall) { +TEST_P(ReduceWindowTest, MiddleDimsSmall) { Array4D input_array(1, 3, 3, 1); - input_array.FillRandom(2.f); - - const auto input = builder_.ConstantR4FromArray4D(input_array); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 1, 1}, {1, 2, 2, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, Along2ndMinorDim) { +TEST_P(ReduceWindowTest, Along2ndMinorDim) { Array4D input_array(3, 6, 7, 32); - input_array.FillRandom(2.f); + input_array.FillRandom(2.f, 2.f); + const auto input = CreateConstantFromArray(input_array, &builder_); // The parameters of this reduction mimic feature norm (e.g. LRN). int lrn_diameter = 7; // diameter = 2*radius + 1 --> must be odd - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, lrn_diameter, 1}, {1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2Dims) { +TEST_P(ReduceWindowTest, AmongMajor2Dims) { Array4D input_array(4, 4, 6, 8); input_array.FillWithMinorDimNum(); + const auto input_data_handle = + CreateConstantFromArray(input_array, &builder_); int win_len = 3; int win_stride = 1; Padding padding = Padding::kSame; - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -180,18 +224,20 @@ TEST_F(ReduceWindowTest, AmongMajor2Dims) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMediumSize) { Array4D input_array(9, 12, 4, 89); - input_array.FillRandom(2.0f); + input_array.FillRandom(2.f, 2.f); int win_len = 3; int win_stride = 2; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); Padding padding = Padding::kSame; // Reduce only along the x and y dimensions, according to the win_len. @@ -202,137 +248,57 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMediumSize) { input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); -} - -// TODO(b/32173947): Test support for arbitrary-sized padding. -TEST_F(ReduceWindowTest, DISABLED_AmongMajor2DimsMediumSizeLargePadding) { - Array4D input_array(9, 12, 4, 89); // simulate Dim0IsMinor layout - input_array.FillRandom(2.0f); - - int64 rank = 4; - int win_len = 3; - int win_stride = 2; - - const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); - - Padding padding = Padding::kSame; - // Reduce only along the x and y dimensions, according to the win_len. - // Create padding vector with large padding values in the reduction dims. - std::vector> low_high_padding; - low_high_padding.resize(rank, {4, 4}); - - builder_.ReduceWindowWithGeneralPadding( - input_data_handle, builder_.ConstantR0(0.0f), - CreateScalarAddComputation(F32, &builder_), {win_len, win_len, 1, 1}, - {win_stride, win_stride, 1, 1}, low_high_padding); - - auto result = ReferenceUtil::ReduceWindow4DAdd( - input_array, 0.0f, {win_len, win_len, 1, 1}, - {win_stride, win_stride, 1, 1}, padding); - - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); -} - -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x2) { - Array3D input_array(2, 1, 2); - input_array(0, 0, 0) = 1000; - input_array(0, 0, 1) = 100; - input_array(1, 0, 0) = 10; - input_array(1, 0, 1) = 1; - auto input = builder_.ConstantR3FromArray3D(input_array); - - ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kValid); - - Array3D expected(2, 1, 1); - expected(0, 0, 0) = 1100; - expected(1, 0, 0) = 11; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3Stride1x1x2) { - Array3D input_array(2, 1, 3); - input_array(0, 0, 0) = 100; - input_array(0, 0, 1) = 10; - input_array(0, 0, 2) = 1; - input_array(1, 0, 0) = 500; - input_array(1, 0, 1) = 50; - input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); - - ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 2}, Padding::kValid); - - Array3D expected(2, 1, 1); - expected(0, 0, 0) = 110; - expected(1, 0, 0) = 550; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); -} - -XLA_TEST_F(ReduceWindowTest, Add1x1x2In2x1x3SamePad) { - Array3D input_array(2, 1, 3); - input_array(0, 0, 0) = 100; - input_array(0, 0, 1) = 10; - input_array(0, 0, 2) = 1; - input_array(1, 0, 0) = 500; - input_array(1, 0, 1) = 50; - input_array(1, 0, 2) = 5; - auto input = builder_.ConstantR3FromArray3D(input_array); - - ReduceWindowAdd(input, {1, 1, 2}, {1, 1, 1}, Padding::kSame); - - Array3D expected(2, 1, 3); - expected(0, 0, 0) = 110; - expected(0, 0, 1) = 11; - expected(0, 0, 2) = 1; - expected(1, 0, 0) = 550; - expected(1, 0, 1) = 55; - expected(1, 0, 2) = 5; - ComputeAndCompareR3(&builder_, expected, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } // Tests a reduction function that is not a simple add/min/max/etc. -XLA_TEST_F(ReduceWindowTest, NonstandardReduceFunction) { +XLA_TEST_P(ReduceWindowTest, NonstandardReduceFunction) { Array4D input_array(1, 2, 2, 1); input_array(0, 0, 0, 0) = 1; input_array(0, 0, 1, 0) = 2; input_array(0, 1, 0, 0) = 3; input_array(0, 1, 1, 0) = 4; + const auto input = CreateConstantFromArray(input_array, &builder_); - const auto input = builder_.ConstantR4FromArray4D(input_array); Padding padding = Padding::kValid; - - const Shape scalar = ShapeUtil::MakeShape(F32, {}); + const Shape scalar = ShapeUtil::MakeShape(FloatType(), {}); auto b = builder_.CreateSubBuilder("unusual"); auto lhs = b->Parameter(0, scalar, "lhs"); auto rhs = b->Parameter(1, scalar, "rhs"); - b->Min(b->Add(lhs, rhs), b->ConstantR0(8.0f)); + b->Min(b->Add(lhs, rhs), + CreateConstantFromLiteral(*Literal::CreateR0(8.0f), b.get())); Computation reduce_fn = b->BuildAndNoteError(); - builder_.ReduceWindow(input, builder_.ConstantR0(3.0f), reduce_fn, - /*window_dimensions=*/{1, 1, 2, 1}, - /*window_strides=*/{1, 1, 1, 1}, padding); + builder_.ReduceWindow( + input, + CreateConstantFromLiteral(*Literal::CreateR0(0.0f), &builder_), + reduce_fn, + /*window_dimensions=*/{1, 1, 2, 1}, + /*window_strides=*/{1, 1, 1, 1}, padding); const auto reduce_func = [](float arg1, float arg2) { return std::min(arg1 + arg2, 8.0f); }; auto expected = - ReferenceUtil::ReduceWindow4DGeneric(input_array, 3.0f, reduce_func, + ReferenceUtil::ReduceWindow4DGeneric(input_array, 0.0f, reduce_func, /*window=*/{1, 1, 2, 1}, /*stride=*/{1, 1, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *expected, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*expected), {}, + DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R4UnitWindow) { +TEST_P(ReduceWindowTest, R4UnitWindow) { Array4D input_array(13, 12, 8, 15); - input_array.Fill(1.0f); + input_array.FillRandom(2.f, 2.f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({0, 3, 2, 1})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); Padding padding = Padding::kSame; ReduceWindowAdd(input, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); @@ -340,15 +306,11 @@ TEST_F(ReduceWindowTest, R4UnitWindow) { auto res = ReferenceUtil::ReduceWindow4DAdd(input_array, 0.0f, {1, 1, 7, 1}, {1, 4, 1, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6AddMultipleStrides) { std::vector input_dims(6, 8); auto shape = ShapeUtil::MakeShape(F32, input_dims); @@ -358,56 +320,15 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(arg_literal->Populate(generator)); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - // Non-monotonic output layout with minor dims trivial. + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {3, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); + std::vector output_layout = {1, 5, 3, 2, 0, 4}; std::vector output_dims = {6, 8, 6, 6, 8, 8}; Shape result_shape = ShapeUtil::MakeShapeWithLayout(F32, output_dims, output_layout); - b.AddInstruction(HloInstruction::CreateReduceWindow( - result_shape, input, init_value, window, add_func)); - std::unique_ptr expected = Literal::CreateFromShape(result_shape); auto out_generator = [&](tensorflow::gtl::ArraySlice indexes) -> float { @@ -415,82 +336,37 @@ XLA_TEST_F(HloTestBase, R6AddMultipleStrides) { }; TF_EXPECT_OK(expected->Populate(out_generator)); - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); - - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(HloTestBase, R6Add) { - auto b = HloComputation::Builder(TestName()); - +XLA_TEST_P(ReduceWindowTest, R6Add) { std::vector input_dims(6, 8); + auto shape = ShapeUtil::MakeShape(F32, input_dims); + std::unique_ptr arg_literal = - Literal::CreateFullWithMonotonicDim0MajorLayout(input_dims, 1.0f); - auto input = - b.AddInstruction(HloInstruction::CreateConstant(std::move(arg_literal))); - - auto init_value = b.AddInstruction( - HloInstruction::CreateConstant(Literal::CreateR0(0.f))); - - HloComputation::Builder add_computation("add"); - Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); - auto param_lhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(0, scalar_shape, "lhs")); - auto param_rhs = add_computation.AddInstruction( - HloInstruction::CreateParameter(1, scalar_shape, "rhs")); - add_computation.AddInstruction(HloInstruction::CreateBinary( - scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - - auto module = CreateNewModule(); - auto add_func = module->AddEmbeddedComputation(add_computation.Build()); - - WindowDimension trivial_dim; - trivial_dim.set_size(1); - trivial_dim.set_stride(1); - trivial_dim.set_padding_low(0); - trivial_dim.set_padding_high(0); - trivial_dim.set_window_dilation(1); - trivial_dim.set_base_dilation(1); - - WindowDimension active_dim; - active_dim.set_size(3); - active_dim.set_stride(1); - active_dim.set_padding_low(0); - active_dim.set_padding_high(0); - active_dim.set_window_dilation(1); - active_dim.set_base_dilation(1); - - Window window; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = active_dim; - *window.add_dimensions() = trivial_dim; - *window.add_dimensions() = trivial_dim; - - Shape shape = ShapeUtil::MakeShape(F32, {8, 8, 6, 6, 8, 8}); - b.AddInstruction(HloInstruction::CreateReduceWindow(shape, input, init_value, - window, add_func)); + Literal::CreateFullWithDescendingLayout(input_dims, 1.0f); + + const auto input = CreateConstantFromLiteral(*arg_literal, &builder_); + + Padding padding = Padding::kValid; + ReduceWindowAdd(input, {1, 1, 3, 3, 1, 1}, {1, 1, 1, 1, 1, 1}, padding); std::vector output_dims = {8, 8, 6, 6, 8, 8}; std::unique_ptr expected = - Literal::CreateFullWithMonotonicDim0MajorLayout(output_dims, 9.0f); - - module->AddEntryComputation(b.Build()); - auto actual = ExecuteAndTransfer(std::move(module), {}); + Literal::CreateFullWithDescendingLayout(output_dims, 9.0f); - LiteralTestUtil::ExpectNear(*actual, *expected, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *expected, {}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorStride) { Array4D input_array(2, 1, 27, 119); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 1; int stride = 8; @@ -500,20 +376,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorUnitStride) { Array4D input_array(3, 2, 4, 64); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 3; int stride = 1; @@ -523,20 +398,19 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorUnitStride) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { +XLA_TEST_P(ReduceWindowTest, R4SecondMinorWin) { Array4D input_array(1, 3, 12, 200); input_array.FillRandom(2.0f); std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({3, 2, 1, 0})); - ComputationDataHandle input = - builder_.Parameter(0, input_literal->shape(), "operand"); + ComputationDataHandle input; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "parameter", &builder_, &input); int win_len = 8; int stride = 5; @@ -546,13 +420,11 @@ XLA_TEST_F(ReduceWindowTest, R4SecondMinorWin) { auto res = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {1, 1, win_len, 1}, {1, 1, stride, 1}, padding); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_data, - client_->TransferToServer(*input_literal)); - ComputeAndCompareR4(&builder_, *res, {input_data.get()}, - ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {input_data.get()}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { +TEST_P(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Array4D input_array(6, 4, 10, 130); input_array.FillRandom(2.0f); @@ -561,7 +433,7 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { Padding padding = Padding::kSame; const auto input_data_handle = - builder_.ConstantR4FromArray4D(input_array); + CreateConstantFromArray(input_array, &builder_); // Reduce only along the x and y dimensions, according to the win_len. ReduceWindowAdd(input_data_handle, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); @@ -569,36 +441,59 @@ TEST_F(ReduceWindowTest, AmongMajor2DimsMultipleMinor) { auto result = ReferenceUtil::ReduceWindow4DAdd( input_array, 0.0f, {win_len, win_len, 1, 1}, {win_stride, win_stride, 1, 1}, padding); - ComputeAndCompareR4(&builder_, *result, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*result), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add24In1152_NoOverlap) { +XLA_TEST_P(ReduceWindowTest, Add24In1152_NoOverlap) { std::vector input_vector(128 * 9, 1); - const auto input = builder_.ConstantR1(input_vector); + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {32}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {32, 32, 32, 32, 32, 32, 32, 32, 32}, - {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral( + &builder_, + *Literal::CreateR1({32, 32, 32, 32, 32, 32, 32, 32, 32}), {}, + DefaultErrorSpec()); } -XLA_TEST_F(ReduceWindowTest, Add128In128Stride128) { - const auto input = builder_.ConstantR1( - {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, - 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); +XLA_TEST_P(ReduceWindowTest, Add128In128Stride128) { + std::vector input_vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); ReduceWindowAdd(input, {128}, {128}, Padding::kValid); - ComputeAndCompareR1(&builder_, {1088}, {}, ErrorSpec(0.0001)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + DefaultErrorSpec()); +} + +XLA_TEST_P(ReduceWindowTest, Add128In128) { + std::vector input_vector{ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}; + const auto input = CreateConstantFromLiteral( + *Literal::CreateR1(input_vector), &builder_); + ReduceWindowAdd(input, {128}, {1}, Padding::kValid); + ComputeAndCompareLiteral(&builder_, *Literal::CreateR1({1088}), {}, + DefaultErrorSpec()); } // Regression test for a bug that appeared in Inception (b/34784899). -TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { Array2D input_array(14, 14, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {14, 14}); + const auto input = CreateConstantFromArray(input_array, &builder_); int win_len = 3; int stride = 1; @@ -608,13 +503,14 @@ TEST_F(ReduceWindowTest, R2ReduceWindowInceptionFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd( input_array, 0.0f, {win_len, win_len}, {stride, stride}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } -TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { +TEST_P(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { Array2D input_array(6, 4, 1.0f); - ComputationDataHandle input = - builder_.Broadcast(builder_.ConstantLiteral(Literal::One(F32)), {6, 4}); + ComputationDataHandle input = builder_.Broadcast( + CreateConstantFromLiteral(Literal::One(F32), &builder_), {6, 4}); Padding padding = Padding::kSame; ReduceWindowAdd(input, {4, 2}, {3, 3}, padding); @@ -622,9 +518,13 @@ TEST_F(ReduceWindowTest, R2ReduceWindowNonOverlappingFromBroadcast) { auto res = ReferenceUtil::ReduceWindow2DAdd(input_array, 0.0f, {4, 2}, {3, 3}, padding); - ComputeAndCompareR2(&builder_, *res, {}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&builder_, *Literal::CreateFromArray(*res), + {}, DefaultErrorSpec()); } +INSTANTIATE_TEST_CASE_P(ReduceWindowTestInstance, ReduceWindowTest, + ::testing::ValuesIn(use_bfloat16_params)); + enum Reducer { kAdd, kMax }; struct R4ReduceWindowTestData { @@ -633,35 +533,43 @@ struct R4ReduceWindowTestData { int64 strides[4]; int64 pad_low[4]; int64 pad_high[4]; + int64 layout[4]; Reducer reducer; }; string R4ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__pad_low_", tensorflow::str_util::Join(data.param.pad_low, "x"), // - "__pad_high_", tensorflow::str_util::Join(data.param.pad_high, "x"), // - (data.param.reducer == kAdd) ? "add" : "max"); - CHECK(data.param.reducer == kAdd || data.param.reducer == kMax); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), // + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), // + "__layout_", tensorflow::str_util::Join(param.layout, "_"), // + (param.reducer == kAdd) ? "_add" : "_max"); + CHECK(param.reducer == kAdd || param.reducer == kMax); // Test names are not allowed to contain the '-' character. std::replace(str.begin(), str.end(), '-', 'n'); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R4ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface { +class R4ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { protected: + R4ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } + void DoIt() { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); const float kInitValue = 0.0f; @@ -669,24 +577,26 @@ class R4ReduceWindowTest param.base_bounds[2], param.base_bounds[3]); input.FillIota(1); std::unique_ptr input_literal = - Literal::CreateR4FromArray4D(input); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + Literal::CreateR4FromArray4DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); std::vector> padding(4); for (int i = 0; i < 4; ++i) { padding[i] = {param.pad_low[i], param.pad_high[i]}; } - auto parameter = b.Parameter(0, input_literal->shape(), "p0"); - auto pad_value = b.ConstantR0(kInitValue); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); CHECK(param.reducer == kAdd || param.reducer == kMax); auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); b.ReduceWindowWithGeneralPadding( /*operand=*/parameter, - /*init_value=*/pad_value, + /*init_value=*/init_value, /*computation=*/computation, /*window_dimensions=*/param.window_bounds, /*window_strides=*/param.strides, @@ -704,8 +614,13 @@ class R4ReduceWindowTest /*window=*/param.window_bounds, /*stride=*/param.strides, /*padding=*/padding); - ComputeAndCompareR4(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); + std::unique_ptr expected_literal = + Literal::CreateFromArray(*expected); + const Shape& expected_shape_with_layout = ShapeUtil::MakeShapeWithLayout( + input_literal->shape().element_type(), + AsInt64Slice(expected_literal->shape().dimensions()), param.layout); + ComputeAndCompareLiteral(&b, *expected_literal, {input_arg.get()}, + DefaultErrorSpec(), &expected_shape_with_layout); } }; @@ -719,6 +634,16 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + + // Arbitrary padding (not kSame or kValid). + R4ReduceWindowTestData{/*base_bounds=*/{9, 12, 4, 89}, + /*window_bounds=*/{3, 3, 1, 1}, + /*strides=*/{2, 2, 1, 1}, + /*pad_low=*/{4, 4, 0, 0}, + /*pad_high=*/{4, 4, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Zero base bound edge case. @@ -727,6 +652,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With non-1x1 window. @@ -735,6 +661,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With max instead of add. @@ -743,6 +670,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kMax}, // With stride. @@ -751,6 +679,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{2, 4, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With low padding. @@ -759,6 +688,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{2, 2, 1, 1}, /*pad_low=*/{3, 2, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With high padding. @@ -767,6 +697,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{2, 2, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{2, 3, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Window touches both sides of the padding simultaneously. @@ -775,6 +706,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{1, 1, 0, 0}, /*pad_high=*/{1, 1, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Window is entirely in the padding for some positions. @@ -783,6 +715,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{4, 4, 0, 0}, /*pad_high=*/{4, 4, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // Zero base bound with padding edge case. @@ -791,6 +724,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 1, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With stride, low padding and high padding. @@ -799,6 +733,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{3, 1, 1, 1}, /*pad_low=*/{10, 1, 0, 0}, /*pad_high=*/{2, 3, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With second minor dimension == 9. @@ -807,6 +742,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With minor dimension == 129. @@ -815,6 +751,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With minor dims reduction and non-overlapped stride. @@ -823,6 +760,7 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*strides=*/{1, 1, 2, 2}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, // With minor dims reduction and overlapped stride. @@ -830,25 +768,29 @@ const R4ReduceWindowTestData kR4ReduceWindowTestValues[] = { /*window_bounds=*/{1, 1, 4, 4}, /*strides=*/{1, 1, 2, 2}, /*pad_low=*/{0, 0, 0, 0}, - /*pad_high=*/{0, 0, 0, 0}, + /*pad_high=*/{1, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowTestInstantiation, R4ReduceWindowTest, - ::testing::ValuesIn(kR4ReduceWindowTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowTestInstantiation, R4ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); class R4ReduceWindowLargeTest : public R4ReduceWindowTest {}; -XLA_TEST_P(R4ReduceWindowLargeTest, DoIt) { DoIt(); } +XLA_TEST_P(R4ReduceWindowLargeTest, DISABLED_ON_INTERPRETER(DoIt)) { DoIt(); } // Test cases that are large/slow/failed. const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { R4ReduceWindowTestData{/*base_bounds=*/{28, 28, 256, 128}, - /*window_bounds=*/{3, 3, 1, 1}, - /*strides=*/{1, 1, 1, 1}, + /*window_bounds=*/{3, 3, 1, 5}, + /*strides=*/{1, 1, 1, 5}, /*pad_low=*/{1, 1, 0, 0}, /*pad_high=*/{1, 1, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kMax}, R4ReduceWindowTestData{/*base_bounds=*/{112, 112, 64, 128}, @@ -856,13 +798,163 @@ const R4ReduceWindowTestData kR4ReduceWindowLargeTestValues[] = { /*strides=*/{2, 2, 1, 1}, /*pad_low=*/{0, 0, 0, 0}, /*pad_high=*/{1, 1, 0, 0}, + /*layout=*/{3, 2, 1, 0}, /*reducer=*/kAdd}, + + R4ReduceWindowTestData{/*base_bounds=*/{1, 1, 32768 - 3, 2}, + /*window_bounds=*/{1, 1, 4, 1}, + /*strides=*/{1, 1, 4, 1}, + /*pad_low=*/{0, 0, 1, 0}, + /*pad_high=*/{0, 0, 2, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kMax}, }; -INSTANTIATE_TEST_CASE_P(R4ReduceWindowLargeTestInstantiation, - R4ReduceWindowLargeTest, - ::testing::ValuesIn(kR4ReduceWindowLargeTestValues), - R4ReduceWindowTestDataToString); +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowLargeTestInstantiation, R4ReduceWindowLargeTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowLargeTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); + +class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {}; + +// TODO(b/72234705): Fix the test cases failed on CPU and GPU. +XLA_TEST_P(R4ReduceWindowAnyDimsTest, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { + DoIt(); +} + +const R4ReduceWindowTestData kR4ReduceWindowAnyDimsTestValues[] = { + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kAdd}, + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 1, 1}, + /*strides=*/{1, 1, 1, 1}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{3, 2, 1, 0}, + /*reducer=*/kMax}, + // With 0321 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 140}, + /*window_bounds=*/{2, 3, 4, 5}, + /*strides=*/{1, 2, 3, 4}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 3, 2, 1}, + /*reducer=*/kAdd}, + + // With 0123 layout. + R4ReduceWindowTestData{/*base_bounds=*/{4, 6, 17, 23}, + /*window_bounds=*/{2, 3, 7, 9}, + /*strides=*/{1, 2, 5, 8}, + /*pad_low=*/{0, 0, 0, 0}, + /*pad_high=*/{0, 0, 0, 0}, + /*layout=*/{0, 1, 2, 3}, + /*reducer=*/kAdd}, +}; + +INSTANTIATE_TEST_CASE_P( + R4ReduceWindowAnyDimsTestInstantiation, R4ReduceWindowAnyDimsTest, + ::testing::Combine(::testing::ValuesIn(kR4ReduceWindowAnyDimsTestValues), + ::testing::ValuesIn(use_bfloat16_params)), + R4ReduceWindowTestDataToString); + +struct R3ReduceWindowTestData { + int64 base_bounds[3]; + int64 window_bounds[3]; + int64 strides[3]; + int64 layout[3]; + Padding padding; + Reducer reducer; +} kR3TestCases[] = { + {/*base_bounds=*/{2, 1, 2}, /*window_bounds=*/{1, 1, 2}, + /*strides=*/{1, 1, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{4, 3, 3}, /*window_bounds=*/{2, 2, 2}, + /*strides=*/{2, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, + /*strides=*/{1, 2, 2}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{10, 21, 129}, /*window_bounds=*/{2, 9, 1}, + /*strides=*/{5, 2, 1}, /*layout=*/{2, 1, 0}, + /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, + /*strides=*/{1, 2, 2}, /*layout=*/{0, 1, 2}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + {/*base_bounds=*/{6, 21, 3}, /*window_bounds=*/{2, 3, 2}, + /*strides=*/{1, 2, 2}, /*layout=*/{1, 0, 2}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, +}; + +string R3ReduceWindowTestDataToString( + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); + string str = tensorflow::strings::StrCat( + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), + "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), + "__strides_", tensorflow::str_util::Join(param.strides, "x"), + "__padding_", param.padding == Padding::kSame ? "same" : "valid", + "__layout_", param.layout[0], "_", param.layout[1], "_", param.layout[2], + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } + return str; +} + +class R3ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R3ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; + +TEST_P(R3ReduceWindowTest, Add) { + ComputationBuilder b(client_, TestName()); + const auto& param = ::testing::get<0>(GetParam()); + CHECK(param.reducer == kAdd); + + const float kInitValue = 0.0f; + Array3D input(param.base_bounds[0], param.base_bounds[1], + param.base_bounds[2], 1.0f); + std::unique_ptr input_literal = + Literal::CreateR3FromArray3DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto expected = ReferenceUtil::ReduceWindow3DAdd( + /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); +} + +INSTANTIATE_TEST_CASE_P( + R3ReduceWindowTestInstantiation, R3ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR3TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R3ReduceWindowTestDataToString); struct R2ReduceWindowTestData { int64 base_bounds[2]; @@ -910,130 +1002,217 @@ struct R2ReduceWindowTestData { }; string R2ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), // "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__layout_", data.param.layout[0], "_", data.param.layout[1], // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + tensorflow::str_util::Join(param.window_bounds, "x"), // + "__strides_", tensorflow::str_util::Join(param.strides, "x"), // + "__padding_", param.padding == Padding::kSame ? "same" : "valid", // + "__layout_", param.layout[0], "_", param.layout[1], // + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R2ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R2ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R2ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } -TEST_P(R2ReduceWindowTest, Add) { - ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); - CHECK(param.reducer == kAdd); + void DoIt() { + ComputationBuilder b(client_, TestName()); + const auto& param = ::testing::get<0>(GetParam()); + CHECK(param.reducer == kAdd); - const float kInitValue = 0.0f; - Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); - std::unique_ptr input_literal = - Literal::CreateR2FromArray2DWithLayout( - input, LayoutUtil::MakeLayout(param.layout)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/CreateScalarAddComputation(F32, &b), - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + const float kInitValue = 0.0f; + Array2D input(param.base_bounds[0], param.base_bounds[1], 1.0f); + std::unique_ptr input_literal = + Literal::CreateR2FromArray2DWithLayout( + input, LayoutUtil::MakeLayout(param.layout)); + + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindow(/*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/CreateScalarAddComputation(FloatType(), &b), + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/param.padding); + + auto expected = ReferenceUtil::ReduceWindow2DAdd( + /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, + /*stride=*/param.strides, /*padding=*/param.padding); + + ComputeAndCompareLiteral(&b, *Literal::CreateFromArray(*expected), + {input_arg.get()}, DefaultErrorSpec()); + } +}; - auto expected = ReferenceUtil::ReduceWindow2DAdd( - /*operand=*/input, /*init=*/kInitValue, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); +TEST_P(R2ReduceWindowTest, DoIt) { DoIt(); } + +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowTestInstantiation, R2ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR2TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); + +class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {}; - ComputeAndCompareR2(&b, *expected, {input_arg.get()}, - ErrorSpec(1e-3, 1e-3)); +// TODO(b/72234705): Fix the test cases failed on CPU and GPU. +XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test, + DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) { + DoIt(); } -INSTANTIATE_TEST_CASE_P(R2ReduceWindowTestInstantiation, R2ReduceWindowTest, - ::testing::ValuesIn(kR2TestCases), - R2ReduceWindowTestDataToString); +const R2ReduceWindowTestData kR2FailingValuesCpuGpuBf16Test[] = { + {/*base_bounds=*/{8, 128}, /*window_bounds=*/{8, 128}, + /*strides=*/{1, 1}, /*layout=*/{1, 0}, + /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, +}; + +INSTANTIATE_TEST_CASE_P( + R2ReduceWindowFailingInstantiation, R2ReduceWindowFailingCpuGpuBf16Test, + ::testing::Combine(::testing::ValuesIn(kR2FailingValuesCpuGpuBf16Test), + ::testing::ValuesIn(use_bfloat16_params)), + R2ReduceWindowTestDataToString); struct R1ReduceWindowTestData { int64 base_bounds[1]; int64 window_bounds[1]; int64 strides[1]; - Padding padding; + int64 pad_low[1]; + int64 pad_high[1]; Reducer reducer; } kR1TestCases[] = { {/*base_bounds=*/{1}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{3}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{1}, /*strides=*/{1}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{4}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kMax}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30}, + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{30}, /*strides=*/{27}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 17}, + /*window_bounds=*/{7}, /*strides=*/{64}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, - - {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32}, + /*pad_low=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{128 * 2}, + /*window_bounds=*/{32}, /*strides=*/{56}, - /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first}, + /*pad_high=*/ + {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{3}, /*window_bounds=*/{2}, /*strides=*/{1}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{5}, /*window_bounds=*/{3}, /*strides=*/{2}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, {/*base_bounds=*/{16}, /*window_bounds=*/{4}, /*strides=*/{3}, - /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd}, + /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first}, + /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{0}, + /*pad_high=*/{5}, + /*reducer=*/Reducer::kAdd}, + + {/*base_bounds=*/{5}, /*window_bounds=*/{5}, + /*strides=*/{1}, + /*pad_low=*/{5}, + /*pad_high=*/{0}, + /*reducer=*/Reducer::kAdd}, }; string R1ReduceWindowTestDataToString( - const ::testing::TestParamInfo& data) { + const ::testing::TestParamInfo< + ::testing::tuple>& data) { + const auto& param = ::testing::get<0>(data.param); string str = tensorflow::strings::StrCat( - "base_bounds_", - tensorflow::str_util::Join(data.param.base_bounds, "x"), // - "__window_bounds_", - tensorflow::str_util::Join(data.param.window_bounds, "x"), // - "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), // - "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", // - "__reducer_", data.param.reducer == kAdd ? "add" : "max"); + "base_bounds_", tensorflow::str_util::Join(param.base_bounds, "x"), + "__window_bounds_", tensorflow::str_util::Join(param.window_bounds, "x"), + "__strides_", tensorflow::str_util::Join(param.strides, "x"), + "__pad_low_", tensorflow::str_util::Join(param.pad_low, "x"), + "__pad_high_", tensorflow::str_util::Join(param.pad_high, "x"), + "__reducer_", param.reducer == kAdd ? "add" : "max"); + if (::testing::get<1>(data.param)) { + str = tensorflow::strings::StrCat(str, "_bfloat16"); + } return str; } -class R1ReduceWindowTest - : public ClientLibraryTestBase, - public ::testing::WithParamInterface {}; +class R1ReduceWindowTest : public ReduceWindowTestBase, + public ::testing::WithParamInterface< + ::testing::tuple> { + protected: + R1ReduceWindowTest() { set_use_bfloat16(::testing::get<1>(GetParam())); } +}; TEST_P(R1ReduceWindowTest, DoIt) { ComputationBuilder b(client_, TestName()); - const auto& param = GetParam(); + const auto& param = ::testing::get<0>(GetParam()); CHECK(param.reducer == kAdd || param.reducer == kMax); const float kInitValue = 0.0f; @@ -1041,18 +1220,24 @@ TEST_P(R1ReduceWindowTest, DoIt) { std::iota(std::begin(input_vector), std::end(input_vector), 0); std::unique_ptr input_literal = Literal::CreateR1(tensorflow::gtl::ArraySlice(input_vector)); - TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr input_arg, - client_->TransferToServer(*input_literal)); + ComputationDataHandle parameter; + auto input_arg = CreateParameterAndTransferLiteral(0, *input_literal, "p0", + &b, ¶meter); + + std::vector> padding(1); + padding[0] = {param.pad_low[0], param.pad_high[0]}; auto computation = param.reducer == kAdd - ? CreateScalarAddComputation(F32, &b) - : CreateScalarMaxComputation(F32, &b); - b.ReduceWindow(/*operand=*/ - b.Parameter(0, input_literal->shape(), "p0"), - /*init_value=*/b.ConstantR0(kInitValue), - /*computation=*/computation, - /*window_dimensions=*/param.window_bounds, - /*window_strides=*/param.strides, /*padding=*/param.padding); + ? CreateScalarAddComputation(FloatType(), &b) + : CreateScalarMaxComputation(FloatType(), &b); + auto init_value = + CreateConstantFromLiteral(*Literal::CreateR0(kInitValue), &b); + b.ReduceWindowWithGeneralPadding( + /*operand=*/parameter, + /*init_value=*/init_value, + /*computation=*/computation, + /*window_dimensions=*/param.window_bounds, + /*window_strides=*/param.strides, /*padding=*/padding); auto reduce_func = param.reducer == kAdd ? +[](float a, float b) { return a + b; } @@ -1062,14 +1247,73 @@ TEST_P(R1ReduceWindowTest, DoIt) { /*init=*/kInitValue, /*reduce_func=*/reduce_func, /*window=*/param.window_bounds, - /*stride=*/param.strides, /*padding=*/param.padding); + /*stride=*/param.strides, + /*padding=*/padding); - ComputeAndCompareR1(&b, tensorflow::gtl::ArraySlice(*expected), - {input_arg.get()}, ErrorSpec(1e-3, 1e-3)); + ComputeAndCompareLiteral(&b, *Literal::CreateR1(*expected), + {input_arg.get()}, DefaultErrorSpec()); +} + +INSTANTIATE_TEST_CASE_P( + R1ReduceWindowTestInstantiation, R1ReduceWindowTest, + ::testing::Combine(::testing::ValuesIn(kR1TestCases), + ::testing::ValuesIn(use_bfloat16_params)), + R1ReduceWindowTestDataToString); + +// Test class for text-based test cases. Note that this compares with the +// results on the interpreter backend. +class ReduceWindowTextTest : public HloTestBase {}; + +TEST_F(ReduceWindowTextTest, R2General256x384) { + const string& hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[256,384]{1,0} parameter(0) + constant = f32[] constant(1) + ROOT reduce-window = f32[256,384]{1,0} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R2General256x384Layout01) { + const string& hlo_string = R"( +HloModule R2Window +mul { +lhs = f32[] parameter(0) +rhs = f32[] parameter(1) +ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { +operand = f32[256,384]{0,1} parameter(0) +constant = f32[] constant(1) +ROOT reduce-window = f32[256,384]{0,1} reduce-window(operand, constant), window={size=2x3 pad=0_1x1_1}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); +} + +TEST_F(ReduceWindowTextTest, R2General2x5) { + const string& hlo_string = R"( +HloModule R2Window +mul { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT mul = f32[] multiply(lhs, rhs) +} +ENTRY R2Window { + operand = f32[2,5]{1,0} parameter(0) + constant = f32[] constant(1) + ROOT reduce-window = f32[3,5]{1,0} reduce-window(operand, constant), window={size=2x1 pad=0_2x0_0}, to_apply=mul +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, ErrorSpec{0.001})); } -INSTANTIATE_TEST_CASE_P(R1ReduceWindowTestInstantiation, R1ReduceWindowTest, - ::testing::ValuesIn(kR1TestCases), - R1ReduceWindowTestDataToString); } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index d235b9a1580ecbd6b82a69fca53d259912ff375e..f7b04debd4f5c40a904e32c832b6fc384a03c33b 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -41,326 +41,467 @@ limitations under the License. namespace xla { namespace { -class ReshapeTest : public ClientLibraryTestBase { +// Use a bool parameter to indicate whether to use bfloat16. +class ReshapeTest : public ::testing::WithParamInterface, + public ClientLibraryTestBase { public: + ReshapeTest() { set_use_bfloat16(GetParam()); } + ErrorSpec zero_error_spec_{0.0}; }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { +XLA_TEST_P(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1EmptyDims) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { +XLA_TEST_P(ReshapeTest, CollapseTrivialR1OnlyDim) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0}); - builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); - - ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0}); + + auto expected_literal = Literal::CreateR1({1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. -XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { +XLA_TEST_P(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0}}); - auto reshape = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, /*new_sizes=*/{}); + Array2D input_array(1, 1); + input_array.Fill(1.0f); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "parameter", + &builder, ¶meter); + auto reshape = builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{}); auto new_shape = builder.GetShape(reshape).ConsumeValueOrDie(); - ComputeAndCompareR0(&builder, 1.0f, {}, zero_error_spec_); + auto expected_literal = Literal::CreateR0(1.0f); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ScalarToSingleElementArray) { +XLA_TEST_P(ReshapeTest, ScalarToSingleElementArray) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR0(1.0f); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0"); - a = builder.Neg(a); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + auto a = builder.Neg(parameter); auto reshape = builder.Reshape(/*operand=*/a, /*dimensions=*/{}, /*new_sizes=*/{1}); - ComputeAndCompareR1(&builder, {-1.0f}, {param0_data.get()}, - zero_error_spec_); + auto expected_literal = Literal::CreateR1({-1.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial0x3) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 3)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(0, 3); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // TODO(b/29185393): Make this work with the GPU backend. The GPU backend // does not handle zero-sized shapes correctly. Failed last on 2017-05-15 // with an incorrect result rank. -XLA_TEST_F(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial0x3WithParameter)) { ComputationBuilder builder(client_, TestName()); std::unique_ptr param0_literal = Literal::CreateR2FromArray2D(Array2D(0, 3)); - std::unique_ptr param0_data = - client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); - - auto a = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0, 3}), "param0"); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {param0_data.get()}, - zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *param0_literal, "param0", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, Trivial3x0) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Trivial3x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(3, 0)); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {}, {}, zero_error_spec_); + Array2D input_array(3, 0); + auto input_literal = Literal::CreateR2FromArray2D(input_array); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional row vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x3) { +XLA_TEST_P(ReshapeTest, Trivial1x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f, 2.0f, 3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f, 2.0f, 3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses a 2-dimensional column vector to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial3x1) { +XLA_TEST_P(ReshapeTest, Trivial3x1) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2({{1.0f}, {2.0f}, {3.0f}}); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); - - ComputeAndCompareR1(&builder, {1.0f, 2.0f, 3.0f}, {}, - zero_error_spec_); + auto input_literal = Literal::CreateR2({{1.0f}, {2.0f}, {3.0f}}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{0, 1}); + auto expected_literal = Literal::CreateR1({1.0f, 2.0f, 3.0f}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Splits an empty vector into an empty matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_0_To_2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(R1ToR2_0_To_2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 0}); - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateR1({}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Splits a vector into a matrix. -XLA_TEST_F(ReshapeTest, R1ToR2_6_To_2x3) { +XLA_TEST_P(ReshapeTest, R1ToR2_6_To_2x3) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); - auto result = - builder.Reshape(/*operand=*/a, /*dimensions=*/{0}, /*new_sizes=*/{2, 3}); - Array2D expected_2x3({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected_2x3, {}, zero_error_spec_); + auto input_literal = + Literal::CreateR1({1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0}, + /*new_sizes=*/{2, 3}); + auto expected_literal = + Literal::CreateR2({{1.0f, 2.0f, 3.0f}, {4.0f, 5.0f, 6.0f}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 2x0 array to a 0x2 array. -XLA_TEST_F(ReshapeTest, Reshape0x2To2x0) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Reshape0x2To2x0)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 2)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 0}); - - ComputeAndCompareR2(&builder, Array2D(2, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 2)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 0}); + auto expected_literal = Literal::CreateR2({{}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional row vector to a column vector. -XLA_TEST_F(ReshapeTest, ReshapeRowToCol) { +XLA_TEST_P(ReshapeTest, ReshapeRowToCol) { ComputationBuilder builder(client_, TestName()); auto simple = MakeLinspaceArray2D(1.0f, 3.0f, 1, 3); - auto a = builder.ConstantR2FromArray2D(*simple); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{3, 1}); + auto input_literal = Literal::CreateFromArray(*simple); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{3, 1}); auto expected = ReferenceUtil::TransposeArray2D(*simple); - ComputeAndCompareR2(&builder, *expected, {}, zero_error_spec_); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array. -XLA_TEST_F(ReshapeTest, TransposeAsReshape) { +XLA_TEST_P(ReshapeTest, TransposeAsReshape) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 4}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 4}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Transposes a 0x4 array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose0x4) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(Transpose0x4)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 4)); - auto result = builder.Transpose(a, {1, 0}); - - ComputeAndCompareR2(&builder, Array2D(4, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 4)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + auto expected_literal = Literal::CreateR2({{}, {}, {}, {}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Transposes a 2-dimensional array with ComputationBuilder::Trans. -XLA_TEST_F(ReshapeTest, Transpose4x3) { +XLA_TEST_P(ReshapeTest, Transpose4x3) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Transpose(a, {1, 0}); - - auto expected3x4 = ReferenceUtil::TransposeArray2D(*a4x3); - ComputeAndCompareR2(&builder, *expected3x4, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Transpose(parameter, {1, 0}); + + auto expected = ReferenceUtil::TransposeArray2D(*a4x3); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// // Reshapes an empty 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffleZeroElements) { +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitNoShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(6, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 3, 0, 0}); - - ComputeAndCompareR4(&builder, Array4D(2, 3, 0, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(6, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 3, 0, 0}); + auto expected_literal = Literal::CreateFromArray(Array4D(2, 3, 0, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, ReshapeR4ToR2ZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeR4ToR2ZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR4FromArray4D(Array4D(2, 3, 4, 0)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{24, 0}); - - ComputeAndCompareR2(&builder, Array2D(24, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array4D(2, 3, 4, 0)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{24, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(24, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), but no reordering (no shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitNoShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitNoShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1}, - /*new_sizes=*/{2, 6}); - - auto expected2x6 = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); - ComputeAndCompareR2(&builder, *expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1}, + /*new_sizes=*/{2, 6}); + + auto expected = MakeLinspaceArray2D(1.0f, 12.0f, 2, 6); + auto expected_literal = Literal::CreateFromArray(*expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -// Reshapes a 2-dimensional array with dimensions that are not just a -// rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffleZeroElements) { +// TODO(b/29185393): Make this work with the GPU backend. The GPU backend +// does not handle zero-sized shapes correctly. Failed last on 2017-11-30 +// with an incorrect result rank. +// +XLA_TEST_P(ReshapeTest, DISABLED_ON_GPU(ReshapeSplitAndShuffleZeroElements)) { ComputationBuilder builder(client_, TestName()); - auto a = builder.ConstantR2FromArray2D(Array2D(0, 6)); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{3, 0}); - - ComputeAndCompareR2(&builder, Array2D(3, 0), {}, - zero_error_spec_); + auto input_literal = Literal::CreateFromArray(Array2D(0, 6)); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{3, 0}); + auto expected_literal = Literal::CreateFromArray(Array2D(3, 0)); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshapes a 2-dimensional array with dimensions that are not just a // rearrangement of the originals (split), and reorder the input (shuffle). -XLA_TEST_F(ReshapeTest, ReshapeSplitAndShuffle) { +XLA_TEST_P(ReshapeTest, ReshapeSplitAndShuffle) { ComputationBuilder builder(client_, TestName()); auto a4x3 = MakeLinspaceArray2D(1.0f, 12.0f, 4, 3); - auto a = builder.ConstantR2FromArray2D(*a4x3); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{1, 0}, - /*new_sizes=*/{2, 6}); - - Array2D expected2x6({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, - {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); - ComputeAndCompareR2(&builder, expected2x6, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(*a4x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 0}, + /*new_sizes=*/{2, 6}); + Array2D expected({{1.0f, 4.0f, 7.0f, 10.0f, 2.0f, 5.0f}, + {8.0f, 11.0f, 3.0f, 6.0f, 9.0f, 12.0f}}); + auto expected_literal = Literal::CreateFromArray(expected); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // The following tests use the same input 3D array; they test the examples we // show for the Reshape operation in the operation_semantics document. // TODO(b/34503277): find a way to show this code in the documentation without // duplication on the TF documentation server. -Array3D v_array_for_doc_R3_tests({{{10, 11, 12}, {15, 16, 17}}, - {{20, 21, 22}, {25, 26, 27}}, - {{30, 31, 32}, {35, 36, 37}}, - {{40, 41, 42}, {45, 46, 47}}}); - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_012) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, - 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{0, 1, 2}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 11, 12}, - {15, 16, 17}, - {20, 21, 22}, - {25, 26, 27}, - {30, 31, 32}, - {35, 36, 37}, - {40, 41, 42}, - {45, 46, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R1_Collapse_120) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{24}); - ComputeAndCompareR1(&builder, - {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, - 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}, - {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{8, 3}); - Array2D expected({{10, 20, 30}, - {40, 11, 21}, - {31, 41, 12}, - {22, 32, 42}, - {15, 25, 35}, - {45, 16, 26}, - {36, 46, 17}, - {27, 37, 47}}); - ComputeAndCompareR2(&builder, expected, {}); -} - -XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { - ComputationBuilder builder(client_, TestName()); - auto v = builder.ConstantR3FromArray3D(v_array_for_doc_R3_tests); - auto result = builder.Reshape(/*operand=*/v, /*dimensions=*/{1, 2, 0}, - /*new_sizes=*/{2, 6, 2}); - Array3D expected( +static Array3D ArrayForDocR3Tests() { + return Array3D({{{10, 11, 12}, {15, 16, 17}}, + {{20, 21, 22}, {25, 26, 27}}, + {{30, 31, 32}, {35, 36, 37}}, + {{40, 41, 42}, {45, 46, 47}}}); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_012) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27, + 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_012_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 11, 12}, + {15, 16, 17}, + {20, 21, 22}, + {25, 26, 27}, + {30, 31, 32}, + {35, 36, 37}, + {40, 41, 42}, + {45, 46, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R1_Collapse_120) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{24}); + auto expected_literal = Literal::CreateR1( + {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42, + 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R2_Collapse_120_Refine_83) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{8, 3}); + auto expected_literal = Literal::CreateR2({{10, 20, 30}, + {40, 11, 21}, + {31, 41, 12}, + {22, 32, 42}, + {15, 25, 35}, + {45, 16, 26}, + {36, 46, 17}, + {27, 37, 47}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); +} + +XLA_TEST_P(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { + ComputationBuilder builder(client_, TestName()); + auto input_literal = Literal::CreateFromArray(ArrayForDocR3Tests()); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{1, 2, 0}, + /*new_sizes=*/{2, 6, 2}); + auto expected_literal = Literal::CreateR3( {{{10, 20}, {30, 40}, {11, 21}, {31, 41}, {12, 22}, {32, 42}}, {{15, 25}, {35, 45}, {16, 26}, {36, 46}, {17, 27}, {37, 47}}}); - ComputeAndCompareR3(&builder, expected, {}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Collapses the low dimensions of a 4D tensor to get a 2D matrix, without @@ -378,23 +519,26 @@ XLA_TEST_F(ReshapeTest, DocR3_R3_Collapse_120_Refine_262) { // Then we collapse Z be collapsed so we just end up with planes: // // 1 2 3 4 5 6 1 2 3 4 5 6 -XLA_TEST_F(ReshapeTest, FullyConnectedCollapse) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapse) { ComputationBuilder builder(client_, TestName()); Array4D t2x2x2x3(2, 2, 2, 3); auto filler2x3 = MakeLinspaceArray2D(1.0f, 6.0f, 2, 3); t2x2x2x3.FillWithYX(*filler2x3); - auto a = builder.ConstantR4FromArray4D(t2x2x2x3); - auto result = builder.Collapse(/*operand=*/a, /*dimensions=*/{1, 2, 3}); - - Array2D expected2x12( + auto input_literal = Literal::CreateFromArray(t2x2x2x3); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Collapse(/*operand=*/parameter, /*dimensions=*/{1, 2, 3}); + auto expected_literal = Literal::CreateR2( {{1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}}); - ComputeAndCompareR2(&builder, expected2x12, {}, zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // As above, but uses reshape directly. -XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { +XLA_TEST_P(ReshapeTest, FullyConnectedCollapseDesugared) { ComputationBuilder builder(client_, TestName()); Array4D t(2, 1, 2, 2); t(0, 0, 0, 0) = 0; @@ -405,52 +549,68 @@ XLA_TEST_F(ReshapeTest, FullyConnectedCollapseDesugared) { t(1, 0, 0, 1) = 5; t(1, 0, 1, 0) = 6; t(1, 0, 1, 1) = 7; - auto a = builder.ConstantR4FromArray4D(t); - auto result = builder.Reshape(/*operand=*/a, /*dimensions=*/{0, 1, 2, 3}, - /*new_sizes=*/{2, 4}); - - Array2D expected({{0, 1, 2, 3}, {4, 5, 6, 7}}); - ComputeAndCompareR2(&builder, expected, {}, zero_error_spec_); + auto input_literal = Literal::CreateFromArray(t); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(/*operand=*/parameter, /*dimensions=*/{0, 1, 2, 3}, + /*new_sizes=*/{2, 4}); + + auto expected_literal = + Literal::CreateR2({{0, 1, 2, 3}, {4, 5, 6, 7}}); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Reshape various ranks to a scalar. -XLA_TEST_F(ReshapeTest, ToScalar) { +XLA_TEST_P(ReshapeTest, ToScalar) { for (int rank = 0; rank < 8; ++rank) { ComputationBuilder b(client_, TestName()); - auto input = Literal::CreateR1({83.0f}); std::vector ones(rank, 1); // this is {1, ..., 1}. std::vector dimensions(rank); std::iota(dimensions.begin(), dimensions.end(), 0); - *input->mutable_shape() = ShapeUtil::MakeShape(F32, ones); - b.Reshape(b.ConstantLiteral(*input), dimensions, {}); + Literal input_literal(ShapeUtil::MakeShape(F32, ones)); + std::vector zeros(rank, 0); // this is {0, ..., 0}. + input_literal.Set(zeros, 83.0f); - ComputeAndCompareR0(&b, 83.0f, {}, zero_error_spec_); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, input_literal, "input", + &b, ¶meter); + b.Reshape(parameter, dimensions, {}); + + auto expected_literal = Literal::CreateR0(83.0f); + ComputeAndCompareLiteral(&b, *expected_literal, {input.get()}, + zero_error_spec_); } } -XLA_TEST_F(ReshapeTest, BadDimensions) { +XLA_TEST_P(ReshapeTest, BadDimensions) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1}), {}, {}); + auto input_literal = Literal::CreateR1({1.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {}, {}); EXPECT_THAT( ExecuteToString(&b, {}), ::testing::HasSubstr("not a permutation of the operand dimensions")); } -XLA_TEST_F(ReshapeTest, BadNewSizes) { +XLA_TEST_P(ReshapeTest, BadNewSizes) { ComputationBuilder b(client_, TestName()); - b.Reshape(b.ConstantR1({1, 2}), {1}, {}); + auto input_literal = Literal::CreateR1({1.0f, 2.0f}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", &b, + ¶meter); + b.Reshape(parameter, {1}, {}); EXPECT_THAT(ExecuteToString(&b, {}), ::testing::HasSubstr("mismatched element counts")); } -XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { - const Shape parameter_shape = ShapeUtil::MakeShape(F32, {2, 2, 2, 2}); +XLA_TEST_P(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, parameter_shape, "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); - // clang-format off - auto literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ + auto input_literal = Literal::CreateR4FromArray4DWithLayout(Array4D{ { { {0, 1}, @@ -474,8 +634,12 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { }, LayoutUtil::MakeLayout({0, 1, 2, 3})); // clang-format on - std::unique_ptr input = - client_->TransferToServer(*literal).ConsumeValueOrDie(); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 8}); + Array2D expected_array({ {0, 1, 2, 3, 100, 101, 102, 103}, {222, 333, 444, 555, 666, 777, 888, 999}, @@ -484,72 +648,75 @@ XLA_TEST_F(ReshapeTest, R4Dim0MinorLayoutToR2Dim0MajorLayout) { Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {2, 8}, {1, 0}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {2, 8}, + {1, 0}); std::unique_ptr actual = client_ ->ExecuteAndTransfer(computation, {input.get()}, &execution_options) .ConsumeValueOrDie(); std::unique_ptr expected = Literal::CreateR2FromArray2D(expected_array); + if (use_bfloat16()) { + expected = LiteralTestUtil::ConvertF32ToBF16(*expected); + } LiteralTestUtil::ExpectEqual(*expected, *actual); } -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 1, 2, 3}}, {{4, 5, 6, 7}}}, {{{100, 101, 102, 103}}, {{104, 105, 106, 107}}}, {{{200, 201, 202, 203}}, {{204, 205, 206, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } // Tests R2->R4 reshape with the reshape dimensions {1, 0}. -XLA_TEST_F(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { - std::unique_ptr input = Literal::CreateR2({ +XLA_TEST_P(ReshapeTest, R2ToR4_3x8_To_3x2x1x4_Dimensions_10) { + ComputationBuilder builder(client_, TestName()); + std::unique_ptr input_literal = Literal::CreateR2({ {0, 1, 2, 3, 4, 5, 6, 7}, {100, 101, 102, 103, 104, 105, 106, 107}, {200, 201, 202, 203, 204, 205, 206, 207}, }); - std::unique_ptr input_data = - client_->TransferToServer(*input).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *input_literal, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0}, /*new_sizes=*/{3, 2, 1, 4}); // clang-format off - Array4D expected = { + auto expected_literal = Literal::CreateR4({ {{{0, 100, 200, 1}}, {{101, 201, 2, 102}}}, {{{202, 3, 103, 203}}, {{4, 104, 204, 5}}}, {{{105, 205, 6, 106}}, {{206, 7, 107, 207}}} - }; + }); // clang-format on - ComputeAndCompareR4(&builder, expected, {input_data.get()}, - zero_error_spec_); + ComputeAndCompareLiteral(&builder, *expected_literal, {input.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 1, 1); @@ -559,12 +726,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{2, 1}); std::unique_ptr expected = LiteralTestUtil::Reshape({2, 1}, {1, 0}, *input_literal); @@ -572,7 +737,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x1x1_To_2x1) { zero_error_spec_); } -XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { +XLA_TEST_P(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(2, 1, 4, 1); @@ -582,12 +748,10 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{4, 2}); std::unique_ptr expected = LiteralTestUtil::Reshape({4, 2}, {1, 0}, *input_literal); @@ -596,7 +760,8 @@ XLA_TEST_F(ReshapeTest, R4ToR2_2x1x4x1_To_4x2) { } // Tests R4->R2 reshape with the reshape dimensions {0, 2, 1, 3}. -XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { +XLA_TEST_P(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input(5, 10, 2, 3); @@ -606,12 +771,11 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 2, 1, 3}, /*new_sizes=*/{5, 60}); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 2, 1, 3}, + /*new_sizes=*/{5, 60}); Array2D expected_array(5, 60); input.Each([&](tensorflow::gtl::ArraySlice indices, float* cell) { @@ -619,10 +783,12 @@ XLA_TEST_F(ReshapeTest, R4ToR2_5x10x2x3_To_5x60_Dimensions_0213) { *cell; }); auto expected = Literal::CreateR2FromArray2D(expected_array); - ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}); + ComputeAndCompareLiteral(&builder, *expected, {input_data.get()}, + zero_error_spec_); } -XLA_TEST_F(ReshapeTest, NoopReshape) { +XLA_TEST_P(ReshapeTest, NoopReshape) { + ComputationBuilder builder(client_, TestName()); std::mt19937 rng; std::uniform_real_distribution distribution; Array4D input_array(2, 3, 5, 7); @@ -632,18 +798,17 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input_array, LayoutUtil::MakeLayout({1, 2, 3, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - - ComputationBuilder builder(client_, TestName()); - auto input = builder.Parameter(0, input_literal->shape(), "input"); - builder.Reshape(input, /*dimensions=*/{3, 0, 1, 2}, + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{3, 0, 1, 2}, /*new_sizes=*/{7, 2, 3, 5}); Computation computation = builder.Build().ConsumeValueOrDie(); ExecutionOptions execution_options = execution_options_; *execution_options.mutable_shape_with_output_layout() = - ShapeUtil::MakeShapeWithLayout(F32, {7, 2, 3, 5}, {2, 3, 0, 1}); + ShapeUtil::MakeShapeWithLayout(use_bfloat16() ? BF16 : F32, {7, 2, 3, 5}, + {2, 3, 0, 1}); std::unique_ptr output_literal = client_ ->ExecuteAndTransfer(computation, {input_data.get()}, @@ -652,35 +817,43 @@ XLA_TEST_F(ReshapeTest, NoopReshape) { // Since the reshape is a no-op, verify that it does not change the underlying // data. - EXPECT_EQ(tensorflow::gtl::ArraySlice(input_literal->f32s()), - tensorflow::gtl::ArraySlice(output_literal->f32s())); + if (use_bfloat16()) { + auto expected = LiteralTestUtil::ConvertF32ToBF16(*input_literal); + EXPECT_EQ(expected->data(), output_literal->data()); + } else { + EXPECT_EQ(input_literal->data(), output_literal->data()); + } } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape_Trivial) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape_Trivial) { + ComputationBuilder builder(client_, TestName()); + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); - ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{0, 1, 2, 3}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 2, 3}, /*new_sizes=*/{1, 2, 3, 4}); - ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {}); + ComputeAndCompareLiteral(&builder, *literal_1x2x3x4, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { - auto literal_1x2x3x4 = Literal::CreateR4( +XLA_TEST_P(ReshapeTest, R4ToR4Reshape) { + auto literal_1x2x3x4 = Literal::CreateR4( {{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}}); ComputationBuilder builder(client_, TestName()); - auto input = builder.ConstantLiteral(*literal_1x2x3x4); - builder.Reshape(input, /*dimensions=*/{1, 3, 2, 0}, + ComputationDataHandle parameter; + auto input = CreateParameterAndTransferLiteral(0, *literal_1x2x3x4, "input", + &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 3, 2, 0}, /*new_sizes=*/{2, 4, 3, 1}); // clang-format off - auto expected_2x4x3x1 = Literal::CreateR4( + auto expected_2x4x3x1 = Literal::CreateR4( {{{{1}, {5}, {9}}, {{2}, {6}, {10}}, {{3}, {7}, {11}}, @@ -691,10 +864,10 @@ XLA_TEST_F(ReshapeTest, R4ToR4Reshape) { {{16}, {20}, {24}}}}); // clang-format on - ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {}); + ComputeAndCompareLiteral(&builder, *expected_2x4x3x1, {input.get()}); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeSimple) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {2, 2, 2, 2}; @@ -706,12 +879,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -723,7 +896,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeSimple) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {1, 1, 250, 300}; @@ -735,12 +908,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -752,7 +925,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstEffectiveR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {5, 5, 1, 10}; @@ -764,12 +937,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -781,7 +954,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::mt19937 rng; std::uniform_real_distribution distribution; // This happens in NN-Builder MNIST. @@ -794,12 +967,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({3, 2, 1, 0})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{0, 1, 3, 2}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{0, 1, 3, 2}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {2, 3, 1, 0}, *input_literal) @@ -811,7 +984,7 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeMajorFirstMinorEffectiveR1InR2) { zero_error_spec_, &expected->shape()); } -XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { +XLA_TEST_P(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::mt19937 rng; std::uniform_real_distribution distribution; std::vector bounds = {3, 3, 1, 3}; @@ -823,12 +996,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { std::unique_ptr input_literal = Literal::CreateR4FromArray4DWithLayout( input, LayoutUtil::MakeLayout({0, 1, 2, 3})); - std::unique_ptr input_data = - client_->TransferToServer(*input_literal).ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); - auto a = builder.Parameter(0, input_literal->shape(), "a"); - builder.Reshape(a, /*dimensions=*/{1, 0, 2, 3}, /*new_sizes=*/new_bounds); + ComputationDataHandle parameter; + auto input_data = CreateParameterAndTransferLiteral( + 0, *input_literal, "input", &builder, ¶meter); + builder.Reshape(parameter, /*dimensions=*/{1, 0, 2, 3}, + /*new_sizes=*/new_bounds); std::unique_ptr expected = LiteralTestUtil::Reshape(new_bounds, {1, 0, 2, 3}, *input_literal) @@ -840,5 +1013,12 @@ XLA_TEST_F(ReshapeTest, R4TwoMinorTransposeTrivialR2) { zero_error_spec_, &expected->shape()); } +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, ::testing::Bool()); +#else +INSTANTIATE_TEST_CASE_P(ReshapeTestInstance, ReshapeTest, + ::testing::ValuesIn(std::vector{false})); +#endif + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc index 1f6cfc85ccd25bb22db51411f7376489c14c3603..8fc841f14087cdea02fe44cdaea521ff92122aec 100644 --- a/tensorflow/compiler/xla/tests/reverse_test.cc +++ b/tensorflow/compiler/xla/tests/reverse_test.cc @@ -28,56 +28,89 @@ limitations under the License. namespace xla { namespace { -class ReverseTest : public ClientLibraryTestBase {}; - -// Tests the reverse operation on a scalar. -XLA_TEST_F(ReverseTest, ReverseScalar) { - ComputationBuilder b(client_, TestName()); - float input = 3.5f; - b.Rev(b.ConstantR0(input), {}); - ComputeAndCompareR0(&b, input, {}); -} - -// Tests the reverse operation on a 0x0 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse0x0FloatArray) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR2FromArray2D(Array2D(0, 0)), {0, 1}); - ComputeAndCompareR2(&b, Array2D(0, 0), {}); -} - -// Tests the reverse operation on a 0x1 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse0x1FloatArray) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR2FromArray2D(Array2D(0, 1)), {0, 1}); - ComputeAndCompareR2(&b, Array2D(0, 1), {}); +#ifdef XLA_BACKEND_SUPPORTS_BFLOAT16 +// Tests both F32 and BF16. +static std::array use_bfloat16_params{false, true}; +#else +// Only tests F32. +static std::array use_bfloat16_params{false}; +#endif + +struct ReverseSpec { + tensorflow::gtl::ArraySlice input_dims; + tensorflow::gtl::ArraySlice reversal; + bool use_bfloat16; + + string ToTestCaseName() const { + return tensorflow::strings::Printf( + "reverse_%s_in_dims_%s_%s", + tensorflow::str_util::Join(input_dims, "x").c_str(), + tensorflow::str_util::Join(reversal, "x").c_str(), + use_bfloat16 ? "bf16" : "f32"); + } +}; + +static std::vector GetTestCases() { + // clang-format off + return ExpandUseBfloat16( + use_bfloat16_params, + {{{}, {}}, + {{0, 0}, {0, 1}}, + {{0, 1}, {0, 1}}, + {{1, 0}, {0, 1}}, + {{1, 1}, {0, 1}}, + {{2, 0, 4, 3}, {0, 2}}, + {{2, 0, 4, 3}, {1, 3}}, + {{1, 2, 3, 4}, {0, 3}}, + {{4, 3, 2, 1}, {0, 1}}, + }); + // clang-format on } -// Tests the reverse operation on a 1x0 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse1x0FloatArray) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR2FromArray2D(Array2D(1, 0)), {0, 1}); - ComputeAndCompareR2(&b, Array2D(1, 0), {}); +void PrintTo(const ReverseSpec& spec, std::ostream* os) { + *os << spec.ToTestCaseName(); } -// Tests the reverse operation on a 1x1 float array on both dimensions. -XLA_TEST_F(ReverseTest, Reverse1x1FloatArray) { - ComputationBuilder b(client_, TestName()); - Array2D input({{3.5f}}); - b.Rev(b.ConstantR2FromArray2D(input), {0, 1}); - ComputeAndCompareR2(&b, input, {}); +class FloatReverseTest : public ClientLibraryTestBase, + public ::testing::WithParamInterface { + public: + FloatReverseTest() { set_use_bfloat16(GetParam().use_bfloat16); } +}; + +TEST_P(FloatReverseTest, Reverses) { + const ReverseSpec& spec = GetParam(); + std::vector input_vector( + ShapeUtil::ElementsIn(ShapeUtil::MakeShape(F32, spec.input_dims))); + std::iota(input_vector.begin(), input_vector.end(), 0.0); + auto r1_literal = Literal::CreateR1(input_vector); + auto input_literal = r1_literal->Reshape(spec.input_dims).ConsumeValueOrDie(); + + ComputationBuilder builder(client_, TestName()); + auto a = AddParam(*input_literal, &builder); + builder.Rev(a, spec.reversal); + + std::unique_ptr expected = input_literal->CloneToUnique(); + std::vector output_indices(spec.input_dims.size()); + expected->EachCell( + [&](tensorflow::gtl::ArraySlice indices, float) { + for (int64 i = 0; i < indices.size(); ++i) { + output_indices[i] = indices[i]; + } + float value = input_literal->Get(indices); + for (int64 dim : spec.reversal) { + output_indices[dim] = (spec.input_dims[dim] - 1) - indices[dim]; + } + expected->Set(output_indices, value); + }); + ComputeAndCompareLiteral(&builder, *expected, {}); } -XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim02) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR4FromArray4D(Array4D(2, 0, 4, 3)), {0, 2}); - ComputeAndCompareR4(&b, Array4D(2, 0, 4, 3), {}); -} +INSTANTIATE_TEST_CASE_P(FloatReverseInstance, FloatReverseTest, + ::testing::ValuesIn(GetTestCases()), + ::testing::PrintToStringParamName()); -XLA_TEST_F(ReverseTest, Reverse2x0x4x3FloatArrayDim13) { - ComputationBuilder b(client_, TestName()); - b.Rev(b.ConstantR4FromArray4D(Array4D(2, 0, 4, 3)), {1, 3}); - ComputeAndCompareR4(&b, Array4D(2, 0, 4, 3), {}); -} +// A simple test class which not templated by float precision. +class ReverseTest : public ClientLibraryTestBase {}; // Tests the reverse operation on a 4D U8 array on dimension 0 and 3. XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) { diff --git a/tensorflow/compiler/xla/tests/sample_file_test.cc b/tensorflow/compiler/xla/tests/sample_file_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..31b104f4e37f77d47f56ff8183ee1de1cc22e44d --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_file_test.cc @@ -0,0 +1,51 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create a file based testcase +// and compare results on gpu and cpu. + +#include +#include + +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +class SampleFileTest : public HloTestBase { + protected: + SampleFileTest() + : HloTestBase( + /*test_platform=*/PlatformUtil::GetPlatform("gpu").ValueOrDie(), + /*reference_platform=*/PlatformUtil::GetPlatform("cpu") + .ValueOrDie()) {} +}; + +TEST_F(SampleFileTest, Convolution) { + const string& filename = "compiler/xla/tests/isolated_convolution.hlo"; + string test_srcdir = tensorflow::testing::TensorFlowSrcRoot(); + EXPECT_TRUE(RunAndCompareFromFile( + tensorflow::io::JoinPath(test_srcdir, filename), ErrorSpec{0.01})); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/sample_text_test.cc b/tensorflow/compiler/xla/tests/sample_text_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..b4f2b74e3dc9e80f50454b28eb6f2502cef3e681 --- /dev/null +++ b/tensorflow/compiler/xla/tests/sample_text_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This demonstrates how to use hlo_test_base to create textual IR based +// testcases. + +#include +#include + +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { + +using tensorflow::gtl::nullopt; + +class SampleTextTest : public HloTestBase {}; + +TEST_F(SampleTextTest, Axpy) { + const string& hlo_string = R"( +HloModule axpy_module: +ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[] parameter(0) + %broadcast = f32[2,4]{1,0} broadcast(f32[] %alpha), dimensions={} + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} +)"; + EXPECT_TRUE(RunAndCompareNoHloPasses(hlo_string, ErrorSpec{0.0001})); +} + +TEST_F(SampleTextTest, Tuple) { + const string& hlo_string = R"( +HloModule TupleCreate_module: +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} +)"; + EXPECT_TRUE(RunAndCompare(hlo_string, nullopt)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index b5e7570778ffeca66cc15d7cd2b153639637a647..4da6ee91607941b395b00befc98a10e7c17746ed 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -69,6 +69,13 @@ class ScalarComputationsTest : public ClientLibraryTestBase { } }; +XLA_TEST_F(ScalarComputationsTest, ReturnScalarF32) { + ComputationBuilder builder(client_, TestName()); + builder.ConstantR0(2.1f); + + ComputeAndCompareR0(&builder, 2.1f, {}, error_spec_); +} + XLA_TEST_F(ScalarComputationsTest, NegateScalarF32) { ComputationBuilder builder(client_, TestName()); builder.Neg(builder.ConstantR0(2.1f)); @@ -730,7 +737,61 @@ XLA_TEST_F(ScalarComputationsTest, PowScalar) { ComputeAndCompareR0(&builder, 8.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(2), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowS32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(-1), // The lower bound. + builder.ConstantR0(-5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, -1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(5), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 3, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(2), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 2, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowU32) { + ComputationBuilder builder(client_, TestName()); + builder.Clamp(builder.ConstantR0(1), // The lower bound. + builder.ConstantR0(0), // The operand to be clamped. + builder.ConstantR0(3)); // The upper bound. + + ComputeAndCompareR0(&builder, 1, {}); +} + +XLA_TEST_F(ScalarComputationsTest, ClampScalarHighF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(5.0f), // The operand to be clamped. @@ -739,7 +800,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarHigh) { ComputeAndCompareR0(&builder, 3.0, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddleF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(2.5f), // The operand to be clamped. @@ -748,7 +809,7 @@ XLA_TEST_F(ScalarComputationsTest, ClampScalarMiddle) { ComputeAndCompareR0(&builder, 2.5, {}, error_spec_); } -XLA_TEST_F(ScalarComputationsTest, ClampScalarLow) { +XLA_TEST_F(ScalarComputationsTest, ClampScalarLowF32) { ComputationBuilder builder(client_, TestName()); builder.Clamp(builder.ConstantR0(2.0f), // The lower bound. builder.ConstantR0(-5.0f), // The operand to be clamped. @@ -845,5 +906,12 @@ XLA_TEST_F(ScalarComputationsTest, SqrtF320) { ComputeAndCompareR0(&builder, 0.0f, {zero_data.get()}, error_spec_); } +XLA_TEST_F(ScalarComputationsTest, RoundScalar) { + ComputationBuilder builder(client_, TestName()); + builder.Round(builder.ConstantR0(1.4f)); + + ComputeAndCompareR0(&builder, 1.0f, {}, error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc index 62ff349e9c011e0eb845192013a74aeb0956b791..9ee94b8571e5fc8789b60501462986967ce909a0 100644 --- a/tensorflow/compiler/xla/tests/select_and_scatter_test.cc +++ b/tensorflow/compiler/xla/tests/select_and_scatter_test.cc @@ -39,8 +39,8 @@ namespace xla { namespace { struct SelectAndScatterTestParam { - Array4D operand_shape; - Array4D source_shape; + std::vector operand_shape; + std::vector source_shape; Padding padding_type; tensorflow::gtl::ArraySlice window_dimensions; tensorflow::gtl::ArraySlice window_strides; @@ -69,83 +69,132 @@ class SelectAndScatterTest Computation min_f32_; }; -XLA_TEST_P(SelectAndScatterTest, R4Randomized) { - Array4D o(GetParam().operand_shape); +XLA_TEST_P(SelectAndScatterTest, ParamTest) { + auto operand_shape = GetParam().operand_shape; + Array o(operand_shape); o.FillRandom(1.5f); - auto operand = builder_.ConstantR4FromArray4D(o); + auto operand = builder_.ConstantFromArray(o); - Array4D s(GetParam().source_shape); + auto source_shape = GetParam().source_shape; + Array s(source_shape); s.FillRandom(12.0f); - auto source = builder_.ConstantR4FromArray4D(s); - - builder_.SelectAndScatter(operand, ge_f32_, GetParam().window_dimensions, - GetParam().window_strides, GetParam().padding_type, - source, builder_.ConstantR0(0.0f), add_f32_); + auto source = builder_.ConstantFromArray(s); - auto e = ReferenceUtil::SelectAndScatter4DGePlus( - o, s, 0.0f, GetParam().window_dimensions, GetParam().window_strides, - GetParam().padding_type == Padding::kSame); + auto select_and_scatter = builder_.SelectAndScatter( + operand, ge_f32_, GetParam().window_dimensions, GetParam().window_strides, + GetParam().padding_type, source, builder_.ConstantR0(0.0f), + add_f32_); - ComputeAndCompareR4(&builder_, *e, {}, ErrorSpec(1e-5)); + ComputeAndCompare(&builder_, select_and_scatter, {}, ErrorSpec(1e-5)); } INSTANTIATE_TEST_CASE_P( SelectAndScatterTest_Instantiation, SelectAndScatterTest, - ::testing::Values(SelectAndScatterTestParam{{6, 6, 256, 128}, - {3, 3, 256, 128}, - Padding::kSame, - {3, 3, 1, 1}, - {2, 2, 1, 1}}, - SelectAndScatterTestParam{{7, 7, 256, 128}, - {3, 3, 256, 128}, - Padding::kValid, - {3, 3, 1, 1}, - {2, 2, 1, 1}}, - SelectAndScatterTestParam{{6, 7, 256, 128}, - {3, 3, 256, 128}, - Padding::kValid, - {2, 3, 1, 1}, - {2, 2, 1, 1}}, - SelectAndScatterTestParam{{6, 7, 256, 128}, - {2, 3, 256, 128}, - Padding::kValid, - {2, 3, 1, 1}, - {3, 2, 1, 1}}, - SelectAndScatterTestParam{{9, 9, 16, 128}, - {3, 3, 16, 128}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{3, 3, 4, 4}, - {1, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{3, 3, 4, 4}, - {1, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{9, 3, 4, 4}, - {3, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{7, 3, 4, 4}, - {3, 1, 4, 4}, - Padding::kValid, - {3, 3, 1, 1}, - {2, 3, 1, 1}}, - SelectAndScatterTestParam{{1, 1, 5, 5}, - {1, 1, 5, 5}, - Padding::kSame, - {3, 3, 1, 1}, - {3, 3, 1, 1}}, - SelectAndScatterTestParam{{7, 7, 8, 256}, - {4, 4, 8, 256}, - Padding::kSame, - {2, 2, 1, 1}, - {2, 2, 1, 1}})); + ::testing::Values( + SelectAndScatterTestParam{{6, 6, 6, 4, 4}, + {3, 3, 3, 4, 4}, + Padding::kSame, + {3, 3, 3, 1, 1}, + {2, 2, 2, 1, 1}}, + SelectAndScatterTestParam{{7, 7, 7, 4, 4}, + {3, 3, 3, 4, 4}, + Padding::kValid, + {3, 3, 3, 1, 1}, + {2, 2, 2, 1, 1}}, + + SelectAndScatterTestParam{{8, 8, 8, 4, 4}, + {1, 3, 3, 4, 4}, + Padding::kValid, + {8, 4, 4, 1, 1}, + {1, 2, 2, 1, 1}}, + SelectAndScatterTestParam{{6, 6, 256, 128}, + {3, 3, 256, 128}, + Padding::kSame, + {3, 3, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{{7, 7, 256, 128}, + {3, 3, 256, 128}, + Padding::kValid, + {3, 3, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{{6, 7, 256, 128}, + {3, 3, 256, 128}, + Padding::kValid, + {2, 3, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{{6, 7, 256, 128}, + {2, 3, 256, 128}, + Padding::kValid, + {2, 3, 1, 1}, + {3, 2, 1, 1}}, + SelectAndScatterTestParam{{9, 9, 16, 128}, + {3, 3, 16, 128}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{3, 3, 4, 4}, + {1, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{3, 3, 4, 4}, + {1, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{9, 3, 4, 4}, + {3, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{7, 3, 4, 4}, + {3, 1, 4, 4}, + Padding::kValid, + {3, 3, 1, 1}, + {2, 3, 1, 1}}, + SelectAndScatterTestParam{{1, 1, 5, 5}, + {1, 1, 5, 5}, + Padding::kSame, + {3, 3, 1, 1}, + {3, 3, 1, 1}}, + SelectAndScatterTestParam{{7, 7, 8, 256}, + {4, 4, 8, 256}, + Padding::kSame, + {2, 2, 1, 1}, + {2, 2, 1, 1}}, + SelectAndScatterTestParam{ + {6, 4, 4}, {3, 4, 4}, Padding::kSame, {3, 1, 1}, {2, 1, 1}}, + SelectAndScatterTestParam{ + {6, 256, 128}, {3, 256, 128}, Padding::kSame, {3, 1, 1}, {2, 1, 1}}, + SelectAndScatterTestParam{{7, 256, 128}, + {3, 256, 128}, + Padding::kValid, + {3, 1, 1}, + {2, 1, 1}}, + SelectAndScatterTestParam{{6, 256, 128}, + {3, 256, 128}, + Padding::kValid, + {2, 1, 1}, + {2, 1, 1}}, + SelectAndScatterTestParam{{6, 256, 128}, + {2, 256, 128}, + Padding::kValid, + {2, 1, 1}, + {3, 1, 1}}, + SelectAndScatterTestParam{ + {9, 16, 128}, {3, 16, 128}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {3, 4, 4}, {1, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {9, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {7, 4, 4}, {3, 4, 4}, Padding::kValid, {3, 1, 1}, {2, 1, 1}}, + SelectAndScatterTestParam{ + {1, 5, 5}, {1, 5, 5}, Padding::kSame, {3, 1, 1}, {3, 1, 1}}, + SelectAndScatterTestParam{ + {7, 8, 256}, {4, 8, 256}, Padding::kSame, {2, 1, 1}, {2, 1, 1}})); // Test for F32 1D array, with a zero-element input. XLA_TEST_F(SelectAndScatterTest, R1S0F32) { diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc index c21124750ad512cad69b1483e708613ee2857ac0..ac163df127e0087c02777fa3d5ce7970c51b97b9 100644 --- a/tensorflow/compiler/xla/tests/slice_test.cc +++ b/tensorflow/compiler/xla/tests/slice_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -33,7 +34,6 @@ namespace xla { namespace { using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrCat; class SliceTest : public ClientLibraryTestBase {}; @@ -211,6 +211,13 @@ class SliceR1Test : public ClientLibraryTestBase, } }; +string SliceR1TestDataToString(const ::testing::TestParamInfo& data) { + const R1Spec& spec = data.param; + return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0, + spec.slice_start, spec.slice_limit, + spec.slice_stride); +} + XLA_TEST_P(SliceR1Test, DoIt_F32) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_F64) { Run(GetParam()); } @@ -223,30 +230,66 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run(GetParam()); } XLA_TEST_P(SliceR1Test, DoIt_S64) { Run(GetParam()); } -INSTANTIATE_TEST_CASE_P( // - SliceR1TestInstantiation, // - SliceR1Test, // - ::testing::Values( // - R1Spec{10, 0, 0, 1}, // - R1Spec{10, 7, 7, 1}, // - R1Spec{10, 2, 4, 1}, // - R1Spec{10, 2, 4, 2}, // - R1Spec{10, 0, 10, 1}, // - R1Spec{1024, 1024 - 4, 1024, 1}, // - R1Spec{4096, 7, 7 + 1024, 1}, // - R1Spec{10, 0, 10, 2}, // - R1Spec{10, 0, 10, 3}, // - R1Spec{10, 0, 10, 4}, // - R1Spec{10, 0, 10, 5}, // - R1Spec{10, 0, 10, 10}, // - R1Spec{500, 200, 400, 7}, // - R1Spec{4096, 1, 4095, 3}, // - R1Spec{2047, 1024 - 24, 1024 + 160, 31}, // - R1Spec{2047, 1, 2046, 3 * 128}, // - R1Spec{4096, 1024 + 3, 4095, 500}, // - R1Spec{8192, 0, 8192, 1024 * 3 + 400} // - ) // +// Tests for R1 slice ops. +// The format for each testcase is {input size, start, limit, stride}. +// clang-format off +INSTANTIATE_TEST_CASE_P( + SliceR1TestInstantiation, + SliceR1Test, + ::testing::Values( + R1Spec{10, 0, 0, 1}, + R1Spec{10, 7, 7, 1}, + R1Spec{10, 0, 5, 1}, + R1Spec{10, 3, 5, 1}, + R1Spec{10, 0, 10, 1}, + R1Spec{1024, 0, 5, 1}, + R1Spec{1024, 3, 5, 1}, + R1Spec{1024 + 17, 0, 5, 1}, + R1Spec{1024 + 17, 3, 5, 1}, + R1Spec{1024 + 17, 1024, 1024 + 6, 1}, + R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1}, + R1Spec{1024, 1024 - 4, 1024, 1}, + R1Spec{4 * 1024, 7, 7 + 1024, 1}, + R1Spec{4 * 1024, 0, 4 * 1024, 1}, + R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1}, + R1Spec{4 * 1024, 1024, 3 * 1024, 1}, + R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1}, + R1Spec{16 * 1024, 0, 5, 1}, + R1Spec{16 * 1024, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 0, 5, 1}, + R1Spec{16 * 1024 + 17, 3, 5, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1}, + R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1}, + R1Spec{64 * 1024, 0, 64 * 1024, 1}, + R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1}, + R1Spec{64 * 1024, 1024, 63 * 1024, 1}, + R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1}, + R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1}, + R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1}, +// TODO(b/69425338): This uses too much memory on GPU. +#ifndef XLA_TEST_BACKEND_GPU + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1}, + R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1}, +#endif + R1Spec{10, 2, 4, 2}, + R1Spec{10, 0, 10, 2}, + R1Spec{10, 0, 10, 3}, + R1Spec{10, 0, 10, 4}, + R1Spec{10, 0, 10, 5}, + R1Spec{10, 0, 10, 10}, + R1Spec{500, 200, 400, 7}, + R1Spec{4096, 1, 4095, 3}, + R1Spec{2047, 1024 - 24, 1024 + 160, 31}, + R1Spec{2047, 1, 2046, 3 * 128}, + R1Spec{4096, 1024 + 3, 4095, 500}, + R1Spec{8192, 0, 8192, 1024 * 3 + 400} + ), + SliceR1TestDataToString ); +// clang-format on struct R2Spec { int64 input_dim0; @@ -339,7 +382,7 @@ struct R4Spec { string R4SpecToString(const ::testing::TestParamInfo& data) { const R4Spec& spec = data.param; - return StrCat( // + return tensorflow::strings::StrCat( // "input_", Join(spec.input_dims, "x"), // "__layout_", Join(spec.input_layout, ""), // "__starts_", Join(spec.slice_starts, "x"), // diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h index 28a2d0198a707cec1aa5e0fbed341ee9b2a927f7..cc4eaf62f50d1fa622c705fab810fe1e1b0fbf08 100644 --- a/tensorflow/compiler/xla/tests/test_macros.h +++ b/tensorflow/compiler/xla/tests/test_macros.h @@ -36,6 +36,7 @@ limitations under the License. #define DISABLED_ON_CPU(X) X #define DISABLED_ON_CPU_PARALLEL(X) X #define DISABLED_ON_GPU(X) X +#define DISABLED_ON_INTERPRETER(X) X // We need this macro instead of pasting directly to support nesting // the DISABLED_ON_FOO macros, as in the definition of DISABLED_ON_CPU. @@ -62,6 +63,11 @@ limitations under the License. # define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X) #endif // XLA_TEST_BACKEND_GPU +#ifdef XLA_TEST_BACKEND_INTERPRETER +# undef DISABLED_ON_INTERPRETER +# define DISABLED_ON_INTERPRETER(X) XLA_TEST_PASTE(DISABLED_, X) +#endif // XLA_TEST_BACKEND_INTERPRETER + // clang-format on namespace xla { diff --git a/tensorflow/compiler/xla/tests/test_utils.cc b/tensorflow/compiler/xla/tests/test_utils.cc index 0d56c9f48363d0569921d7c76050dcc66208931b..b060fb13b1451aab30cfca73bea0a4a598a9fa3a 100644 --- a/tensorflow/compiler/xla/tests/test_utils.cc +++ b/tensorflow/compiler/xla/tests/test_utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_verifier.h" #include "tensorflow/compiler/xla/service/transfer_manager.h" @@ -23,118 +24,117 @@ namespace xla { namespace { template -void PopulateWithRandomFloatingPointData(Literal* literal) { +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::minstd_rand0 engine; - std::uniform_real_distribution generator(0.0f, 1.0f); + // Create uniform numbers between 1 and 1.125 to avoid creating denormal + // numbers. + std::uniform_real_distribution generator(1.0f, 1.125f); + const bool should_index_bias = ShapeUtil::ElementsIn(literal->shape()) > 1000; TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice indices) { + // Generate a random uniform number from -0.0625 and 0.0625 and bias it + // with a position dependent number with mean 0.037109375. These number + // should allow for long chains of accumulation without being too close + // to zero or too large to accumulate all numbers accurately. Only do + // this for large literals where the number of elements is much greater + // than 47 otherwise only negative values are produced. + // + // The value is positionally biased using a product of the indices. Add + // one to each index value to avoid collapsing to zero if any of the + // indices are zero. + int64 index_product = 1; + for (int64 i : indices) { + index_product *= (1 + i); + } + const int64 negative_bias = should_index_bias ? 47 : 0; + FloatT index_bias = + static_cast(index_product % 113 - negative_bias) / + static_cast(256.0f); + return (generator(*engine) - 1.0625) + index_bias; + })); +} + +// The standard library does not have a case for bfloat16, unsurprisingly, so we +// handle that one specially. +template <> +void PopulateWithRandomFloatingPointData(Literal* literal, + std::minstd_rand0* engine) { + CHECK_EQ(literal->shape().element_type(), BF16); + std::uniform_real_distribution generator(-0.9f, 1.0f); + TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); + return static_cast(generator(*engine)); })); } template -void PopulateWithRandomIntegralData(Literal* literal) { +void PopulateWithRandomIntegralData(Literal* literal, + std::minstd_rand0* engine) { CHECK_EQ(literal->shape().element_type(), primitive_util::NativeToPrimitiveType()); - std::minstd_rand0 engine; std::uniform_int_distribution generator( std::numeric_limits::lowest(), std::numeric_limits::max()); TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); + return generator(*engine); })); } -bool LooksLikeSum(const HloInstruction& instruction) { - return instruction.opcode() == HloOpcode::kAdd && - instruction.operand(0)->opcode() == HloOpcode::kParameter && - instruction.operand(1)->opcode() == HloOpcode::kParameter && - instruction.operand(0) != instruction.operand(1); -} - -// Given an instruction and operand number, replace the given operand with -// a Literal Constant Zero. Handle the case of a fusion instruction by -// replacing the fusion's parent's parameter with a Literal Constant Zero, -// unless the fusion's parent is itself a fusion. -Status MaybeReplaceParameterInputWithZero(HloInstruction* const instruction, - const int64 operand_number) { - CHECK_LT(operand_number, instruction->operand_count()); - if (instruction->operand(operand_number)->opcode() != HloOpcode::kParameter) { - return Status::OK(); - } - - HloComputation* const computation = instruction->parent(); - std::unique_ptr zero = HloInstruction::CreateConstant( - MakeUnique(Literal::Zero(instruction->shape().element_type()))); - - if (computation->IsFusionComputation()) { - HloInstruction* const fusion_instruction = computation->FusionInstruction(); - if (fusion_instruction->IsFused()) { - return Unimplemented( - "Unable to replace fused parameter of fusion instruction"); - } - TF_RETURN_IF_ERROR(fusion_instruction->ReplaceOperandWith( - instruction->operand(operand_number)->parameter_number(), - fusion_instruction->parent()->AddInstruction(std::move(zero)))); - } else { - TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith( - operand_number, computation->AddInstruction(std::move(zero)))); - } - return Status::OK(); -} - -} // namespace - -StatusOr> MakeFakeLiteral(const Shape& shape) { +// Similar to MakeFakeLiteral but takes a random number generator engine to +// enable reusing the engine across randomly generated literals. +StatusOr> MakeFakeLiteralInternal( + const Shape& shape, std::minstd_rand0* engine) { if (ShapeUtil::IsTuple(shape)) { std::vector> elements; for (const Shape& element_shape : shape.tuple_shapes()) { TF_ASSIGN_OR_RETURN(std::unique_ptr element, - MakeFakeLiteral(element_shape)); + MakeFakeLiteralInternal(element_shape, engine)); elements.push_back(std::move(element)); } return Literal::MakeTupleOwned(std::move(elements)); } std::unique_ptr literal = Literal::CreateFromShape(shape); switch (shape.element_type()) { + case BF16: + PopulateWithRandomFloatingPointData(literal.get(), engine); + break; case F32: - PopulateWithRandomFloatingPointData(literal.get()); + PopulateWithRandomFloatingPointData(literal.get(), engine); break; case F64: - PopulateWithRandomFloatingPointData(literal.get()); + PopulateWithRandomFloatingPointData(literal.get(), engine); break; case S8: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case U8: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case S16: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case U16: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case S32: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case U32: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case S64: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case U64: - PopulateWithRandomIntegralData(literal.get()); + PopulateWithRandomIntegralData(literal.get(), engine); break; case PRED: { std::uniform_int_distribution generator(0, 1); - std::minstd_rand0 engine; TF_CHECK_OK(literal->Populate( [&](tensorflow::gtl::ArraySlice /*indices*/) { - return generator(engine); + return generator(*engine); })); break; } @@ -145,43 +145,162 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { return std::move(literal); } -StatusOr>> MakeFakeArguments( - const HloModule& module) { - std::vector> arguments; - for (const ShapeLayout& shape_layout : - module.config().entry_computation_layout().parameter_layouts()) { - TF_ASSIGN_OR_RETURN(auto literal, MakeFakeLiteral(shape_layout.shape())); - arguments.push_back(std::move(literal)); +// Matches binary addition computations. +bool LooksLikeSum(const HloComputation& computation) { + const HloInstruction* const root = computation.root_instruction(); + return root->opcode() == HloOpcode::kAdd && + computation.num_parameters() == 2 && + root->operand(0)->opcode() == HloOpcode::kParameter && + root->operand(1)->opcode() == HloOpcode::kParameter && + root->operand(0) != root->operand(1); +} + +// Reduce, ReduceWindow, and SelectAndScatter ops may use binary addition, +// which requires an init_value of 0 rather than a random value. +bool NeedsZeroInitValue(const HloUse& use) { + const HloInstruction* const instruction = use.instruction; + const HloOpcode opcode = instruction->opcode(); + const int64 op_num = use.operand_number; + return ( + ((opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow) && + op_num == 1 && LooksLikeSum(*instruction->to_apply())) || + (opcode == HloOpcode::kSelectAndScatter && op_num == 2 && + LooksLikeSum(*instruction->scatter()))); +} + +// Generate random values that are constrained to the input_shape minus the +// output_shape so as not to produce wrapping slices, for instance. +std::unique_ptr MakeRandomNonwrappingSliceIndex( + const Shape& input_shape, const Shape& slice_shape, + std::minstd_rand0* engine) { + const int64 rank = ShapeUtil::Rank(input_shape); + std::vector start_indices(rank); + for (int i = 0; i < rank; ++i) { + const int32 upper_bound = ShapeUtil::GetDimension(input_shape, i) - + ShapeUtil::GetDimension(slice_shape, i); + std::uniform_int_distribution generator(0, upper_bound); + start_indices[i] = generator(*engine); } - return std::move(arguments); + return Literal::CreateR1(start_indices); } -Status ReplaceInitsWithConstants(HloModule* const module) { - for (HloComputation* const computation : module->computations()) { - for (HloInstruction* const instruction : computation->instructions()) { +// Use dataflow analysis on each parameter to see if there are uses that would +// be problematic when generating input data. Returns the list of instructions +// that correspond to their uses. +// +// Should be paired with the CreateLiteralForConstrainedUses() function below. +std::vector FindConstrainedUses( + const HloDataflowAnalysis& dataflow, const HloInstruction& param) { + std::vector constrained_uses; + for (const auto& pair : dataflow.GetInstructionValueSet(¶m)) { + const HloValue& value = dataflow.GetUniqueValueAt(¶m, pair.first); + for (const HloUse& use : value.uses()) { + HloInstruction* instruction = use.instruction; const HloOpcode opcode = instruction->opcode(); - if ((opcode == HloOpcode::kReduce || - opcode == HloOpcode::kReduceWindow) && - LooksLikeSum(*instruction->to_apply()->root_instruction())) { - TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 1)); - } else if (opcode == HloOpcode::kSelectAndScatter && - LooksLikeSum(*instruction->scatter()->root_instruction())) { - TF_RETURN_IF_ERROR(MaybeReplaceParameterInputWithZero(instruction, 2)); + const int64 op_num = use.operand_number; + if ((opcode == HloOpcode::kDynamicSlice && op_num == 1) || + (opcode == HloOpcode::kDynamicUpdateSlice && op_num == 2)) { + constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kFusion) { + const HloInstruction* const to_analyze = + instruction->fused_parameter(op_num); + auto fused_uses = FindConstrainedUses(dataflow, *to_analyze); + constrained_uses.insert(constrained_uses.end(), fused_uses.begin(), + fused_uses.end()); + } else if (NeedsZeroInitValue(use)) { + constrained_uses.push_back(instruction); + } else if (opcode == HloOpcode::kConvert || + opcode == HloOpcode::kReducePrecision) { + auto converted_uses = FindConstrainedUses(dataflow, *instruction); + constrained_uses.insert(constrained_uses.end(), converted_uses.begin(), + converted_uses.end()); } } } - return Status::OK(); + return constrained_uses; +} + +// Given a parameter, generate a random Literal to use as input if there exist +// no constrained uses in the dataflow graph. If such constraints exist, +// generate a constrained literal (either bounded in the case of indices, or +// zero in the case of init_values for reductions). +StatusOr> CreateLiteralForConstrainedUses( + const tensorflow::gtl::ArraySlice constrained_uses, + const HloInstruction& param, std::minstd_rand0* engine) { + HloInstruction* needs_index = nullptr; + HloInstruction* needs_zero = nullptr; + for (HloInstruction* use : constrained_uses) { + switch (use->opcode()) { + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + if (needs_index != nullptr && + !ShapeUtil::Equal(needs_index->shape(), use->shape())) { + return Unimplemented( + "Conflicting operand generation slice index constraints\n"); + } + needs_index = use; + break; + + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + needs_zero = use; + break; + + default: + return Unimplemented( + "Constrained operand generation not implemented for %s.", + use->ToString().c_str()); + } + } + if (needs_index != nullptr && needs_zero != nullptr) { + return Unimplemented( + "Conflicting operand generation constraints.\nNeeds index: %s\nNeeds " + "zero: %s\n", + needs_index->ToString().c_str(), needs_zero->ToString().c_str()); + } + if (needs_index != nullptr) { + return MakeRandomNonwrappingSliceIndex(needs_index->operand(0)->shape(), + needs_index->shape(), engine); + } else if (needs_zero != nullptr) { + return Literal::CreateFromShape(param.shape()); + } else { + return MakeFakeLiteralInternal(param.shape(), engine); + } +} + +// Given a module entry parameter, use the dataflow analysis to see if a +// special case literal must be created, or if we can generate fake data. +StatusOr> MakeConstrainedArgument( + const HloDataflowAnalysis& dataflow, const HloInstruction& param, + std::minstd_rand0* engine) { + const auto constrained_uses = FindConstrainedUses(dataflow, param); + return CreateLiteralForConstrainedUses(constrained_uses, param, engine); +} + +} // namespace + +StatusOr> MakeFakeLiteral(const Shape& shape) { + std::minstd_rand0 engine; + return MakeFakeLiteralInternal(shape, &engine); +} + +StatusOr>> MakeFakeArguments( + HloModule* const module) { + TF_ASSIGN_OR_RETURN(auto dataflow, HloDataflowAnalysis::Run(module)); + const auto params = module->entry_computation()->parameter_instructions(); + std::minstd_rand0 engine; + std::vector> arguments(params.size()); + for (int i = 0; i < params.size(); ++i) { + TF_ASSIGN_OR_RETURN( + arguments[i], MakeConstrainedArgument(*dataflow, *params[i], &engine)); + } + return std::move(arguments); } Status VerifyHloModule(const perftools::gputools::Platform& platform, HloModule* const module) { - return HloVerifier( - std::bind( - &TransferManager::GetByteSizeRequirement, - TransferManager::GetForPlatform(&platform).ConsumeValueOrDie(), - std::placeholders::_1)) - .Run(module) - .status(); + return HloVerifier().Run(module).status(); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/test_utils.h b/tensorflow/compiler/xla/tests/test_utils.h index 9aca162a185e5b22888229555b7bce88769c79a6..0fb024ffb074f1c90b75022bc7f5a8b58b03c0c2 100644 --- a/tensorflow/compiler/xla/tests/test_utils.h +++ b/tensorflow/compiler/xla/tests/test_utils.h @@ -60,13 +60,11 @@ StatusOr> MakeFakeLiteral(const Shape& shape); // Generates a vector of arguments containing fake data. The number, shape and // layout of the arguments is appropriate for given HLO module. +// +// Will handle special cases such as making sure that indices used for dynamic +// slices are bounded, reduces that call adds use 0 as an init value, etc. StatusOr>> MakeFakeArguments( - const HloModule& module); - -// Reductions using Adds, ReduceWindow, and SelectAndScatter, require their -// init_value to be replaced with the constant 0.0f when testing, otherwise we -// may generate a bad init_value when looking at the op in isolation. -Status ReplaceInitsWithConstants(HloModule* const module); + HloModule* const module); // Check that a given module satisfies various constraints before trying to // execute it. diff --git a/tensorflow/compiler/xla/tests/transfer_manager_test.cc b/tensorflow/compiler/xla/tests/transfer_manager_test.cc index f2a64749482e5f5a8c5d72034fb7a4eee07baf48..268ba338f2e6740a1d1a046d5a85494f3cf2e9f8 100644 --- a/tensorflow/compiler/xla/tests/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/tests/transfer_manager_test.cc @@ -46,9 +46,10 @@ class TransferManagerTest : public LocalClientTestBase { ~TransferManagerTest() override = default; std::unique_ptr AllocateDeviceBuffer(const Shape& shape) { - return ScopedShapedBuffer::Allocate( - shape, GetOrCreateAllocator(local_client_->platform()), - /*device_ordinal=*/0, shape_size_fn_) + return transfer_manager_ + ->AllocateScopedShapedBuffer( + shape, GetOrCreateAllocator(local_client_->platform()), + /*device_ordinal=*/0) .ValueOrDie(); } @@ -118,7 +119,7 @@ XLA_TEST_F(TransferManagerTest, TransferR1U8) { transfer_manager_->TransferLiteralFromDevice( stream_executor_, *device_buffer)); - EXPECT_EQ(result->u8s_string(), test_string); + EXPECT_EQ(result->GetR1U8AsString(), test_string); } XLA_TEST_F(TransferManagerTest, TransferR2F32) { @@ -211,5 +212,39 @@ XLA_TEST_F(TransferManagerTest, TransferNestedTuple) { LiteralTestUtil::ExpectEqual(*literal, *result); } +XLA_TEST_F(TransferManagerTest, TransferComplexValue) { + std::unique_ptr literal = Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + +XLA_TEST_F(TransferManagerTest, TransferComplexValueInTuple) { + std::unique_ptr literal = Literal::MakeTuple( + {Literal::CreateR1( + {complex64(1.0f, 2.0f), complex64(42.0f, -123.4f)}) + .get(), + Literal::CreateR1({1, 2, 3, 4, 5, 6}).get(), + Literal::CreateR0(complex64(0.3f, -0.4f)).get()}); + auto device_buffer = AllocateDeviceBuffer(literal->shape()); + + // Round trip literal through device. + ASSERT_IS_OK(transfer_manager_->TransferLiteralToDevice( + stream_executor_, *literal, *device_buffer)); + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + transfer_manager_->TransferLiteralFromDevice( + stream_executor_, *device_buffer)); + + LiteralTestUtil::ExpectEqual(*literal, *result); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 5a012c93d64f6a6fca73aa422e20cf238c945ce9..2029312f94a14bc81706368b9ecfc2727fd9fe4c 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -57,6 +57,20 @@ XLA_TEST_F(TupleTest, TupleConstant) { ComputeAndCompareTuple(&builder, *value, {}, error_spec_); } +// Tests a tuple made of scalar constants. +XLA_TEST_F(TupleTest, TupleScalarConstant) { + ComputationBuilder builder(client_, TestName()); + + const float constant_scalar1 = 7.3f; + const float constant_scalar2 = 1.2f; + auto value = + Literal::MakeTuple({Literal::CreateR0(constant_scalar1).get(), + Literal::CreateR0(constant_scalar2).get()}); + + auto result = builder.ConstantLiteral(*value); + ComputeAndCompareTuple(&builder, *value, {}, error_spec_); +} + // Tests the creation of tuple data. XLA_TEST_F(TupleTest, TupleCreate) { ComputationBuilder builder(client_, TestName()); @@ -180,8 +194,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTuple) { ComputeAndCompareTuple(&builder, *expected, {}, error_spec_); } -// TODO(b/68395210): GPU does not tolerate ambiguous top-level buffers. -XLA_TEST_F(TupleTest, DISABLED_ON_GPU(SelectBetweenPredTuples)) { +XLA_TEST_F(TupleTest, SelectBetweenPredTuples) { ComputationBuilder b(client_, TestName()); ComputationDataHandle v1, v2; @@ -445,5 +458,61 @@ XLA_TEST_F(TupleTest, GetTupleElementOfNestedTuple) { ComputeAndCompareR1(&builder, expected, arguments, ErrorSpec(1e-5)); } +XLA_TEST_F(TupleTest, ComplexTuples) { + ComputationBuilder builder(client_, TestName()); + { + Shape c64r0 = ShapeUtil::MakeShape(C64, {}); + Shape c64r1 = ShapeUtil::MakeShape(C64, {2}); + Shape c64r2 = ShapeUtil::MakeShape(C64, {3, 2}); + Shape arg0_shape = ShapeUtil::MakeTupleShape( + {c64r0, ShapeUtil::MakeTupleShape({c64r1, c64r2})}); + auto input0 = builder.Parameter(0, arg0_shape, "input0"); + auto t0 = builder.GetTupleElement(input0, 0); + auto t1 = builder.GetTupleElement(input0, 1); + auto t10 = builder.GetTupleElement(t1, 0); + auto t11 = builder.GetTupleElement(t1, 1); + auto sum = builder.Add(builder.Add(t10, t11, {1}), t0); + auto input1 = builder.Parameter(1, c64r1, "input1"); + auto prod = builder.Mul(input1, sum, {1}); + builder.Tuple({builder.Tuple({prod, sum}), + builder.ConstantR0({123, 456})}); + } + + std::unique_ptr arg0 = + client_ + ->TransferToServer(*Literal::MakeTuple( + {Literal::CreateR0({1, 2}).get(), + Literal::MakeTuple( + {Literal::CreateR1({{10, 20}, {30, 40}}).get(), + Literal::CreateR2( + {{{100, 200}, {300, 400}}, + {{1000, 2000}, {3000, 4000}}, + {{10000, 20000}, {30000, 40000}}}) + .get()}) + .get()})) + .ConsumeValueOrDie(); + std::unique_ptr arg1 = + client_ + ->TransferToServer(*Literal::CreateR1({{1, 2}, {1, -2}})) + .ConsumeValueOrDie(); + auto sum = Literal::CreateR2({{{111, 222}, {331, 442}}, + {{1011, 2022}, {3031, 4042}}, + {{10011, 20022}, {30031, 40042}}}); + auto prod = Literal::CreateFromShape(sum->shape()); + ASSERT_TRUE(prod->Populate( + [&sum](tensorflow::gtl::ArraySlice indexes) { + return sum->Get(indexes) * + (indexes[indexes.size() - 1] == 0 + ? complex64(1, 2) + : complex64(1, -2)); + }) + .ok()); + auto expected = + Literal::MakeTuple({Literal::MakeTuple({prod.get(), sum.get()}).get(), + Literal::CreateR0({123, 456}).get()}); + ComputeAndCompareTuple(&builder, *expected, {arg0.get(), arg1.get()}, + error_spec_); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index fa4192e9281784a4a3063601afe89fba6a9dac18..835e2d7e5594d7c8c6e523f9806e32dce23a87e9 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -215,5 +215,23 @@ XLA_TEST_F(UnaryOpTest, SignAbsTestR2) { ComputeAndCompareR2(&builder, {{0, 0}, {0, 0}}, {}); } +XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToS32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 1}); + auto rhs = builder.ConstantR1({1, 1}); + builder.ConvertElementType(builder.Eq(lhs, rhs), S32); + + ComputeAndCompareR1(&builder, {0, 1}, {}); +} + +XLA_TEST_F(UnaryOpTest, ConvertElementTypePredToF32) { + ComputationBuilder builder(client_, TestName()); + auto lhs = builder.ConstantR1({0, 1}); + auto rhs = builder.ConstantR1({1, 1}); + builder.ConvertElementType(builder.Eq(lhs, rhs), F32); + + ComputeAndCompareR1(&builder, {0.0, 1.0}, {}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index 49f673f5f0bf9b844ab4030383784208b4e2c58a..52157b837c383205f77a030ef98b2fd03a41aff5 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -357,8 +357,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { +TEST_F(WhileTest, WhileWithPermutationAndTupleResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -411,8 +410,7 @@ TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } -// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. -TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { +TEST_F(WhileTest, WhileWithPermutationAndVectorResult) { std::vector shape_elements = { ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; @@ -565,6 +563,53 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0)); } +TEST_F(WhileTest, WhileWithTupleConstantScalarResult) { + std::vector shape_elements = {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeShape(S32, {})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(5), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and set the other tuple element to a + // constant. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto result = + builder.Tuple({builder.Add(iteration, builder.ConstantR0(1)), + builder.ConstantR0(7)}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR0(7)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "while = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(5); + auto expected_data = Literal::CreateR0(7); + auto expected = + Literal::MakeTuple({expected_counter.get(), expected_data.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + // Tests two while nodes when the result type T is a Tuple and the second // while node uses the result of the first while node which is used in two // nodes. @@ -913,8 +958,7 @@ TEST_F(WhileTest, WhileWithPrngScalarResult) { } } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithTupleElement) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -950,8 +994,7 @@ TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithTupleElement)) { ErrorSpec(1e-6)); } -// TODO(b/34969189) Fails with bad AtomicCmpSwap on GPU on 2017-09-11. -TEST_F(WhileTest, DISABLED_ON_GPU(WhileThatSwapsParameterWithBroadcast)) { +TEST_F(WhileTest, WhileThatSwapsParameterWithBroadcast) { auto element_shape = ShapeUtil::MakeShape(F32, {2}); ComputationBuilder outer(client_, "outer"); @@ -1164,6 +1207,50 @@ TEST_F(WhileTest, WhileWithCallInsideCondition) { ComputeAndCompareR0(&builder, 5, {}); } +TEST_F(WhileTest, WhileWithLoopInvariantOperation) { + auto matrix_shape = ShapeUtil::MakeShape(F32, {2, 2}); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + auto while_shape = ShapeUtil::MakeTupleShape( + {scalar_s32, matrix_shape, matrix_shape, matrix_shape}); + + // Create a computation for the condition: repeat for 5 iterations. + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto state = builder.Parameter(0, while_shape, "state"); + builder.Gt(builder.ConstantR0(5), builder.GetTupleElement(state, 0)); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto state = builder.Parameter(0, while_shape, "state"); + auto indvar = builder.GetTupleElement(state, 0); + auto input_0 = builder.GetTupleElement(state, 1); + auto input_1 = builder.GetTupleElement(state, 2); + auto output = builder.Tanh(builder.Dot(input_0, input_1)); + auto indvar_next = builder.Add(indvar, builder.ConstantR0(1)); + auto tuple_result = builder.Tuple({indvar_next, input_0, input_1, output}); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + ComputationBuilder builder(client_, TestName()); + auto matrix_input = builder.Parameter(0, matrix_shape, "matrix"); + auto init = builder.Tuple( + {builder.ConstantR0(0), matrix_input, matrix_input, matrix_input}); + auto while_instruction = builder.While(condition, body, init); + builder.GetTupleElement(while_instruction, 3); + + TF_ASSERT_OK_AND_ASSIGN(auto param_value, + client_->TransferToServer(*Literal::CreateR2( + {{1.0, 2.0}, {-1.0, -2.0}}))); + + ComputeAndCompareR2( + &builder, {{-0.76159416, -0.96402758}, {0.76159416, 0.96402758}}, + {param_value.get()}, ErrorSpec(4e-5)); +} + void BM_WhileLoop(int num_iters) { // Benchmark a simple kernel to measure while loop overheads. tensorflow::testing::StopTiming(); diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9ad2a1985331b80625dd0687ea052300bc99e440 --- /dev/null +++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc @@ -0,0 +1,362 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/compiler/xla/array2d.h" +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/client_library_test_base.h" +#include "tensorflow/compiler/xla/tests/test_macros.h" +#include "tensorflow/compiler/xla/tests/test_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace { +namespace se = ::perftools::gputools; +namespace gtl = ::tensorflow::gtl; + +class HloProfileTest : public ClientLibraryTestBase {}; + +struct ParsedProfileOutputLine { + int64 cycles; + string cycles_percentage; + double usec; + string flops; + string trops; + string bytes_per_sec; + string bytes_per_cycle; + string opcode; +}; + +::testing::AssertionResult HasFlops( + const ParsedProfileOutputLine& parsed_line) { + if (RE2::FullMatch(parsed_line.flops, "[0-9.TGMk]+FLOP/s")) { + return ::testing::AssertionSuccess() + << "'flops' field present in " << parsed_line.opcode << ": '" + << parsed_line.flops << "'"; + } + + return ::testing::AssertionFailure() + << "'flops' field absent in " << parsed_line.opcode << ": '" + << parsed_line.flops << "'"; +} + +::testing::AssertionResult HasTrops( + const ParsedProfileOutputLine& parsed_line) { + if (RE2::FullMatch(parsed_line.trops, "[0-9.TGMk]+TROP/s")) { + return ::testing::AssertionSuccess() + << "'trops' field present in " << parsed_line.opcode << ": '" + << parsed_line.trops << "'"; + } + + return ::testing::AssertionFailure() + << "'trops' field absent in " << parsed_line.opcode << ": '" + << parsed_line.trops << "'"; +} + +Status ParseOneProfileOutputLine( + const string& line, bool expect_hlo, + gtl::FlatMap* parsed_results) { + string separator = "[^:]*:: +"; + string match_percentage = "\\d+\\.\\d\\d%"; + string match_cycles = "(\\d+) cycles +\\( *(" + match_percentage + ")\\)"; + string match_usecs = "([0-9.]+) usec"; + string match_flops = "([^ ]+)"; + string match_trops = "([^ ]+)"; + string match_bytes_per_sec = "([0-9.TGMKi]+)B/s"; + string match_bytes_per_cycle = "([0-9.TGMKi]+)B/cycle"; + + // The underlined part is what we're trying to match with match_opcode: + // + // %dot33 = f32[256,256]{1,0} dot(...) + // ^^^ + + string match_opcode = + expect_hlo ? "%[^=]+= [^ ]+ ([^(]+)\\(.*" : "(\\[total\\])"; + string regexp_pattern = tensorflow::strings::StrCat( + " +", match_cycles, separator, match_usecs, separator, match_flops, + separator, match_trops, separator, match_bytes_per_sec, separator, + match_bytes_per_cycle, separator, match_opcode); + + ParsedProfileOutputLine parsed_line; + bool matched = RE2::FullMatch( + line, regexp_pattern, &parsed_line.cycles, &parsed_line.cycles_percentage, + &parsed_line.usec, &parsed_line.flops, &parsed_line.trops, + &parsed_line.bytes_per_sec, &parsed_line.bytes_per_cycle, + &parsed_line.opcode); + if (!matched) { + return tensorflow::errors::InvalidArgument( + "Input did not match regexp. Input: ", line, + ", Regexp: ", regexp_pattern); + } + + InsertOrDie(parsed_results, parsed_line.opcode, parsed_line); + + return Status::OK(); +} + +// Returns void so that we can ASSERT. +void ExecuteAndFetchProfile(string* profile_output, LocalClient* client, + const Computation& computation, + const Shape& lhs_arg_shape, + const Shape& rhs_arg_shape) { + LocalService* service = ClientLibrary::GetXlaService(client->platform()); + Backend* backend = service->mutable_backend(); + se::StreamExecutor* executor = backend->default_stream_executor(); + DeviceMemoryAllocator* allocator = backend->memory_allocator(); + auto* transfer_manager = backend->transfer_manager(); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr lhs_arg, + transfer_manager->AllocateScopedShapedBuffer( + lhs_arg_shape, allocator, backend->default_device_ordinal())); + TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( + executor, *Literal::CreateFromShape(lhs_arg_shape), *lhs_arg)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr rhs_arg, + transfer_manager->AllocateScopedShapedBuffer( + rhs_arg_shape, allocator, backend->default_device_ordinal())); + TF_ASSERT_OK(transfer_manager->TransferLiteralToDevice( + executor, *Literal::CreateFromShape(rhs_arg_shape), *rhs_arg)); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr local_executable, + client->Compile(computation, {&lhs_arg_shape, &rhs_arg_shape}, + ExecutableBuildOptions())); + + Executable* executable = local_executable->executable(); + HloExecutionProfile hlo_execution_profile( + &executable->hlo_profile_printer_data(), + &executable->hlo_profile_index_map()); + + TF_ASSERT_OK_AND_ASSIGN( + Backend::StreamPtr stream_ptr, + backend->BorrowStream(backend->default_device_ordinal())); + ExecutableRunOptions exec_run_options; + exec_run_options.set_stream(stream_ptr.get()); + exec_run_options.set_allocator(backend->memory_allocator()); + exec_run_options.set_intra_op_thread_pool( + backend->eigen_intra_op_thread_pool_device()); + ServiceExecutableRunOptions run_options( + exec_run_options, /*borrow_stream=*/nullptr, + backend->eigen_intra_op_thread_pool()); + TF_ASSERT_OK_AND_ASSIGN( + auto execution_result, + executable->ExecuteOnStream(&run_options, {lhs_arg.get(), rhs_arg.get()}, + &hlo_execution_profile)); + (void)execution_result; + + *profile_output = + hlo_execution_profile.ToString(executor->GetDeviceDescription()); + + XLA_VLOG_LINES(4, *profile_output); +} + +// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. +XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) { + const int64 m = 256, k = 256, n = 256; + Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k}); + + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + PlatformUtil::GetDefaultPlatform()); + TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(platform)); + + ComputationBuilder builder(client, TestName()); + auto result = builder.Tanh(builder.Add( + builder.Parameter(0, ShapeUtil::MakeShape(F32, {m, k}), "dot_lhs"), + builder.Parameter(1, ShapeUtil::MakeShape(F32, {k, n}), "dot_rhs"))); + + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); + + string profile_output; + ExecuteAndFetchProfile(&profile_output, client, computation, lhs_shape, + rhs_shape); + + std::vector profile_output_lines = + tensorflow::str_util::Split(profile_output, '\n'); + + gtl::FlatMap parsed_profile_lines; + + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines)); + + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[2], /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK(ParseOneProfileOutputLine( + profile_output_lines[3], /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_profile, + MaybeFind(parsed_profile_lines, "[total]")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, + MaybeFind(parsed_profile_lines, "add")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine tanh_profile, + MaybeFind(parsed_profile_lines, "tanh")); + + EXPECT_GT(total_profile.cycles, 0); + EXPECT_EQ(total_profile.cycles_percentage, "100.00%"); + + EXPECT_TRUE(HasFlops(total_profile)); + EXPECT_TRUE(HasTrops(total_profile)); + + EXPECT_GT(total_profile.cycles, dot_profile.cycles); + EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); + EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); + + EXPECT_TRUE(HasFlops(dot_profile)); + EXPECT_FALSE(HasTrops(dot_profile)); + + EXPECT_GT(total_profile.cycles, tanh_profile.cycles); + EXPECT_NE(tanh_profile.cycles_percentage, "0.00%"); + EXPECT_NE(tanh_profile.cycles_percentage, "100.00%"); + + EXPECT_FALSE(HasFlops(tanh_profile)); + EXPECT_TRUE(HasTrops(tanh_profile)); +} + +// TODO(b/71364943): This test exposes a bug in the parallel CPU backend. +// +// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo +// instructions "interior" to while nodes. +XLA_TEST_F(HloProfileTest, + DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(ProfileWhileComputation))) { + const int64 size = 256; + Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size}); + Shape while_result_shape = + ShapeUtil::MakeTupleShape({ShapeUtil::MakeShape(S32, {}), matrix_shape}); + + TF_ASSERT_OK_AND_ASSIGN(se::Platform * platform, + PlatformUtil::GetDefaultPlatform()); + TF_ASSERT_OK_AND_ASSIGN(LocalClient * client, + ClientLibrary::GetOrCreateLocalClient(platform)); + + Computation condition; + { + ComputationBuilder builder(client, "condition"); + auto state = builder.Parameter(0, while_result_shape, "state"); + auto iteration = builder.GetTupleElement(state, 0); + builder.Gt(builder.ConstantR0(5), iteration); + TF_ASSERT_OK_AND_ASSIGN(condition, builder.Build()); + } + + Computation body; + { + ComputationBuilder builder(client, "body"); + auto state = builder.Parameter(0, while_result_shape, "state"); + auto matrix = builder.GetTupleElement(state, 1); + auto next_iteration = builder.Add(builder.GetTupleElement(state, 0), + builder.ConstantR0(1)); + builder.Tuple({next_iteration, builder.Add(matrix, matrix)}); + TF_ASSERT_OK_AND_ASSIGN(body, builder.Build()); + } + + ComputationBuilder builder(client, TestName()); + auto initial_while_state = + builder.Tuple({builder.ConstantR0(0), + builder.Parameter(0, matrix_shape, "initial_value")}); + auto while_result = builder.While(condition, body, initial_while_state); + builder.Add(builder.GetTupleElement(while_result, 1), + builder.Parameter(1, matrix_shape, "other_value")); + + TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); + + string profile_output; + ExecuteAndFetchProfile(&profile_output, client, computation, matrix_shape, + matrix_shape); + + std::vector profile_output_lines = + tensorflow::str_util::Split(profile_output, '\n'); + + auto while_body_profile_start = + std::find_if(profile_output_lines.begin(), profile_output_lines.end(), + [](tensorflow::StringPiece s) { + return s.starts_with("Execution profile for body"); + }); + + ASSERT_NE(while_body_profile_start, profile_output_lines.end()); + + gtl::FlatMap parsed_profile_lines; + + TF_ASSERT_OK( + ParseOneProfileOutputLine(*std::next(while_body_profile_start, 1), + /*expect_hlo=*/false, &parsed_profile_lines)); + + TF_ASSERT_OK( + ParseOneProfileOutputLine(*std::next(while_body_profile_start, 2), + /*expect_hlo=*/true, &parsed_profile_lines)); + + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine total_while_body_profile, + MaybeFind(parsed_profile_lines, "[total]")); + TF_ASSERT_OK_AND_ASSIGN(ParsedProfileOutputLine dot_profile, + MaybeFind(parsed_profile_lines, "add")); + + EXPECT_GT(total_while_body_profile.cycles, 0); + EXPECT_EQ(total_while_body_profile.opcode, "[total]"); + EXPECT_EQ(total_while_body_profile.cycles_percentage, "100.00%"); + + EXPECT_GT(total_while_body_profile.cycles, dot_profile.cycles); + EXPECT_NE(dot_profile.cycles_percentage, "0.00%"); + EXPECT_NE(dot_profile.cycles_percentage, "100.00%"); +} +} // namespace +} // namespace xla + +static std::pair AddXlaHloProfileFlag(int argc, char** argv) { + // Intentional "leak". + char** new_argv = new char*[argc + 2]; + for (int i = 0; i < argc; i++) { + new_argv[i] = argv[i]; + } + + // We do it this way (as opposed to piping in a modified DebugOptions + // instance) for better end-to-end integration testing. + new_argv[argc] = strdup("--xla_hlo_profile"); + + // Fusion can change the Hlo instructions that show up in the final Hlo + // executable, so block it here. + new_argv[argc + 1] = strdup("--xla_disable_hlo_passes=fusion"); + return {argc + 2, new_argv}; +} + +GTEST_API_ int main(int argc, char** argv) { + std::vector flag_list; + xla::legacy_flags::AppendDebugOptionsFlags(&flag_list); + std::tie(argc, argv) = AddXlaHloProfileFlag(argc, argv); + + auto usage = tensorflow::Flags::Usage(argv[0], flag_list); + if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) { + LOG(ERROR) << "\n" << usage; + return 2; + } + + testing::InitGoogleTest(&argc, argv); + if (argc > 1) { + LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage; + return 2; + } + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/compiler/xla/text_literal_reader.cc b/tensorflow/compiler/xla/text_literal_reader.cc index 4d060895d357493327ec50b38016478c65fef94d..6fa4c48e11d1102367b21bc21d4734466495ef0e 100644 --- a/tensorflow/compiler/xla/text_literal_reader.cc +++ b/tensorflow/compiler/xla/text_literal_reader.cc @@ -102,9 +102,9 @@ StatusOr> TextLiteralReader::ReadAllLines() { ShapeUtil::HumanString(shape).c_str()); } - auto result = MakeUnique(); + auto result = MakeUnique(shape); const float fill = std::numeric_limits::quiet_NaN(); - result->PopulateWithValue(fill, AsInt64Slice(shape.dimensions())); + result->PopulateWithValue(fill); std::vector pieces; std::vector coordinates; std::vector coordinate_values; diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 5ede37b8737bd4fa6235464ddeb6382af17c8a80..b82f1c81c84b487c1661af5267b9123da97bb107 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -85,10 +85,12 @@ void RealMain(tensorflow::gtl::ArraySlice args) { for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } + ExecutableBuildOptions build_options; + build_options.set_device_ordinal(0); + build_options.set_result_layout(program_shape->result()); StatusOr> executable = local_service->CompileExecutable(computation.handle(), layouts, - &program_shape->result(), - /*device_ordinal=*/0); + build_options); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 78d8fb1f4330aed899ca917e66fae819a002b3a9..05c0fdf97d27c09eb2bbb0f265b5b2a5982ca7b1 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -60,16 +60,19 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { for (int i = 0; i < program_shape->parameters_size(); ++i) { layouts.push_back(&program_shape->parameters(i)); } + + ExecutableBuildOptions build_options; + build_options.set_device_ordinal(0); + build_options.set_result_layout(program_shape->result()); StatusOr> executable = local_service->CompileExecutable(computation.handle(), layouts, - &program_shape->result(), - /*device_ordinal=*/0); + build_options); const HloModule& module = executable.ValueOrDie()->module(); fprintf(stdout, "HLO compiled for %s backend:\n%s\n", local_service->backend().platform()->Name().c_str(), - module.ToString().c_str()); + module.ToString(HloPrintOptions::ShortParsable()).c_str()); } else { const ComputationTracker& tracker = local_service->computation_tracker(); UserComputation* user_computation = @@ -80,7 +83,8 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { tracker.BuildHloModule(versioned_handle, HloModuleConfig()) .ConsumeValueOrDie(); - fprintf(stdout, "%s\n", module->ToString().c_str()); + fprintf(stdout, "%s\n", + module->ToString(HloPrintOptions::ShortParsable()).c_str()); } } } diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc index 4e02e17db65c0a4220672733be8319e1a0cc4f0f..8460ae3e4991ee091af72d2553a8491f627c722e 100644 --- a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -19,7 +19,7 @@ limitations under the License. // // Reads one serilized Hlo module, convert it into JSON format and dump into // some output directory. some_binaray_proto is obtained by serializing Hlo -// module to disk using --xla_dump_hlo_proto_to debug optoin. +// module to disk using --xla_dump_optimized_hlo_proto_to debug option. #include #include diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index ce936af6c3376387c1ed9fa48da23b8af537f6e5..97aacf6b39f83978e732060817cd93ede81ca782 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -34,9 +34,9 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", ], diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md index 6232967f5f04cbf316d985357ae84c28335531e2..f0f3dd7785c13e505e1eb6d4c8cd4bad157c4993 100644 --- a/tensorflow/compiler/xla/tools/parser/README.md +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -1,24 +1,26 @@ -# HloModule string syntax - -TODO: Support all subcomputations (for fusion, reduce, ...). - -TODO: Support all extra attributes, e.g. dimensions, strides. +# HLO Text Syntax ```yacc hlo_module : 'HloModule' name computations ; +/* If no computation is marked as ENTRY, the last computation will be the entry +computation of the module.*/ computations : computation | computation computations ; computation - : 'ENTRY' name param_list '->' shape instruction_list - | name param_list '->' shape instruction_list + : 'ENTRY' name param_list_to_shape instruction_list + | name param_list_to_shape instruction_list + | 'ENTRY' name instruction_list + | name instruction_list ; +/* If no instruction is marked as ROOT, the last instruction will be the root of +its computation. */ instruction_list : '{' instruction_list1 '}' ; @@ -41,6 +43,7 @@ operands1 ; operand : shape name + | name ; attributes @@ -60,6 +63,10 @@ attribute_value | '{' sub_attributes '}' ; +param_list_to_shape + : param_list '->' shape + ; + param_list : '(' param_list1 ')' ; @@ -84,6 +91,7 @@ tuple_elements name : identifier ':' | '%' identifier + | identifier ; identifier @@ -108,7 +116,29 @@ non_tuple | rank2345 ; rank2345 - : shape nested_array + : shape sparse_or_nested_array + ; +sparse_or_nested_array + : sparse_array + | nested_array + ; +sparse_array + : '{' sparse_array1 '}' + ; +sparse_array1 + : sparse_array_item + | sparse_array1 ',' sparse_array_item + ; +sparse_array_item + : multi_index ':' scalar + ; +multi_index + : kInt + | '[' multi_index1 ']' + ; +multi_index1 + : kInt + | multi_index1 ',' kInt ; ``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index 56744440db1b17aa1cc8823feb1bad279f8f4f75..fc0e4444521247734fc240a03da669244fe1a6a4 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -17,7 +17,6 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -153,21 +152,21 @@ TokKind HloLexer::LexToken() { } } -// Lex a shape, name, keyword, opcode, attribute name, or the dim labels -// pattern. +// Lex a shape, name, keyword, attribute name, the dim labels pattern, and +// other identifiers. // // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... -// opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); // 'consumable' will be advanced iff its prefix matches the pattern. static LazyRE2 shape_pattern = { - R"(^(\w*\d*)\[([\d,]*)\](?:{([\d,]*)})?)"}; + R"(^(\w*\d*)\[([\d,]*)\](?:(dense|sparse)?{([\d,]+)})?)"}; if (RE2::Consume(&consumable, *shape_pattern)) { auto status_or_shape = ShapeUtil::ParseShapeString( StringPieceFromPointers(token_start_, consumable.begin())); @@ -220,20 +219,6 @@ TokKind HloLexer::LexIdentifier() { #undef KEYWORD - // See if this is an opcode. - auto opcode = StringToHloOpcode(identifier.ToString()); - if (opcode.ok()) { - opcode_val_ = opcode.ValueOrDie(); - return TokKind::kOpcode; - } - - // See if this is an fusion kind. - auto kind = xla::StringToFusionKind(identifier.ToString()); - if (kind.ok()) { - fusion_kind_val_ = kind.ValueOrDie(); - return TokKind::kFusionKind; - } - { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 dim_labels_pattern = { @@ -244,8 +229,9 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kDimLabels; } } - current_ptr_ = token_start_ + 1; - return TokKind::kError; + + str_val_ = identifier.ToString(); + return TokKind::kIdent; } // Lex names after a % character. @@ -271,7 +257,8 @@ TokKind HloLexer::LexPercent() { // fp without exp ::= [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} // dxd_pattern ::= [0-9]+(x[0-9]+)+ -// pad_pattern ::= [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* +// pad_pattern ::= +// [-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)* // int ::= [-]?[0-9]+ // negative inf ::= '-inf' TokKind HloLexer::LexNumberOrPattern() { @@ -289,7 +276,7 @@ TokKind HloLexer::LexNumberOrPattern() { R"([0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,})"}; static LazyRE2 dxd_pattern = {R"([0-9]+(x[0-9]+)+)"}; static LazyRE2 pad_pattern = { - R"([0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)*)"}; + R"([-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?(x[-]?[0-9]+_[-]?[0-9]+(_[0-9]+)?)*)"}; if (RE2::Consume(&consumable, *dim_labels_pattern)) { current_ptr_ = consumable.begin(); @@ -326,18 +313,43 @@ TokKind HloLexer::LexNumberOrPattern() { return TokKind::kError; } -StringPiece HloLexer::GetCurrentLine() const { - const char* start = token_start_; - const char* end = current_ptr_; - if (!CanDereference(start) || !CanDereference(end)) { - return "LINE OUT OF RANGE"; +std::pair HloLexer::GetLineAndColumn(LocTy location) const { + unsigned line_no = 1; + const char* start = buf_.begin(); + const char* ptr = start; + if (line_no_cache_.last_query && CanDereference(line_no_cache_.last_query) && + line_no_cache_.last_query <= location) { + ptr = line_no_cache_.last_query; + line_no = line_no_cache_.line_no_of_query; } - while (start > buf_.begin() && *start != '\n') { - start--; + for (; ptr != location; ptr++) { + if (*ptr == '\n') { + line_no++; + } } - while (end < buf_.end() && *end != '\n') { - end++; + + // Update the line number cache. + line_no_cache_.last_query = ptr; + line_no_cache_.line_no_of_query = line_no; + size_t line_offset = StringPieceFromPointers(start, ptr).rfind('\n'); + if (line_offset == StringPiece::npos) { + line_offset = 0; } + return {line_no, ptr - start - line_offset}; +} + +StringPiece HloLexer::GetLine(LocTy loc) const { + if (!CanDereference(loc)) { + return "LINE OUT OF RANGE"; + } + size_t line_start = + StringPieceFromPointers(buf_.begin(), loc + 1).rfind('\n'); + const char* start = line_start == StringPiece::npos + ? buf_.begin() + : buf_.begin() + line_start + 1; + size_t line_end = StringPieceFromPointers(loc, buf_.end()).find('\n'); + const char* end = line_end == StringPiece::npos ? buf_.end() : loc + line_end; + return StringPieceFromPointers(start, end); } @@ -428,14 +440,12 @@ string TokKindToString(TokKind kind) { return "kDxD"; case TokKind::kPad: return "kPad"; + case TokKind::kIdent: + return "kIdent"; case TokKind::kString: return "kString"; case TokKind::kShape: return "kShape"; - case TokKind::kOpcode: - return "kOpcode"; - case TokKind::kFusionKind: - return "kFusionKind"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 5c9d1bf3912584040dc5260cc6730247d439fd60..27880b9b8afbfa58abfedc3b2cecd5236b78a6d6 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,9 +18,8 @@ limitations under the License. #include -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" @@ -48,6 +47,7 @@ class HloLexer { case TokKind::kDxD: case TokKind::kPad: case TokKind::kString: + case TokKind::kIdent: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -57,14 +57,6 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - HloOpcode GetOpcodeVal() const { - CHECK(GetKind() == TokKind::kOpcode); - return opcode_val_; - } - HloInstruction::FusionKind GetFusionKindVal() const { - CHECK(GetKind() == TokKind::kFusionKind); - return fusion_kind_val_; - } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -74,8 +66,16 @@ class HloLexer { return decimal_val_; } - // Returns the line of text that is currently being lexed. - tensorflow::StringPiece GetCurrentLine() const; + typedef const char* LocTy; + + // Returns the location of the current token. + LocTy GetLoc() const { return token_start_; } + + // Returns the line and column of a location in the buffer. + std::pair GetLineAndColumn(LocTy location) const; + + // Returns the whole line given the location. + tensorflow::StringPiece GetLine(LocTy loc) const; private: // Returns the current character. If it's neither the end of input buffer nor @@ -114,10 +114,15 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - HloOpcode opcode_val_; - HloInstruction::FusionKind fusion_kind_val_; int64 int64_val_; double decimal_val_; + + struct LineNoCacheTy { + const char* last_query; + unsigned line_no_of_query; + }; + // This caches the line number of the previous query. + mutable LineNoCacheTy line_no_cache_{nullptr, 0}; }; } // namespace tools diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 47979ec6f361789f29e8f7ff47793747330551fc..89def5d5610cb9522a69297668b443b8c4e03fb5 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -40,6 +41,8 @@ const double kF16max = 65504; // Parser for the HloModule::ToString() format text. class HloParser { public: + using LocTy = HloLexer::LocTy; + explicit HloParser(StringPiece str, const HloModuleConfig& config) : lexer_(str), config_(config) {} @@ -56,7 +59,7 @@ class HloParser { // ParseXXX returns false if an error occurred. bool ParseHloModule(); bool ParseComputations(); - bool ParseComputation(); + bool ParseComputation(HloComputation** entry_computation); bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); @@ -65,6 +68,13 @@ class HloParser { bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); bool ParseNonTupleLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseDenseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseSparseLiteral(std::unique_ptr* literal, + const Shape& shape); + template + bool ParseSparseLiteralHelper(std::unique_ptr* literal, + const Shape& shape); + // Sets the sub-value of literal at the given index to the given value. The // literal's shape must have the default layout. bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); @@ -96,6 +106,7 @@ class HloParser { kString, kBracedInt64List, kHloComputation, + kFftType, kWindow, kConvolutionDimensionNumbers, kSharding, @@ -104,6 +115,7 @@ class HloParser { kPaddingConfig, kMetadata, kFusionKind, + kDistribution, }; struct AttrConfig { @@ -167,20 +179,30 @@ class HloParser { bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); + bool ParseParamListToShape(Shape* shape, LocTy* shape_loc); bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); + bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); + bool ParseRandomDistribution(RandomDistribution* result); bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); bool ParseToken(TokKind kind, const string& msg); + // Returns true if the current token is the beginning of a shape. + bool CanBeShape(); + // Returns true if the current token is the beginning of a + // param_list_to_shape. + bool CanBeParamListToShape(); + // Logs the current parsing line and the given message. Always returns false. bool TokenError(StringPiece msg); + bool Error(LocTy loc, StringPiece msg); // If the current token is 'kind', eats it (i.e. lexes the next token) and // returns true. @@ -191,31 +213,47 @@ class HloParser { // Adds the instruction to the pool. Returns false and emits an error if the // instruction already exists. - bool AddInstruction(const string& name, HloInstruction* instruction); + bool AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc); // Adds the computation to the pool. Returns false and emits an error if the // computation already exists. - bool AddComputation(const string& name, HloComputation* computation); + bool AddComputation(const string& name, HloComputation* computation, + LocTy name_loc); - // The map from the instruction name to the instruction. This does not own the - // instructions. - std::unordered_map instruction_pool_; - std::unordered_map computation_pool_; + // The map from the instruction/computation name to the + // instruction/computation itself and it's location. This does not own the + // pointers. + std::unordered_map> + instruction_pool_; + std::unordered_map> + computation_pool_; HloLexer lexer_; std::unique_ptr module_; + std::vector> computations_; const HloModuleConfig config_; std::vector error_; }; -bool HloParser::TokenError(StringPiece msg) { - const string error = - StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; token ", - TokKindToString(lexer_.GetKind()), "; ", msg); - VLOG(1) << "TokenError: " << error; - error_.push_back(error); +bool HloParser::Error(LocTy loc, StringPiece msg) { + auto line_col = lexer_.GetLineAndColumn(loc); + const unsigned line = line_col.first; + const unsigned col = line_col.second; + std::vector error_lines; + error_lines.push_back( + StrCat("was parsing ", line, ":", col, ": error: ", msg)); + error_lines.push_back(lexer_.GetLine(loc).ToString()); + error_lines.push_back(col == 0 ? "" : StrCat(string(col - 1, ' '), "^")); + + error_.push_back(tensorflow::str_util::Join(error_lines, "\n")); + VLOG(1) << "Error: " << error_.back(); return false; } +bool HloParser::TokenError(StringPiece msg) { + return Error(lexer_.GetLoc(), msg); +} + bool HloParser::Run() { lexer_.Lex(); return ParseHloModule(); @@ -241,46 +279,110 @@ bool HloParser::ParseHloModule() { // computations ::= (computation)+ bool HloParser::ParseComputations() { + HloComputation* entry_computation = nullptr; do { - if (!ParseComputation()) { + if (!ParseComputation(&entry_computation)) { return false; } } while (lexer_.GetKind() != TokKind::kEof); + + for (int i = 0; i < computations_.size(); i++) { + // If entry_computation is not nullptr, it means the computation it pointed + // to is marked with "ENTRY"; otherwise, no computation is marked with + // "ENTRY", and we use the last computation as the entry computation. We + // add the non-entry computations as embedded computations to the module. + if ((entry_computation != nullptr && + computations_[i].get() != entry_computation) || + (entry_computation == nullptr && i != computations_.size() - 1)) { + module_->AddEmbeddedComputation(std::move(computations_[i])); + continue; + } + auto computation = + module_->AddEntryComputation(std::move(computations_[i])); + // The parameters and result layouts were set to default layout. Here we + // set the layouts to what the hlo text says. + for (int p = 0; p < computation->num_parameters(); p++) { + const Shape& param_shape = computation->parameter_instruction(p)->shape(); + if (param_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_parameter_layout(p) + ->ResetLayout(param_shape.layout()); + } + } + const Shape& result_shape = computation->root_instruction()->shape(); + if (result_shape.has_layout()) { + module_->mutable_entry_computation_layout() + ->mutable_result_layout() + ->ResetLayout(result_shape.layout()); + } + } + return true; } -// computation ::= ('ENTRY')? name param_list '->' shape instruction_list -bool HloParser::ParseComputation() { +// computation ::= ('ENTRY')? name (param_list_to_shape)? instruction_list +bool HloParser::ParseComputation(HloComputation** entry_computation) { + LocTy maybe_entry_loc = lexer_.GetLoc(); const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); + string name; + LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name)) { return false; } auto builder = MakeUnique(name); + LocTy shape_loc = nullptr; Shape shape; + if (CanBeParamListToShape() && !ParseParamListToShape(&shape, &shape_loc)) { + return false; + } + string root_name; - if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || - !ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) { + if (!ParseInstructionList(builder.get(), &root_name)) { return false; } - HloInstruction* root = - tensorflow::gtl::FindPtrOrNull(instruction_pool_, root_name); + std::pair* root_node = + tensorflow::gtl::FindOrNull(instruction_pool_, root_name); // This means some instruction was marked as ROOT but we didn't find it in the // pool, which should not happen. - if (!root_name.empty() && root == nullptr) { + if (!root_name.empty() && root_node == nullptr) { LOG(FATAL) << "instruction " << root_name << " was marked as ROOT but the parser has not seen it before"; } + + HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; // Now root can be either an existing instruction or a nullptr. If it's a // nullptr, the implementation of Builder will set the last instruction as // root instruction. - HloComputation* computation = - is_entry_computation - ? module_->AddEntryComputation(builder->Build(root)) - : module_->AddEmbeddedComputation(builder->Build(root)); - return AddComputation(name, computation); + computations_.emplace_back(builder->Build(root)); + HloComputation* computation = computations_.back().get(); + + if (!root) { + root = computation->root_instruction(); + } else { + CHECK_EQ(root, computation->root_instruction()); + } + + // If param_list_to_shape was present, check compatibility. + if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + return Error( + shape_loc, + StrCat("Shape of computation ", name, ", ", + ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + root_name, ", ", ShapeUtil::HumanString(root->shape()))); + } + + if (is_entry_computation) { + if (*entry_computation != nullptr) { + return Error(maybe_entry_loc, "expects only one ENTRY"); + } + *entry_computation = computation; + } + + return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' @@ -307,13 +409,21 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, Shape shape; HloOpcode opcode; std::vector operands; + + LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); + + const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || !ParseToken(TokKind::kEqual, "expects '=' in instruction") || !ParseShape(&shape) || !ParseOpcode(&opcode)) { return false; } + if (is_root) { + if (!root_name->empty()) { + return Error(maybe_root_loc, "one computation should have only one ROOT"); + } *root_name = name; } @@ -395,7 +505,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - case HloOpcode::kDot: case HloOpcode::kMaximum: case HloOpcode::kMinimum: case HloOpcode::kPower: @@ -444,12 +553,11 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands[0])); + HloInstruction::CreateCrossReplicaSum(shape, operands)); break; } case HloOpcode::kReshape: { @@ -590,6 +698,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); break; } + case HloOpcode::kFft: { + optional fft_type; + optional> fft_length; + attrs["fft_type"] = {/*required=*/true, AttrTy::kFftType, &fft_type}; + attrs["fft_length"] = {/*required=*/true, AttrTy::kBracedInt64List, + &fft_length}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateFft( + shape, operands[0], *fft_type, *fft_length)); + break; + } case HloOpcode::kBroadcast: { optional> broadcast_dimensions; attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, @@ -813,18 +935,113 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, return false; } instruction = builder->AddInstruction(HloInstruction::CreateOutfeed( - shape, operands[0], config ? *config : "")); + operands[0]->shape(), operands[0], config ? *config : "")); + break; + } + case HloOpcode::kRng: { + optional distribution; + attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, + &distribution}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRng(shape, *distribution, operands)); + break; + } + case HloOpcode::kReducePrecision: { + optional exponent_bits; + optional mantissa_bits; + attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, + &exponent_bits}; + attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, + &mantissa_bits}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateReducePrecision( + shape, operands[0], static_cast(*exponent_bits), + static_cast(*mantissa_bits))); + break; + } + case HloOpcode::kConditional: { + optional true_computation; + optional false_computation; + attrs["true_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &true_computation}; + attrs["false_computation"] = {/*required=*/true, AttrTy::kHloComputation, + &false_computation}; + if (!ParseOperands(&operands, /*expected_size=*/3) || + !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateConditional( + shape, /*pred=*/operands[0], + /*true_computation_arg=*/operands[1], *true_computation, + /*false_computation_arg=*/operands[2], *false_computation)); + break; + } + case HloOpcode::kCustomCall: { + optional custom_call_target; + attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, + &custom_call_target}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target)); + break; + } + case HloOpcode::kDot: { + optional> lhs_contracting_dims; + attrs["lhs_contracting_dims"] = { + /*required=*/false, AttrTy::kBracedInt64List, &lhs_contracting_dims}; + optional> rhs_contracting_dims; + attrs["rhs_contracting_dims"] = { + /*required=*/false, AttrTy::kBracedInt64List, &rhs_contracting_dims}; + optional> lhs_batch_dims; + attrs["lhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, + &lhs_batch_dims}; + optional> rhs_batch_dims; + attrs["rhs_batch_dims"] = {/*required=*/false, AttrTy::kBracedInt64List, + &rhs_batch_dims}; + + if (!ParseOperands(&operands, /*expected_size=*/2) || + !ParseAttributes(attrs)) { + return false; + } + + DotDimensionNumbers dnum; + if (lhs_contracting_dims) { + *dnum.mutable_lhs_contracting_dimensions() = { + lhs_contracting_dims->begin(), lhs_contracting_dims->end()}; + } + if (rhs_contracting_dims) { + *dnum.mutable_rhs_contracting_dimensions() = { + rhs_contracting_dims->begin(), rhs_contracting_dims->end()}; + } + if (lhs_batch_dims) { + *dnum.mutable_lhs_batch_dimensions() = {lhs_batch_dims->begin(), + lhs_batch_dims->end()}; + } + if (rhs_batch_dims) { + *dnum.mutable_rhs_batch_dimensions() = {rhs_batch_dims->begin(), + rhs_batch_dims->end()}; + } + + instruction = builder->AddInstruction( + HloInstruction::CreateDot(shape, operands[0], operands[1], dnum)); break; } - case HloOpcode::kConditional: - case HloOpcode::kCustomCall: - case HloOpcode::kReducePrecision: - case HloOpcode::kRng: case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } + instruction->set_name(name); + // Add common attrs (sharding, control predecessors) to the instruction, if // they were seen. if (sharding) { @@ -835,15 +1052,15 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, for (auto* pre : *predecessors) { Status status = pre->AddControlDependencyTo(instruction); if (!status.ok()) { - return TokenError(StrCat("error adding control dependency for: ", name, - " status: ", status.ToString())); + return Error(name_loc, StrCat("error adding control dependency for: ", + name, " status: ", status.ToString())); } } } if (metadata) { instruction->set_metadata(*metadata); } - return AddInstruction(name, instruction); + return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) // ::= '{' (single_sharding | tuple_sharding) '}' @@ -889,6 +1106,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return false; } + LocTy loc = lexer_.GetLoc(); bool maximal = false; bool replicated = false; std::vector devices; @@ -956,34 +1174,35 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, if (replicated) { if (!devices.empty()) { - return TokenError( - "replicated shardings should not have any devices assigned"); + return Error(loc, + "replicated shardings should not have any devices assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError( - "replicated shardings should not have any tile shape set"); + return Error(loc, + "replicated shardings should not have any tile shape set"); } sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { - return TokenError( - "maximal shardings should have exactly one device assigned"); + return Error(loc, + "maximal shardings should have exactly one device assigned"); } if (!ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("maximal shardings should not have any tile shape set"); + return Error(loc, "maximal shardings should not have any tile shape set"); } sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); sharding->add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { - return TokenError( - "non-maximal shardings must have more than one device assigned"); + return Error( + loc, "non-maximal shardings must have more than one device assigned"); } if (ShapeUtil::Equal(tile_shape, Shape())) { - return TokenError("non-maximal shardings should have a tile shape set"); + return Error(loc, "non-maximal shardings should have a tile shape set"); } if (tile_assignment_dimensions.empty()) { - return TokenError( + return Error( + loc, "non-maximal shardings must have a tile assignment list including " "dimensions"); } @@ -1008,22 +1227,23 @@ bool HloParser::ParseInstructionNames( "expects '{' at the beginning of instruction name list")) { return false; } + LocTy loc = lexer_.GetLoc(); do { string name; if (!ParseName(&name)) { - return TokenError("expects a instruction name"); + return Error(loc, "expects a instruction name"); } - HloInstruction* instr = - tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + std::pair* instr = + tensorflow::gtl::FindOrNull(instruction_pool_, name); if (!instr) { return TokenError( Printf("instruction '%s' is not defined", name.c_str())); } - instructions->push_back(instr); + instructions->push_back(instr->first); } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, - "expects '}' at the end of control instructions"); + "expects '}' at the end of instruction name list"); } bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, @@ -1058,6 +1278,8 @@ bool HloParser::SetValueInLiteral(double value, int64 linear_index, switch (shape.element_type()) { case F16: return SetValueInLiteralHelper(value, linear_index, literal); + case BF16: + return SetValueInLiteralHelper(value, linear_index, literal); case F32: return SetValueInLiteralHelper(value, linear_index, literal); case F64: @@ -1096,7 +1318,8 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, (std::numeric_limits::infinity() == value || -std::numeric_limits::infinity() == value))) { // Skip range checking for non-finite value. - } else if (literal->shape().element_type() == F16) { + } else if (literal->shape().element_type() == F16 || + literal->shape().element_type() == BF16) { if (value > kF16max || value < -kF16max) { return TokenError(StrCat( "value ", value, " is out of range for literal's primitive type ", @@ -1112,7 +1335,7 @@ bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, PrimitiveType_Name(literal->shape().element_type()))); } - literal->GetMutableArraySlice().at(linear_index) = + literal->data().at(linear_index) = static_cast(value); return true; } @@ -1179,9 +1402,19 @@ bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, // non_tuple // ::= rank01 // ::= rank2345 -// rank2345 ::= shape nested_array +// rank2345 ::= shape sparse_or_nested_array bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, const Shape& shape) { + if (LayoutUtil::IsSparseArray(shape)) { + return ParseSparseLiteral(literal, shape); + } + + CHECK(LayoutUtil::IsDenseArray(shape)); + return ParseDenseLiteral(literal, shape); +} + +bool HloParser::ParseDenseLiteral(std::unique_ptr* literal, + const Shape& shape) { const int64 rank = ShapeUtil::Rank(shape); if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { return false; @@ -1282,26 +1515,28 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, } lexer_.Lex(); } else if (primitive_util::IsIntegralType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); int64 value; if (!ParseInt64(&value)) { - return TokenError(StrCat("expects integer for primitive type: ", + return Error(loc, StrCat("expects integer for primitive type: ", PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; } } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + LocTy loc = lexer_.GetLoc(); double value; if (!ParseDouble(&value)) { - return TokenError( - StrCat("expect floating point value for primitive type: ", - PrimitiveType_Name(shape.element_type()))); + return Error( + loc, StrCat("expect floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); } if (!SetValueInLiteral(value, linear_index++, literal->get())) { return false; } } else { - return TokenError(StrCat("unsupported premitive type ", + return TokenError(StrCat("unsupported primitive type ", PrimitiveType_Name(shape.element_type()))); } break; @@ -1313,11 +1548,147 @@ bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, return true; } +bool HloParser::ParseSparseLiteral(std::unique_ptr* literal, + const Shape& shape) { + if (!EatShapeAndCheckCompatible(shape)) { + return false; + } + + switch (shape.element_type()) { + case PRED: + return ParseSparseLiteralHelper(literal, shape); + case S8: + return ParseSparseLiteralHelper(literal, shape); + case S16: + return ParseSparseLiteralHelper(literal, shape); + case S32: + return ParseSparseLiteralHelper(literal, shape); + case S64: + return ParseSparseLiteralHelper(literal, shape); + case U8: + return ParseSparseLiteralHelper(literal, shape); + case U16: + return ParseSparseLiteralHelper(literal, shape); + case U32: + return ParseSparseLiteralHelper(literal, shape); + case U64: + return ParseSparseLiteralHelper(literal, shape); + case F16: + return ParseSparseLiteralHelper(literal, shape); + case F32: + return ParseSparseLiteralHelper(literal, shape); + case BF16: + return ParseSparseLiteralHelper(literal, shape); + case F64: + return ParseSparseLiteralHelper(literal, shape); + default: + return Error(lexer_.GetLoc(), + StrCat("invalid primitive type for sparse literal: ", + PrimitiveType_Name(shape.element_type()))); + } +} + +template +bool HloParser::ParseSparseLiteralHelper(std::unique_ptr* literal, + const Shape& shape) { + std::vector index; + + int64 rank = ShapeUtil::Rank(shape); + + *literal = MakeUnique(shape); + + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of a sparse literal")) { + return false; + } + + for (;;) { + if (lexer_.GetKind() == TokKind::kRbrace) { + lexer_.Lex(); + break; + } + + LocTy index_loc = lexer_.GetLoc(); + index.clear(); + if (lexer_.GetKind() == TokKind::kInt) { + int64 single_index = lexer_.GetInt64Val(); + lexer_.Lex(); + if (rank != 1) { + return Error( + index_loc, + StrCat("invalid single-dimensional index for shape with rank ", + rank, ": ", single_index)); + } + index.push_back(single_index); + } else { + if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kComma, + &index)) { + return false; + } + if (index.size() != rank) { + return Error( + index_loc, + StrCat("invalid multi-dimension index for shape with rank ", rank, + ": [", tensorflow::str_util::Join(index, ", "), "]")); + } + } + if (!ParseToken(TokKind::kColon, + "expects ':' after after the sparse array index and before " + "the sparse array value")) { + return false; + } + LocTy value_loc = lexer_.GetLoc(); + LiteralNativeT value; + if (lexer_.GetKind() == TokKind::kw_true || + lexer_.GetKind() == TokKind::kw_false) { + value = static_cast(lexer_.GetKind() == TokKind::kw_true); + lexer_.Lex(); + } else if (primitive_util::IsIntegralType(shape.element_type())) { + int64 value_s64; + if (!ParseInt64(&value_s64)) { + return Error(value_loc, + StrCat("expects integer for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + value = static_cast(value_s64); + } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + double value_f64; + if (!ParseDouble(&value_f64)) { + return Error(value_loc, + StrCat("expects floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + value = static_cast(value_f64); + } else { + LOG(FATAL) << "Unexpected element type: " + << PrimitiveType_Name(shape.element_type()); + } + if (lexer_.GetKind() != TokKind::kRbrace && + !ParseToken(TokKind::kComma, + "expects ',' separator between sparse array elements")) { + return false; + } + + if ((*literal)->sparse_element_count() + 1 == + LayoutUtil::MaxSparseElements(shape.layout())) { + return Error( + lexer_.GetLoc(), + StrCat("number of sparse elements exceeds maximum for layout: ", + ShapeUtil::HumanStringWithLayout(shape))); + } + + (*literal)->AppendSparseElement(index, value); + } + + (*literal)->SortSparseElements(); + return true; +} + // operands ::= '(' operands1 ')' // operands1 // ::= /*empty*/ // ::= operand (, operand)* -// operand ::= shape name +// operand ::= (shape)? name bool HloParser::ParseOperands(std::vector* operands) { if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { @@ -1327,17 +1698,23 @@ bool HloParser::ParseOperands(std::vector* operands) { // empty } else { do { - Shape shape; + LocTy loc = lexer_.GetLoc(); string name; - if (!ParseShape(&shape) || !ParseName(&name)) { + if (CanBeShape()) { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + } + if (!ParseName(&name)) { return false; } - HloInstruction* instruction = - tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + std::pair* instruction = + tensorflow::gtl::FindOrNull(instruction_pool_, name); if (!instruction) { - return TokenError(StrCat("instruction does not exist: ", name)); + return Error(loc, StrCat("instruction does not exist: ", name)); } - operands->push_back(instruction); + operands->push_back(instruction->first); } while (EatIfPresent(TokKind::kComma)); } return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); @@ -1345,11 +1722,12 @@ bool HloParser::ParseOperands(std::vector* operands) { bool HloParser::ParseOperands(std::vector* operands, const int expected_size) { + LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; } if (expected_size != operands->size()) { - return TokenError(StrCat("expects ", expected_size, " operands, but has ", + return Error(loc, StrCat("expects ", expected_size, " operands, but has ", operands->size(), " operands")); } return true; @@ -1358,6 +1736,7 @@ bool HloParser::ParseOperands(std::vector* operands, // sub_attributes ::= '{' (','? attribute)* '}' bool HloParser::ParseSubAttributes( const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { return false; } @@ -1376,7 +1755,7 @@ bool HloParser::ParseSubAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("sub-attribute %s is expected but not seen", + return Error(loc, Printf("sub-attribute %s is expected but not seen", attr_it.first.c_str())); } } @@ -1386,6 +1765,7 @@ bool HloParser::ParseSubAttributes( // attributes ::= (',' attribute)* bool HloParser::ParseAttributes( const std::unordered_map& attrs) { + LocTy loc = lexer_.GetLoc(); std::unordered_set seen_attrs; while (EatIfPresent(TokKind::kComma)) { if (!ParseAttributeHelper(attrs, &seen_attrs)) { @@ -1396,7 +1776,7 @@ bool HloParser::ParseAttributes( for (const auto& attr_it : attrs) { if (attr_it.second.required && seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("attribute %s is expected but not seen", + return Error(loc, Printf("attribute %s is expected but not seen", attr_it.first.c_str())); } } @@ -1406,21 +1786,23 @@ bool HloParser::ParseAttributes( bool HloParser::ParseAttributeHelper( const std::unordered_map& attrs, std::unordered_set* seen_attrs) { + LocTy loc = lexer_.GetLoc(); string name; if (!ParseAttributeName(&name)) { - return TokenError("error parsing attributes"); + return Error(loc, "error parsing attributes"); } VLOG(1) << "Parsing attribute " << name; if (!seen_attrs->insert(name).second) { - return TokenError(Printf("attribute %s already exists", name.c_str())); + return Error(loc, Printf("attribute %s already exists", name.c_str())); } auto attr_it = attrs.find(name); if (attr_it == attrs.end()) { - return TokenError(Printf("unexpected attribute %s", name.c_str())); + return Error(loc, Printf("unexpected attribute %s", name.c_str())); } AttrTy attr_type = attr_it->second.attr_type; void* attr_out_ptr = attr_it->second.result; bool success = [&] { + LocTy attr_loc = lexer_.GetLoc(); switch (attr_type) { case AttrTy::kInt64: { int64 result; @@ -1436,7 +1818,7 @@ bool HloParser::ParseAttributeHelper( return false; } if (result != static_cast(result)) { - return TokenError("value out of range for int32"); + return Error(attr_loc, "value out of range for int32"); } static_cast*>(attr_out_ptr) ->emplace(static_cast(result)); @@ -1449,7 +1831,7 @@ bool HloParser::ParseAttributeHelper( } if (result > std::numeric_limits::max() || result < std::numeric_limits::lowest()) { - return TokenError("value out of range for float"); + return Error(attr_loc, "value out of range for float"); } static_cast*>(attr_out_ptr) ->emplace(static_cast(result)); @@ -1463,6 +1845,14 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kFftType: { + FftType result; + if (!ParseFftType(&result)) { + return false; + } + static_cast*>(attr_out_ptr)->emplace(result); + return true; + } case AttrTy::kWindow: { Window result; if (!ParseWindow(&result)) { @@ -1548,23 +1938,35 @@ bool HloParser::ParseAttributeHelper( static_cast*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kDistribution: { + RandomDistribution result; + if (!ParseRandomDistribution(&result)) { + return false; + } + static_cast*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { - return TokenError(Printf("error parsing attribute %s", name.c_str())); + return Error(loc, Printf("error parsing attribute %s", name.c_str())); } return true; } bool HloParser::ParseComputationName(HloComputation** value) { string name; + LocTy loc = lexer_.GetLoc(); if (!ParseName(&name)) { - return TokenError("expects computation name"); + return Error(loc, "expects computation name"); } - *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); - if (*value == nullptr) { - return TokenError(StrCat("computation does not exist: ", name)); + std::pair* computation = + tensorflow::gtl::FindOrNull(computation_pool_, name); + if (computation == nullptr) { + return Error(loc, StrCat("computation does not exist: ", name)); } + *value = computation->first; return true; } @@ -1572,6 +1974,7 @@ bool HloParser::ParseComputationName(HloComputation** value) { // The subattributes can appear in any order. 'size=' is required, others are // optional. bool HloParser::ParseWindow(Window* window) { + LocTy loc = lexer_.GetLoc(); if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { return false; } @@ -1581,10 +1984,12 @@ bool HloParser::ParseWindow(Window* window) { std::vector> pad; std::vector lhs_dilate; std::vector rhs_dilate; + std::vector rhs_reversal; while (lexer_.GetKind() != TokKind::kRbrace) { + LocTy attr_loc = lexer_.GetLoc(); string field_name; if (!ParseAttributeName(&field_name)) { - return TokenError("expects sub-attributes in window"); + return Error(attr_loc, "expects sub-attributes in window"); } bool ok = [&] { if (field_name == "size") { @@ -1602,7 +2007,10 @@ bool HloParser::ParseWindow(Window* window) { if (field_name == "pad") { return ParseWindowPad(&pad); } - return TokenError(StrCat("unexpected attribute name: ", field_name)); + if (field_name == "rhs_reversal") { + return ParseDxD("rhs_reversal", &rhs_reversal); + } + return Error(attr_loc, StrCat("unexpected attribute name: ", field_name)); }(); if (!ok) { return false; @@ -1610,20 +2018,20 @@ bool HloParser::ParseWindow(Window* window) { } if (size.empty()) { - return TokenError( - "sub-attribute 'size=' is required in the window attribute"); + return Error(loc, + "sub-attribute 'size=' is required in the window attribute"); } if (!stride.empty() && stride.size() != size.size()) { - return TokenError("expects 'stride=' has the same size as 'size='"); + return Error(loc, "expects 'stride=' has the same size as 'size='"); } if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { - return TokenError("expects 'lhs_dilate=' has the same size as 'size='"); + return Error(loc, "expects 'lhs_dilate=' has the same size as 'size='"); } if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { - return TokenError("expects 'rhs_dilate=' has the same size as 'size='"); + return Error(loc, "expects 'rhs_dilate=' has the same size as 'size='"); } if (!pad.empty() && pad.size() != size.size()) { - return TokenError("expects 'pad=' has the same size as 'size='"); + return Error(loc, "expects 'pad=' has the same size as 'size='"); } for (int i = 0; i < size.size(); i++) { @@ -1638,6 +2046,8 @@ bool HloParser::ParseWindow(Window* window) { lhs_dilate.empty() ? 1 : lhs_dilate[i]); window->mutable_dimensions(i)->set_window_dilation( rhs_dilate.empty() ? 1 : rhs_dilate[i]); + window->mutable_dimensions(i)->set_window_reversal( + rhs_reversal.empty() ? false : (rhs_reversal[i] == 1)); } return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); } @@ -1769,7 +2179,7 @@ bool HloParser::ParseConvolutionDimensionNumbers( // // {[2:3:4], [5:6:7], [8:9]} // -// The the parsed result will be: +// The parsed result will be: // // {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} // @@ -1783,20 +2193,19 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } do { + LocTy loc = lexer_.GetLoc(); ranges.emplace_back(); if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, &ranges.back())) { return false; } - } while (EatIfPresent(TokKind::kComma)); - - for (const auto& range : ranges) { + const auto& range = ranges.back(); if (range.size() != 2 && range.size() != 3) { - return TokenError(Printf( - "expects [start:limit:step] or [start:limit], but sees %ld elements.", - range.size())); + return Error(loc, Printf("expects [start:limit:step] or [start:limit], " + "but sees %ld elements.", + range.size())); } - } + } while (EatIfPresent(TokKind::kComma)); for (const auto& range : ranges) { result->starts.push_back(range[0]); @@ -1832,6 +2241,19 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end, end, StrCat("expects an int64 list to end with ", TokKindToString(end))); } +// param_list_to_shape ::= param_list '->' shape +bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) { + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) { + return false; + } + *shape_loc = lexer_.GetLoc(); + return ParseShape(shape); +} + +bool HloParser::CanBeParamListToShape() { + return lexer_.GetKind() == TokKind::kLparen; +} + // param_list ::= '(' param_list1 ')' // param_list1 // ::= /*empty*/ @@ -1848,8 +2270,8 @@ bool HloParser::ParseParamList() { } else { do { Shape shape; - if (!ParseToken(TokKind::kName, "expects name in parameter") || - !ParseShape(&shape)) { + string name; + if (!ParseName(&name) || !ParseShape(&shape)) { return false; } } while (EatIfPresent(TokKind::kComma)); @@ -1888,9 +2310,17 @@ bool HloParser::ParseShape(Shape* result) { return true; } +bool HloParser::CanBeShape() { + // A non-tuple shape starts with a kShape token; a tuple shape starts with + // '('. + return lexer_.GetKind() == TokKind::kShape || + lexer_.GetKind() == TokKind::kLparen; +} + bool HloParser::ParseName(string* result) { VLOG(1) << "ParseName"; - if (lexer_.GetKind() != TokKind::kName) { + if (lexer_.GetKind() != TokKind::kIdent && + lexer_.GetKind() != TokKind::kName) { return TokenError("expects name"); } *result = lexer_.GetStrVal(); @@ -1918,15 +2348,16 @@ bool HloParser::ParseString(string* result) { } bool HloParser::ParseDxD(const string& name, std::vector* result) { + LocTy loc = lexer_.GetLoc(); if (!result->empty()) { - return TokenError( - Printf("sub-attribute '%s=' already exists", name.c_str())); + return Error(loc, + Printf("sub-attribute '%s=' already exists", name.c_str())); } // 1D if (lexer_.GetKind() == TokKind::kInt) { int64 number; if (!ParseInt64(&number)) { - return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str())); + return Error(loc, Printf("expects sub-attribute '%s=i'", name.c_str())); } result->push_back(number); return true; @@ -1935,8 +2366,8 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { if (lexer_.GetKind() == TokKind::kDxD) { string str = lexer_.GetStrVal(); if (!SplitAndParseAsInts(str, 'x', result)) { - return TokenError( - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); + return Error(loc, + Printf("expects sub-attribute '%s=ixj...'", name.c_str())); } lexer_.Lex(); return true; @@ -1945,8 +2376,9 @@ bool HloParser::ParseDxD(const string& name, std::vector* result) { } bool HloParser::ParseWindowPad(std::vector>* pad) { + LocTy loc = lexer_.GetLoc(); if (!pad->empty()) { - return TokenError("sub-attribute 'pad=' already exists"); + return Error(loc, "sub-attribute 'pad=' already exists"); } if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); @@ -1957,8 +2389,8 @@ bool HloParser::ParseWindowPad(std::vector>* pad) { std::vector low_high; if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || low_high.size() != 2) { - return TokenError( - "expects padding_low and padding_high separated by '_'"); + return Error(loc, + "expects padding_low and padding_high separated by '_'"); } pad->push_back(low_high); } @@ -1974,15 +2406,16 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { if (lexer_.GetKind() != TokKind::kPad) { return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); } + LocTy loc = lexer_.GetLoc(); string str = lexer_.GetStrVal(); std::vector padding_str = Split(str, 'x'); for (const auto& padding_dim_str : padding_str) { std::vector padding_dim; if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || (padding_dim.size() != 2 && padding_dim.size() != 3)) { - return TokenError( - "expects padding config pattern like 'low_high_interior' or " - "'low_high'"); + return Error(loc, + "expects padding config pattern like 'low_high_interior' or " + "'low_high'"); } auto* dim = padding->add_dimensions(); dim->set_edge_padding_low(padding_dim[0]); @@ -2024,20 +2457,64 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; - if (lexer_.GetKind() != TokKind::kOpcode) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } - *result = lexer_.GetOpcodeVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToHloOpcode(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects opcode but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseFftType(FftType* result) { + VLOG(1) << "ParseFftType"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects fft type"); + } + string val = lexer_.GetStrVal(); + if (!FftType_Parse(val, result) || !FftType_IsValid(*result)) { + return TokenError(Printf("expects fft type but sees: %s", val.c_str())); + } lexer_.Lex(); return true; } bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; - if (lexer_.GetKind() != TokKind::kFusionKind) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fusion kind"); } - *result = lexer_.GetFusionKindVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToFusionKind(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseRandomDistribution(RandomDistribution* result) { + VLOG(1) << "ParseRandomDistribution"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToRandomDistribution(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects random distribution but sees: %s, error: %s", + val.c_str(), status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } @@ -2103,20 +2580,24 @@ bool HloParser::EatIfPresent(TokKind kind) { return true; } -bool HloParser::AddInstruction(const string& name, - HloInstruction* instruction) { - auto result = instruction_pool_.insert({name, instruction}); +bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, + LocTy name_loc) { + auto result = instruction_pool_.insert({name, {instruction, name_loc}}); if (!result.second) { - return TokenError(StrCat("instruction already exists: ", name)); + Error(name_loc, StrCat("instruction already exists: ", name)); + return Error(/*loc=*/result.first->second.second, + "instruction previously defined here"); } return true; } -bool HloParser::AddComputation(const string& name, - HloComputation* computation) { - auto result = computation_pool_.insert({name, computation}); +bool HloParser::AddComputation(const string& name, HloComputation* computation, + LocTy name_loc) { + auto result = computation_pool_.insert({name, {computation, name_loc}}); if (!result.second) { - return TokenError(StrCat("computation already exists: ", name)); + Error(name_loc, StrCat("computation already exists: ", name)); + return Error(/*loc=*/result.first->second.second, + "computation previously defined here"); } return true; } @@ -2127,7 +2608,7 @@ StatusOr> Parse(StringPiece str, const HloModuleConfig& config) { HloParser parser(str, config); if (!parser.Run()) { - return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + return InvalidArgument("Syntax error:\n%s", parser.GetError().c_str()); } return parser.ConsumeHloModule(); } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 90cdb87a1ebcf59d291eebd52963a130f19f4403..b8c6b59204f897c7dc07b846370b5b776a19a808 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -25,7 +25,6 @@ namespace tools { namespace { using tensorflow::StringPiece; -using tensorflow::strings::StrCat; struct TestData { string test_name; @@ -46,7 +45,7 @@ std::vector CreateTestCases() { // ax + y { "AxpyParam", -R"(HloModule axpy_module: +R"(HloModule axpy_module ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { %alpha = f32[] parameter(0) @@ -62,7 +61,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { // pred constant { "ConstantPred", -R"(HloModule constant_pred_module: +R"(HloModule constant_pred_module ENTRY %constant_pred () -> pred[] { ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} @@ -73,7 +72,7 @@ ENTRY %constant_pred () -> pred[] { // s32 constant { "ConstantS32", -R"(HloModule constant_s32_module: +R"(HloModule constant_s32_module ENTRY %constant_s32 () -> s32[] { ROOT %constant = s32[] constant(-42) @@ -84,7 +83,7 @@ ENTRY %constant_s32 () -> s32[] { // f32 constant, but the value is not a decimal { "ConstantF32", -R"(HloModule ConstantF32_module: +R"(HloModule ConstantF32_module ENTRY %ConstantF32.v4 () -> f32[] { ROOT %constant = f32[] constant(42) @@ -95,7 +94,7 @@ ENTRY %ConstantF32.v4 () -> f32[] { // f32 constant, rank 1 empty array. { "ConstantF32R1Empty", -R"(HloModule ConstantF32Empty_module: +R"(HloModule ConstantF32Empty_module ENTRY %ConstantF32Empty.v4 () -> f32[0] { ROOT %constant = f32[0]{0} constant({}) @@ -106,7 +105,7 @@ ENTRY %ConstantF32Empty.v4 () -> f32[0] { // f32 constant, rank 4 empty array. { "ConstantF32R4Empty", -R"(HloModule ConstantF32R4Empty_module: +R"(HloModule ConstantF32R4Empty_module ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { ROOT %constant = f32[2,0,4,3]{3,2,1,0} constant(f32[2,0,4,3] { { /*i0=0*/ }, { /*i0=1*/ } }) @@ -117,7 +116,7 @@ ENTRY %ConstantF32R4Empty.v4 () -> f32[2,0,4,3] { // constant 4D { "Constant4D", -R"(HloModule Small_3x2x1x1_module: +R"(HloModule Small_3x2x1x1_module ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) @@ -128,7 +127,7 @@ ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { // non-finite constants: nan, inf, -inf { "ConstantNonFinite", -R"(HloModule IsFiniteR1F32s_module: +R"(HloModule IsFiniteR1F32s_module ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf}) @@ -140,18 +139,29 @@ ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { // constant f16 { "ConstantF16", -R"(HloModule ConstantF16_module: +R"(HloModule ConstantF16_module ENTRY %ConstantF16.v4 () -> f16[] { ROOT %constant = f16[] constant(500) } +)" +}, +// bf16 +{ +"BF16", +R"(HloModule BF16 + +ENTRY %BF16.v4 () -> bf16[] { + ROOT %constant = bf16[] constant(500) +} + )" }, // constant + constant { "AddConstants", -R"(HloModule add_constants_module: +R"(HloModule add_constants_module ENTRY %add_constants () -> f32[] { %constant = f32[] constant(3.14) @@ -163,7 +173,7 @@ ENTRY %add_constants () -> f32[] { // tuple constant { "TupleConstant", -R"(HloModule TupleConstant_module: +R"(HloModule TupleConstant_module ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) @@ -174,7 +184,7 @@ ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { // v1 > v2 ? v1 : v2 { "SelectR1F32", -R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} @@ -188,7 +198,7 @@ ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f3 // empty tuple { "EmptyTupleCreate", -R"(HloModule EmptyTupleCreate_module: +R"(HloModule EmptyTupleCreate_module ENTRY %EmptyTupleCreate.v1 () -> () { ROOT %tuple = () tuple() @@ -199,7 +209,7 @@ ENTRY %EmptyTupleCreate.v1 () -> () { // tuple { "TupleCreate", -R"(HloModule TupleCreate_module: +R"(HloModule TupleCreate_module ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { %v1 = f32[] parameter(0) @@ -212,7 +222,7 @@ ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f }, { "ShardedTupleCreate", -R"(HloModule ShardedTupleCreate_module: +R"(HloModule ShardedTupleCreate_module ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { %v1 = f32[] parameter(0) @@ -227,7 +237,7 @@ ENTRY %ShardedTupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f3 // while (result < 5) { result = result + 1; } { "WhileWithScalarS32Result", -R"(HloModule WhileWithScalarS32Result_module: +R"(HloModule WhileWithScalarS32Result_module %body.v3 (prev.1: s32[]) -> s32[] { %constant = s32[] constant(1) @@ -251,7 +261,7 @@ ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { // send and recv { "SendRecv", -R"(HloModule TwoSendRecvBothWayRecvFist_module: +R"(HloModule TwoSendRecvBothWayRecvFist_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15, sharding={maximal device=1} @@ -266,7 +276,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { // get-tuple-element { "GetTupleElement", -R"(HloModule GetTupleElement_module: +R"(HloModule GetTupleElement_module ENTRY %GetTupleElement.v4 () -> s32[2,3] { %constant = f32[3]{0} constant({1, 2, 3}) @@ -280,7 +290,7 @@ ENTRY %GetTupleElement.v4 () -> s32[2,3] { // call { "Call", -R"(HloModule CallR0F32IdentityScalar_module: +R"(HloModule CallR0F32IdentityScalar_module %Identity.v1 (x: f32[]) -> f32[] { ROOT %x = f32[] parameter(0) @@ -296,7 +306,7 @@ ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { // reduce window { "ReduceWindow", -R"(HloModule R4UnitWindow_module: +R"(HloModule R4UnitWindow_module %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { %lhs = f32[] parameter(0) @@ -315,7 +325,7 @@ ENTRY %R4UnitWindow.v3 (operand: f32[13,12,8,15]) -> f32[13,3,8,15] { // reduce window on scalar { "ReduceWindowScalar", -R"(HloModule reduce_window_scalar: +R"(HloModule reduce_window_scalar %add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { %lhs = f32[] parameter(0) @@ -334,7 +344,7 @@ ENTRY %R4UnitWindowScalar () -> f32[] { // convolution { "Convolution", -R"(HloModule Convolve1D1Window_0_module: +R"(HloModule Convolve1D1Window_0_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -348,7 +358,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 // convolution rank 2 { "ConvolutionR2", -R"(HloModule ConvolveR2_module: +R"(HloModule ConvolveR2_module ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { %input = f32[1,2]{1,0} parameter(0) @@ -356,12 +366,25 @@ ENTRY %ConvolveR2.v3 (input: f32[1,2], filter: f32[1,1]) -> f32[1,2] { ROOT %convolution = f32[1,2]{0,1} convolution(f32[1,2]{1,0} %input, f32[1,1]{1,0} %filter), dim_labels=bf_io->bf } +)" +}, +// convolution backward +{ +"ConvolutionBackward", +R"(HloModule ConvolveBackward_module + +ENTRY %ConvolveBackward (input: f32[128,7,7,512], filter: f32[3,3,512,512]) -> f32[128,14,14,512] { + %input = f32[128,7,7,512]{0,3,2,1} parameter(0) + %filter = f32[3,3,512,512]{3,2,1,0} parameter(1) + ROOT %convolution-base-dilated = f32[128,14,14,512]{0,3,2,1} convolution(f32[128,7,7,512]{0,3,2,1} %input, f32[3,3,512,512]{3,2,1,0} %filter), window={size=3x3 pad=1_2x1_2 lhs_dilate=2x2 rhs_reversal=1x1}, dim_labels=b01f_01oi->b01f +} + )" }, // reverse(constant) { "Reverse4D", -R"(HloModule Reverse4DFloatArrayOnDim01_module: +R"(HloModule Reverse4DFloatArrayOnDim01_module ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { %constant = f32[4,3,2,1]{0,1,2,3} constant(f32[4,3,2,1] { { /*i0=0*/ { /*i1=0*/ {1}, {2} }, { /*i1=1*/ {3}, {4} }, { /*i1=2*/ {5}, {6} } }, { /*i0=1*/ { /*i1=0*/ {7}, {8} }, { /*i1=1*/ {9}, {10} }, { /*i1=2*/ {11}, {12} } }, { /*i0=2*/ { /*i1=0*/ {13}, {14} }, { /*i1=1*/ {15}, {16} }, { /*i1=2*/ {17}, {18} } }, { /*i0=3*/ { /*i1=0*/ {19}, {20} }, { /*i1=1*/ {21}, {22} }, { /*i1=2*/ {23}, {24} } } }) @@ -373,7 +396,7 @@ ENTRY %Reverse4DFloatArrayOnDim01.v2 () -> f32[4,3,2,1] { // concat { "Concat", -R"(HloModule Concat2x3With2x5_module: +R"(HloModule Concat2x3With2x5_module ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { %constant = f32[2,3]{1,0} constant(f32[2,3] { { 0, 1, 2 }, { 1000, 1001, 1002 } }) @@ -381,50 +404,12 @@ ENTRY %Concat2x3With2x5.v3 () -> f32[2,8] { ROOT %concatenate = f32[2,8]{1,0} concatenate(f32[2,3]{1,0} %constant, f32[2,5]{1,0} %constant.1), dimensions={1} } -)" -}, -// map -{ -"Map", -R"(HloModule MapBinaryAdder_module: - -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) -} - -ENTRY %MapBinaryAdder.v3 (param0: f32[4], param1: f32[4]) -> f32[4] { - %param0 = f32[4]{0} parameter(0) - %param1 = f32[4]{0} parameter(1) - ROOT %map = f32[4]{0} map(f32[4]{0} %param0, f32[4]{0} %param1), to_apply=%add_F32.v3 -} - -)" -}, -// reduce -{ -"Reduce", -R"(HloModule ReduceR3ToR2_module: - -%add_F32.v3 (lhs: f32[], rhs: f32[]) -> f32[] { - %lhs = f32[] parameter(0) - %rhs = f32[] parameter(1) - ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs) -} - -ENTRY %ReduceR3ToR2.v3 (input: f32[8,16,256]) -> f32[8,16] { - %input = f32[8,16,256]{2,1,0} parameter(0) - %constant = f32[] constant(0) - ROOT %reduce = f32[8,16]{1,0} reduce(f32[8,16,256]{2,1,0} %input, f32[] %constant), dimensions={2}, to_apply=%add_F32.v3 -} - )" }, // select and scatter { "SelectAndScatter", -R"(HloModule R4F32OverlapSmall_module: +R"(HloModule R4F32OverlapSmall_module %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) @@ -450,7 +435,7 @@ ENTRY %R4F32OverlapSmall.v4 () -> f32[4,5,1,1] { // select and scatter on scalar { "SelectAndScatterScalar", -R"(HloModule select_and_scatter_scalar: +R"(HloModule select_and_scatter_scalar %ge_F32.v3 (lhs: f32[], rhs: f32[]) -> pred[] { %lhs = f32[] parameter(0) @@ -476,7 +461,7 @@ ENTRY %SelectAndScatterScalar () -> f32[] { // slice { "Slice", -R"(HloModule slice_module: +R"(HloModule slice_module ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) @@ -488,7 +473,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { // slice, no stride { "SliceNoStride", -R"(HloModule Slice3x3x3_To_1x3x3_F32_module: +R"(HloModule Slice3x3x3_To_1x3x3_F32_module ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { %constant = f32[3,3,3]{2,1,0} constant(f32[3,3,3] { { { 0, 1, 2 }, { 3, 4, 5 }, { 6, 7, 8 } }, { { 9, 10, 11 }, { 12, 13, 14 }, { 15, 16, 17 } }, { { 18, 19, 20 }, { 21, 22, 23 }, { 24, 25, 26 } } }) @@ -500,7 +485,7 @@ ENTRY %Slice3x3x3_To_1x3x3_F32.v2 () -> f32[1,3,3] { // slice R0 { "SliceR0", -R"(HloModule SliceR0_module: +R"(HloModule SliceR0_module ENTRY %SliceR0.v2 () -> s32[] { %constant = s32[] constant(1) @@ -512,7 +497,7 @@ ENTRY %SliceR0.v2 () -> s32[] { // transpose { "Transpose", -R"(HloModule Transpose_module: +R"(HloModule Transpose_module ENTRY %Transpose.v2 () -> s32[1,2,3] { %constant = s32[1,2,3]{2,1,0} constant(s32[1,2,3] { { { 1, 2, 3 }, { 4, 5, 6 } } }) @@ -524,7 +509,7 @@ ENTRY %Transpose.v2 () -> s32[1,2,3] { // Dynamic slice { "DynamicSlice", -R"(HloModule DynamicSlice_module: +R"(HloModule DynamicSlice_module ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) -> s32[2,2,258] { %original_parameter = s32[2,2,258]{2,1,0} parameter(0) @@ -539,7 +524,7 @@ ENTRY %DynamicSlice.v5 (original_parameter: s32[2,2,258], start_index: s32[1]) - // Dynamic update slice { "DynamicUpdateSlice", -R"(HloModule DynamicUpdateSlice_module: +R"(HloModule DynamicUpdateSlice_module ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_indices: s32[4]) -> s32[1,1,25,1] { %input = s32[1,1,25,1]{3,2,1,0} parameter(0) @@ -553,7 +538,7 @@ ENTRY %DynamicUpdateSlice.v4 (input: s32[1,1,25,1], update: s32[1,1,2,1], start_ // batch norm training { "BatchNormTraining", -R"(HloModule BasicTraining_module: +R"(HloModule BasicTraining_module ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { %constant = f32[2,2,1,2]{3,2,1,0} constant(f32[2,2,1,2] { { /*i0=0*/ { /*i1=0*/ {1, 2} }, { /*i1=1*/ {3, 4} } }, { /*i0=1*/ { /*i1=0*/ {5, 6} }, { /*i1=1*/ {7, 8} } } }) @@ -567,7 +552,7 @@ ENTRY %BasicTraining.v4 () -> (f32[2,2,1,2], f32[2], f32[2]) { // batch norm inference { "BatchNormInference", -R"(HloModule BatchNormInference_module: +R"(HloModule BatchNormInference_module ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2], mean: f32[2], variance: f32[2]) -> f32[2,2,2,2] { %input = f32[2,2,2,2]{3,2,1,0} parameter(0) @@ -583,7 +568,7 @@ ENTRY %BatchNormInference.v6 (input: f32[2,2,2,2], offset: f32[2], scale: f32[2] // batch norm grad { "BatchNormGrad", -R"(HloModule BatchNormGrad_module: +R"(HloModule BatchNormGrad_module ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], variance: f32[2], grad_output: f32[2,2,2,2]) -> (f32[2,2,2,2], f32[2], f32[2]) { %input = f32[2,2,2,2]{3,2,1,0} parameter(0) @@ -594,12 +579,60 @@ ENTRY %BatchNormGrad.v4 (input: f32[2,2,2,2], scale: f32[2], mean: f32[2], varia ROOT %batch-norm-grad = (f32[2,2,2,2]{3,2,1,0}, f32[2]{0}, f32[2]{0}) batch-norm-grad(f32[2,2,2,2]{3,2,1,0} %input, f32[2]{0} %scale, f32[2]{0} %mean, f32[2]{0} %variance, f32[2,2,2,2]{3,2,1,0} %grad_output), epsilon=0.001, feature_index=0 } +)" +}, +// fft +{ +"Fft", +R"(HloModule Fft_module + +ENTRY %Fft (input: c64[8,32]) -> c64[8,32] { + %input = c64[8,32]{1,0} parameter(0) + ROOT %fft = c64[8,32]{1,0} fft(c64[8,32]{1,0} %input), fft_type=FFT, fft_length={32} +} + +)" +}, +// ifft +{ +"Ifft2d", +R"(HloModule Ifft2d_module + +ENTRY %Ifft2d (input: c64[5,8,32]) -> c64[5,8,32] { + %input = c64[5,8,32]{2,1,0} parameter(0) + ROOT %fft = c64[5,8,32]{2,1,0} fft(c64[5,8,32]{2,1,0} %input), fft_type=IFFT, fft_length={8,32} +} + +)" +}, +// rfft2d +{ +"Rfft2d", +R"(HloModule Rfft2d_module + +ENTRY %Rfft2d (input: f32[5,64,32]) -> c64[5,64,17] { + %input = f32[5,64,32]{2,1,0} parameter(0) + ROOT %fft = c64[5,64,17]{2,1,0} fft(f32[5,64,32]{2,1,0} %input), fft_type=RFFT, fft_length={64,32} +} + +)" +}, +// irfft3d +{ +"Irfft3d", +R"(HloModule Irfft3d_module + +ENTRY %Irfft3d (input: c64[5,64,128,33]) -> f32[5,64,128,64] { + %input = c64[5,64,128,33]{3,2,1,0} parameter(0) + ROOT %fft = f32[5,64,128,64]{3,2,1,0} fft(c64[5,64,128,33]{3,2,1,0} %input), fft_type=IRFFT, fft_length={64,128,64} +} + )" }, // pad { "Pad", -R"(HloModule Pad1DS3Array_module: +R"(HloModule Pad1DS3Array_module ENTRY %Pad1DS3Array.v3 () -> f32[8] { %constant = f32[3]{0} constant({1, 2, 3}) @@ -612,7 +645,7 @@ ENTRY %Pad1DS3Array.v3 () -> f32[8] { // pad has interior { "PadHasInterior", -R"(HloModule PadHasInterior_module: +R"(HloModule PadHasInterior_module ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { %input = f32[1,25,7,7]{3,2,1,0} parameter(0) @@ -620,12 +653,25 @@ ENTRY %PadHasInterior.v3 (input: f32[1,25,7,7]) -> f32[1,25,17,11] { ROOT %pad = f32[1,25,17,11]{3,2,1,0} pad(f32[1,25,7,7]{3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_0_0x2_2_1x2_2_0 } +)" +}, +// Negative padding +{ +"PadHasNegativePadding", +R"(HloModule PadHasNegativePadding_module + +ENTRY %PadHasNegativePadding (input: f32[1,25,7,7,10]) -> f32[1,15,6,3,29] { + %input = f32[1,25,7,7,10]{4,3,2,1,0} parameter(0) + %constant = f32[] constant(-5.123) + ROOT %pad = f32[1,15,6,3,29]{4,3,2,1,0} pad(f32[1,25,7,7,10]{4,3,2,1,0} %input, f32[] %constant), padding=0_0_0x0_-10_0x0_-1_0x-2_-2_0x-1_-1_3 +} + )" }, // fusion { "Fusion", -R"(HloModule fusion_module: +R"(HloModule fusion_module %fused_computation (constant.param_0: f32[3,2,1,1], constant.1.param_1: f32[2]) -> f32[3,2,1,1] { %constant.param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) @@ -640,22 +686,182 @@ ENTRY %fusion.v3 () -> f32[3,2,1,1] { ROOT %fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %constant, f32[2]{0} %constant.1), kind=kLoop, calls=%fused_computation } +)" +}, +{ +"Sparse", +R"(HloModule sparse_f32 + +ENTRY %sparse () -> f32[2,3,4] { + ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{[0, 1, 2]: 1, [1, 2, 3]: 2, [2, 3, 4]: 3}) +} + +)" +}, +{ +"SparseEmpty", +R"(HloModule sparse_f32_empty + +ENTRY %sparse_f32_empty () -> f32[2,3,4] { + ROOT %foo = f32[2,3,4]sparse{10} constant(f32[2,3,4]{}) +} + +)" +}, +{ +"SparseR1", +R"(HloModule sparse_f32_r1 + +ENTRY %sparse_f32_r1 () -> f32[9] { + ROOT %foo = f32[9]sparse{10} constant(f32[9]{1: 2, 3: 4, 5: 6}) +} + +)" +}, + }); + // clang-format on +} + +std::vector CreateShortTestCases() { + // clang-format off + return std::vector({ +// map +{ +"Map", +R"(HloModule MapBinaryAdder_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY MapBinaryAdder.v3 { + param0 = f32[4]{0} parameter(0) + param1 = f32[4]{0} parameter(1) + ROOT map = f32[4]{0} map(param0, param1), to_apply=add_F32.v3 +} + +)" +}, +// reduce +{ +"Reduce", +R"(HloModule ReduceR3ToR2_module + +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY ReduceR3ToR2.v3 { + input = f32[8,16,256]{2,1,0} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{1,0} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +} + )" }, // infeed/outfeed { "InfeedOutfeed", -R"(HloModule outfeed_module: +R"(HloModule outfeed_module + +ENTRY InfeedToOutfeed { + infeed = (u32[3]{0}, pred[]) infeed() + outfeed = () outfeed(infeed) + ROOT infeed.1 = (u32[3]{0}, pred[]) infeed() + outfeed.1 = () outfeed(infeed.1) +} + +)" +}, +// Rng +{ +"Rng", +R"(HloModule rng_module + +ENTRY Rng { + constant = f32[] constant(0) + constant.1 = f32[] constant(1) + ROOT rng = f32[8]{0} rng(constant, constant.1), distribution=rng_uniform +} + +)" +}, +// Reduce precision +{ +"ReducePrevison", +R"(HloModule reduce_precision + +ENTRY ReducePrecision { + constant = f32[1]{0} constant({3.14159}) + ROOT reduce-precision = f32[1]{0} reduce-precision(constant), exponent_bits=8, mantissa_bits=10 +} + +)" +}, +// Conditional +{ +"Conditional", +R"(HloModule conditional + +Negate { + x = f32[] parameter(0) + ROOT negate = f32[] negate(x) +} + +Identity { + y = f32[] parameter(0) + ROOT copy = f32[] copy(y) +} -ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { - %infeed = (u32[3]{0}, pred[]) infeed() - %outfeed = () outfeed((u32[3]{0}, pred[]) %infeed) - ROOT %infeed.1 = (u32[3]{0}, pred[]) infeed() - %outfeed.1 = () outfeed((u32[3]{0}, pred[]) %infeed.1) +ENTRY Parameters1.v4 { + constant = pred[] constant(true) + constant.1 = f32[] constant(56) + constant.2 = f32[] constant(12) + ROOT conditional = f32[] conditional(constant, constant.1, constant.2), true_computation=Negate, false_computation=Identity } )" +}, +// CustomCall +{ +"CustomCall", +R"(HloModule custom_call + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar" +} + +)" +}, +// Variables with non-default names +{ +"NonDefaultNames", +R"(HloModule add_constants_module + +ENTRY add_constants { + foo = f32[] constant(3.14) + ROOT bar = f32[] add(foo, foo) } + +)" +}, +{ +"Dot", +R"(HloModule dot + +ENTRY dot { + a = f32[2,10]{1,0} parameter(0) + b = f32[10,3]{1,0} parameter(1) + ROOT dot = f32[2,3]{1,0} dot(a, b), lhs_batch_dims={0}, lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +)" +}, }); // clang-format on } @@ -674,18 +880,35 @@ class HloParserTest : public ::testing::Test, void ExpectEqual() { const string& original = GetParam().module_string; auto result = Parse(original); - TF_EXPECT_OK(result.status()); + TF_ASSERT_OK(result.status()); + EXPECT_EQ(original, result.ValueOrDie()->ToString( + HloPrintOptions().set_print_large_constants(true))); + } +}; + +class HloParserShortTest : public HloParserTest { + protected: + void ExpectEqualShort() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_ASSERT_OK(result.status()); EXPECT_EQ(original, - result.ValueOrDie()->ToString(/*include_large_constants=*/true)); + result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable())); } }; TEST_P(HloParserTest, Run) { ExpectEqual(); } +TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); } + INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, ::testing::ValuesIn(CreateTestCases()), TestDataToString); +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest, + ::testing::ValuesIn(CreateShortTestCases()), + TestDataToString); + TEST_F(HloParserTest, Empty) { const string original = ""; auto result = Parse(original); @@ -749,7 +972,7 @@ ENTRY %blabla (x: f32[]) -> pred[] { } TEST_F(HloParserTest, MoreConstants) { - const string original = R"(HloModule SelectScalarS32True_module: + const string original = R"(HloModule SelectScalarS32True_module ENTRY %SelectScalarS32True.v4 () -> s32[] { %constant.2 = pred[] constant(true) @@ -766,7 +989,7 @@ ENTRY %SelectScalarS32True.v4 () -> s32[] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { - const string original = R"(HloModule some_2_module: + const string original = R"(HloModule some_2_module ENTRY %some_2 () -> f32[2] { ROOT %constant = f32[2]{0} constant({1,{2}}) @@ -780,7 +1003,7 @@ ENTRY %some_2 () -> f32[2] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { - const string original = R"(HloModule some_2x3_module: + const string original = R"(HloModule some_2x3_module ENTRY %some_2x3 () -> f32[2,3] { ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) @@ -794,7 +1017,7 @@ ENTRY %some_2x3 () -> f32[2,3] { } TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { - const string original = R"(HloModule some_2x3x2_module: + const string original = R"(HloModule some_2x3x2_module ENTRY %some_2x3x2 () -> f32[2,3,2] { ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) @@ -809,7 +1032,7 @@ ENTRY %some_2x3x2 () -> f32[2,3,2] { TEST_F(HloParserTest, ConstantF16Overflow) { const string original = - R"(HloModule ConstantF16Overflow_module: + R"(HloModule ConstantF16Overflow_module ENTRY %ConstantF16Overflow.v4 () -> f16[] { ROOT %constant = f16[] constant(-65505) @@ -823,7 +1046,7 @@ ENTRY %ConstantF16Overflow.v4 () -> f16[] { } TEST_F(HloParserTest, ConstantWithExp) { - const string original = R"(HloModule ConstantWithExp_module: + const string original = R"(HloModule ConstantWithExp_module ENTRY %ConstantWithExp.v4 () -> f32[] { %constant.1 = f32[] constant(3e+2) @@ -838,7 +1061,7 @@ ENTRY %ConstantWithExp.v4 () -> f32[] { } TEST_F(HloParserTest, AttibutesAnyOrder) { - const string original = R"(HloModule any_order_module: + const string original = R"(HloModule any_order_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -852,7 +1075,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } TEST_F(HloParserTest, InvalidDimLabels) { - string prefix = R"(HloModule invalid_dim_labels_module: + string prefix = R"(HloModule invalid_dim_labels_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -864,19 +1087,21 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 )"; - ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=00_01_10", suffix)) - .status() - .error_message(), - "expects dim labels pattern"); + ExpectHasSubstr( + Parse(tensorflow::strings::StrCat(prefix, ",dim_labels=00_01_10", suffix)) + .status() + .error_message(), + "expects dim labels pattern"); - ExpectHasSubstr(Parse(StrCat(prefix, ",dim_labels=010_1100->010", suffix)) + ExpectHasSubstr(Parse(tensorflow::strings::StrCat( + prefix, ",dim_labels=010_1100->010", suffix)) .status() .error_message(), "must have the same rank"); } TEST_F(HloParserTest, UnexpectedAttribute) { - const string original = R"(HloModule unexpected_attr_module: + const string original = R"(HloModule unexpected_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -892,7 +1117,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, MissingAttribute) { - const string original = R"(HloModule missing_attr_module: + const string original = R"(HloModule missing_attr_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -908,7 +1133,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, PredecessorUndefined) { - const string original = R"(HloModule pre_not_found_module: + const string original = R"(HloModule pre_not_found_module ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { %recv = (f32[], u32[]) recv(), channel_id=15 @@ -924,7 +1149,7 @@ ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { } TEST_F(HloParserTest, SliceAllowOmitStride1) { - const string original = R"(HloModule slice_module: + const string original = R"(HloModule slice_module ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { %p0 = f32[3,3,4,4]{3,2,1,0} parameter(0) @@ -936,7 +1161,7 @@ ENTRY %slice.v2 (p0: f32[3,3,4,4]) -> f32[3,3,2,4] { } TEST_F(HloParserTest, PaddingConfigIsNotWindowPad) { - const string original = R"(HloModule window_pad_module: + const string original = R"(HloModule window_pad_module ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2,1] { %input = f32[1,2,1]{2,1,0} parameter(0) @@ -951,7 +1176,7 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 } TEST_F(HloParserTest, CommaBetweenSubAttributes) { - const string original = R"(HloModule test_comma_module: + const string original = R"(HloModule test_comma_module ENTRY %test_comma.v4 () -> f32[] { ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} @@ -961,6 +1186,124 @@ ENTRY %test_comma.v4 () -> f32[] { TF_EXPECT_OK(Parse(original).status()); } +TEST_F(HloParserTest, ComputationShapeDoesNotMatchRootShape) { + const string original = R"(HloModule custom_call: + +ENTRY %CustomCall () -> f32[1] { + %constant = f32[1]{0} constant({12345}) + ROOT %foo = f32[1,2,3]{0,2,1} custom-call(f32[1]{0} %constant), custom_call_target="foo\"bar" +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "Shape of computation CustomCall, f32[1], is not compatible " + "with that of its root instruction foo, f32[1,2,3]"); +} + +TEST_F(HloParserTest, EntryComputationWithLayout) { + const string original = R"(HloModule layout: +add_F32.v3 { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY %Reduce (input: f32[8,16,256]) -> f32[8,16] { + input = f32[8,16,256]{0,1,2} parameter(0) + constant = f32[] constant(0) + ROOT reduce = f32[8,16]{0,1} reduce(input, constant), dimensions={2}, to_apply=add_F32.v3 +})"; + + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + auto program_layout = module.ValueOrDie()->entry_computation_layout(); + ASSERT_EQ(program_layout.parameter_count(), 1); + auto param_layout = program_layout.parameter_layout(0).layout(); + auto result_layout = program_layout.result_layout().layout(); + EXPECT_TRUE( + LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1, 2}), param_layout)) + << "actual layout of parameter(0) is " + << LayoutUtil::HumanString(param_layout); + EXPECT_TRUE(LayoutUtil::Equal(LayoutUtil::MakeLayout({0, 1}), result_layout)) + << "actual layout of result is " + << LayoutUtil::HumanString(result_layout); +} + +TEST_F(HloParserTest, NoEntry) { + const string original = R"(HloModule no_entry: +c1 { + const1 = f32[1]{0} constant({12345}) +} +c2 { + const2 = f32[1]{0} constant({67890}) +})"; + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + EXPECT_EQ(module.ValueOrDie()->entry_computation()->name(), "c2"); +} + +TEST_F(HloParserTest, NoRoot) { + const string original = R"(HloModule no_root: +ENTRY consts { + first = f32[1]{0} constant({12345}) + last = f32[1]{0} constant({67890}) +})"; + auto module = Parse(original); + TF_ASSERT_OK(module.status()); + EXPECT_EQ( + module.ValueOrDie()->entry_computation()->root_instruction()->name(), + "last"); +} + +TEST_F(HloParserTest, MultipleEntries) { + const string original = R"(HloModule multiple_entries: +ENTRY c1 { + const1 = f32[1]{0} constant({12345}) +} +ENTRY c2 { + const2 = f32[1]{0} constant({67890}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "expects only one ENTRY"); +} + +TEST_F(HloParserTest, MultipleRoots) { + const string original = R"(HloModule multiple_roots: +ENTRY consts { + ROOT const1 = f32[1]{0} constant({12345}) + ROOT const2 = f32[1]{0} constant({12345}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + "one computation should have only one ROOT"); +} + +TEST_F(HloParserTest, InstructionExists) { + const string original = R"(HloModule comp_exists +c1 { + instr = f32[1]{0} constant({12345}) +} +c2 { + instr = f32[1]{0} constant({67890}) +})"; + + ExpectHasSubstr(Parse(original).status().error_message(), + R"(was parsing 3:3: error: instruction previously defined here + instr = f32[1]{0} constant({12345}) + ^)"); +} + +TEST_F(HloParserTest, ComputationExists) { + const string original = R"(HloModule comp_exists +comp { + const1 = f32[1]{0} constant({12345}) +} +comp { + const2 = f32[1]{0} constant({67890}) +})"; + ExpectHasSubstr(Parse(original).status().error_message(), + R"(was parsing 2:1: error: computation previously defined here +comp { +^)"); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 07e48804d053f31bdff6678f09ee2c1e3b731e0f..7928bee5c2097f353b182095a555c334d7b69c95 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -18,6 +18,9 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + namespace xla { namespace tools { @@ -60,10 +63,9 @@ enum class TokKind { kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} kDxD, // [0-9]+(x[0-9]+)+ kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} - kOpcode, // add - kFusionKind, // kLoop, kOutput, ... kInt, // 42 kDecimal, // 4.2 }; diff --git a/tensorflow/compiler/xla/tools/replay_computation.cc b/tensorflow/compiler/xla/tools/replay_computation.cc index ec3f6a0471e2ae965846f5ef7560e448fe9d8073..eda5effbb92db92c9317a956497a00c0ec15c27c 100644 --- a/tensorflow/compiler/xla/tools/replay_computation.cc +++ b/tensorflow/compiler/xla/tools/replay_computation.cc @@ -59,25 +59,33 @@ namespace xla { namespace tools { namespace { +// Command-line opts to this tool. See main() for descriptions of these +// fields. +struct Options { + string fake_infeed_shape; + bool use_fake_data = false; + bool print_result = true; + int num_runs = 1; +}; + // Invokes the given computation passing arbitrary data for every (unbound) // parameter if use_fake_data, Otherwise use recorded data if available. // // Similarly, infeeds fake data of shape fake_infeed_shape if it is provided; // otherwise, no infeed is performed. StatusOr> ReplayComputation( - const SessionModule& module, int num_runs, - tensorflow::StringPiece fake_infeed_shape, bool use_fake_data, - Client* client) { + const SessionModule& module, Client* client, const Options& opts) { TF_ASSIGN_OR_RETURN(Computation computation, client->LoadSnapshot(module)); std::vector> arguments; - if (use_fake_data) { + if (opts.use_fake_data) { arguments = MakeFakeArgumentsOrDie(computation, client); } else { // use recorded data if available for (const auto& proto : module.arguments()) { - Literal literal(proto); + TF_ASSIGN_OR_RETURN(std::unique_ptr literal, + Literal::CreateFromProto(proto)); TF_ASSIGN_OR_RETURN(std::unique_ptr data, - client->TransferToServer(literal)); + client->TransferToServer(*literal)); arguments.push_back(std::move(data)); } } @@ -86,12 +94,12 @@ StatusOr> ReplayComputation( // concurrent infeed occur via the fake_infeed_shape. tensorflow::gtl::optional pool; - if (!fake_infeed_shape.empty()) { + if (!opts.fake_infeed_shape.empty()) { pool.emplace(tensorflow::Env::Default(), "infeed", /*num_threads=*/1); - pool->Schedule([fake_infeed_shape, client]() { + pool->Schedule([opts, client]() { StatusOr shape_status = - ShapeUtil::ParseShapeString(fake_infeed_shape); + ShapeUtil::ParseShapeString(opts.fake_infeed_shape); TF_CHECK_OK(shape_status.status()); Shape shape = std::move(shape_status).ValueOrDie(); StatusOr> data_status = MakeFakeLiteral(shape); @@ -112,19 +120,19 @@ StatusOr> ReplayComputation( // Run the computation num_runs times, and return the result from the last // execution. std::unique_ptr result; - for (int i = 0; i < num_runs; ++i) { + for (int i = 0; i < opts.num_runs; ++i) { ExecutionProfile profile; - if (use_fake_data) { - // If using fake data, execute the computation but don't bother retrieving - // the result -- presumably it's uninteresting, since our data is fake. + if (opts.print_result) { + TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( + computation, execute_arguments, + /*execution_options=*/nullptr, &profile)); + } else { + // If we're not printing the result, execute the computation but don't + // bother retrieving the result. This can be a significant speedup. TF_RETURN_IF_ERROR(client ->Execute(computation, execute_arguments, /*execution_options=*/nullptr, &profile) .status()); - } else { - TF_ASSIGN_OR_RETURN(result, client->ExecuteAndTransfer( - computation, execute_arguments, - /*execution_options=*/nullptr, &profile)); } LOG(INFO) << "Execution took " << static_cast(profile.compute_time_ns()) / 1e9 << "s"; @@ -133,16 +141,15 @@ StatusOr> ReplayComputation( return std::move(result); } -int RealMain(tensorflow::gtl::ArraySlice args, int num_runs, - tensorflow::StringPiece fake_infeed_shape, bool use_fake_data) { +int RealMain(tensorflow::gtl::ArraySlice args, const Options& opts) { Client* client = ClientLibrary::LocalClientOrDie(); tensorflow::Env* env = tensorflow::Env::Default(); int exit_status = EXIT_SUCCESS; for (char* arg : args) { SessionModule module; TF_CHECK_OK(tensorflow::ReadBinaryProto(env, arg, &module)); - StatusOr> result_status = ReplayComputation( - module, num_runs, fake_infeed_shape, use_fake_data, client); + StatusOr> result_status = + ReplayComputation(module, client, opts); if (!result_status.ok()) { fprintf(stderr, "%s: error: %s\n", arg, result_status.status().ToString().c_str()); @@ -156,12 +163,16 @@ int RealMain(tensorflow::gtl::ArraySlice args, int num_runs, ShapeUtil::HumanString(result->shape()).c_str(), result->ToString().c_str()); if (module.has_result()) { + std::unique_ptr literal = + Literal::CreateFromProto(module.result()).ConsumeValueOrDie(); fprintf(stdout, "was %s:%s\n", ShapeUtil::HumanString(module.result().shape()).c_str(), - Literal(module.result()).ToString().c_str()); + literal->ToString().c_str()); } } } + + ClientLibrary::DestroyLocalInstances(); return exit_status; } @@ -170,16 +181,15 @@ int RealMain(tensorflow::gtl::ArraySlice args, int num_runs, } // namespace xla int main(int argc, char** argv) { - // Flags - xla::string fake_infeed_shape; - bool use_fake_data = false; - int num_runs = 1; + xla::tools::Options opts; const std::vector flag_list = { - tensorflow::Flag("use_fake_data", &use_fake_data, + tensorflow::Flag("use_fake_data", &opts.use_fake_data, "Replay computation using fake data"), - tensorflow::Flag("num_runs", &num_runs, + tensorflow::Flag("print_result", &opts.print_result, + "Print the result of the computation to stdout"), + tensorflow::Flag("num_runs", &opts.num_runs, "Number of times to run each computation"), - tensorflow::Flag("fake_infeed_shape", &fake_infeed_shape, + tensorflow::Flag("fake_infeed_shape", &opts.fake_infeed_shape, "Shape of fake data to construct for (infinite) infeed"), }; xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list); @@ -191,5 +201,5 @@ int main(int argc, char** argv) { tensorflow::gtl::ArraySlice args(argv, argc); args.pop_front(); // Pop off the binary name, argv[0] - return xla::tools::RealMain(args, num_runs, fake_infeed_shape, use_fake_data); + return xla::tools::RealMain(args, opts); } diff --git a/tensorflow/compiler/xla/tools/show_literal.cc b/tensorflow/compiler/xla/tools/show_literal.cc index b50cb5e28eac14ed99af566939f8bd64e393ff64..fe8e72ba32bb4493b2751cfdfeb977f271092f9c 100644 --- a/tensorflow/compiler/xla/tools/show_literal.cc +++ b/tensorflow/compiler/xla/tools/show_literal.cc @@ -40,7 +40,8 @@ int main(int argc, char **argv) { xla::LiteralProto literal_proto; TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), argv[1], &literal_proto)); - xla::Literal literal(literal_proto); + std::unique_ptr literal = + xla::Literal::CreateFromProto(literal_proto).ConsumeValueOrDie(); LOG(INFO) << "literal: " << literal_proto.ShortDebugString(); - fprintf(stderr, "%s\n", literal.ToString().c_str()); + fprintf(stderr, "%s\n", literal->ToString().c_str()); } diff --git a/tensorflow/compiler/xla/tools/show_text_literal.cc b/tensorflow/compiler/xla/tools/show_text_literal.cc index bbe9902aa17a585c4bad5b732330305dfdd45302..8525873e913185554d18df8c8c3584bfcdcdcabe 100644 --- a/tensorflow/compiler/xla/tools/show_text_literal.cc +++ b/tensorflow/compiler/xla/tools/show_text_literal.cc @@ -39,13 +39,13 @@ int main(int argc, char **argv) { std::unique_ptr literal = xla::TextLiteralReader::ReadPath(argv[1]).ConsumeValueOrDie(); - LOG(INFO) << "literal: " << literal->ShortDebugString(); + LOG(INFO) << "literal: " << *literal; fprintf(stderr, "%s\n", literal->ToString().c_str()); if (literal->shape().element_type() == xla::F32) { - float min = - *std::min_element(literal->f32s().begin(), literal->f32s().end()); - float max = - *std::max_element(literal->f32s().begin(), literal->f32s().end()); + float min = *std::min_element(literal->data().begin(), + literal->data().end()); + float max = *std::max_element(literal->data().begin(), + literal->data().end()); fprintf(stderr, "min: %a=%f\n", min, min); fprintf(stderr, "max: %a=%f\n", max, max); } diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index e595df3052c3de64de503d7627eff72dcba177ee..1f0c626bbb2d64ef4e67c9ec51485ae96ae73d04 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -30,9 +30,7 @@ limitations under the License. #include "tensorflow/core/platform/stacktrace.h" namespace xla { -namespace { -// Logs the provided status message with a backtrace. Status WithLogBacktrace(const Status& status) { CHECK(!status.ok()); VLOG(1) << status.ToString(); @@ -40,8 +38,6 @@ Status WithLogBacktrace(const Status& status) { return status; } -} // namespace - ScopedLoggingTimer::ScopedLoggingTimer(const string& label, bool enabled) : enabled(enabled), label(label) { if (enabled) { @@ -74,13 +70,18 @@ Status AppendStatus(Status prior, tensorflow::StringPiece context) { // Implementation note: we can't common these out (without using macros) because // they all need to va_start/va_end their varargs in their frame. -Status InvalidArgument(const char* format, ...) { +Status InvalidArgumentV(const char* format, va_list args) { string message; + tensorflow::strings::Appendv(&message, format, args); + return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); +} + +Status InvalidArgument(const char* format, ...) { va_list args; va_start(args, format); - tensorflow::strings::Appendv(&message, format, args); + Status result = InvalidArgumentV(format, args); va_end(args); - return WithLogBacktrace(tensorflow::errors::InvalidArgument(message)); + return result; } Status Unimplemented(const char* format, ...) { @@ -191,9 +192,9 @@ std::vector ComposePermutations(tensorflow::gtl::ArraySlice p1, return output; } -bool IsIdentityPermutation(tensorflow::gtl::ArraySlice p) { - for (int64 i = 0; i < p.size(); ++i) { - if (p[i] != i) { +bool IsIdentityPermutation(tensorflow::gtl::ArraySlice permutation) { + for (int64 i = 0; i < permutation.size(); ++i) { + if (permutation[i] != i) { return false; } } @@ -338,7 +339,7 @@ std::vector> CommonFactors( string SanitizeFileName(string file_name) { for (char& c : file_name) { - if (c == '/' || c == '\\' || c == '[' || c == ']') { + if (c == '/' || c == '\\' || c == '[' || c == ']' || c == ' ') { c = '_'; } } diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index b722095d1f38bf8a984c3ce9092a65f8e0baa911..08df5b12b3a53a138f56705531baa3333b23c5d8 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -40,6 +40,13 @@ limitations under the License. namespace xla { +// Logs the provided status message with a backtrace. +// +// For use by Status-factories, logs a backtrace at the point where the status +// is created, such that we can use --vmodule=util=1 to see all status +// creation backtraces. +Status WithLogBacktrace(const Status& status); + // Ranks greater than 8 are very rare, so use InlinedVector to store // the bounds and indices. And for the rare cases of ranks greater than 8, // the InlinedVector will just behave like an std::vector<> and allocate the @@ -207,6 +214,27 @@ Status ResourceExhausted(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); Status NotFound(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); Status Unavailable(const char* format, ...) TF_PRINTF_ATTRIBUTE(1, 2); +// Passed-varargs variant of the InvalidArgument factory above. +Status InvalidArgumentV(const char* format, va_list args); + +template +Status UnimplementedStrCat(Args&&... concat) { + return Unimplemented( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + +template +Status InternalErrorStrCat(Args&&... concat) { + return InternalError( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + +template +Status ResourceExhaustedStrCat(Args&&... concat) { + return ResourceExhausted( + "%s", tensorflow::strings::StrCat(std::forward(concat)...).c_str()); +} + // Splits the lines of the original, replaces leading whitespace with the prefix // given by "indentation", and returns the string joined by newlines again. As a // side effect, any additional trailing whitespace is removed. @@ -239,11 +267,14 @@ std::vector Permute(tensorflow::gtl::ArraySlice permutation, // Override of the above that works around compile failures with gcc 7.1.1. // For details see https://github.com/tensorflow/tensorflow/issues/10843 +// Hide this workaround from MSVC as it causes ambiguous error. +#ifndef _MSC_VER template std::vector Permute(tensorflow::gtl::ArraySlice permutation, const std::vector& input) { return Permute(permutation, input); } +#endif // Inverts a permutation, i.e., output_permutation[input_permutation[i]] = i. std::vector InversePermutation( @@ -329,7 +360,7 @@ T CeilOfRatio(T dividend, T divisor) { } // Rounds the value up to a multiple of the divisor by first calling CeilOfRatio -// then multiplying by the divisor. For example: RoundUpToMultiple(13, 8) => 16 +// then multiplying by the divisor. For example: RoundUpToNearest(13, 8) => 16 template T RoundUpToNearest(T value, T divisor) { return CeilOfRatio(value, divisor) * divisor; @@ -337,7 +368,7 @@ T RoundUpToNearest(T value, T divisor) { // Rounds the value down to a multiple of the divisor by first calling // FloorOfRatio then multiplying by the divisor. For example: -// RoundUpToMultiple(13, 8) => 8 +// RoundDownToNearest(13, 8) => 8 template T RoundDownToNearest(T value, T divisor) { return FloorOfRatio(value, divisor) * divisor; @@ -395,6 +426,33 @@ std::vector> CommonFactors( // Removes illegal characters from filenames. string SanitizeFileName(string file_name); +template +bool c_all_of(Container container, Predicate predicate) { + return std::all_of(std::begin(container), std::end(container), predicate); +} + +template +OutputIterator c_transform(InputContainer input_container, + OutputIterator output_iterator, + UnaryOperation unary_op) { + return std::transform(std::begin(input_container), std::end(input_container), + output_iterator, unary_op); +} + +template +OutputIterator c_copy_if(InputContainer input_container, + OutputIterator output_iterator, + UnaryPredicate predicate) { + return std::copy_if(std::begin(input_container), std::end(input_container), + output_iterator, predicate); +} + +template +void c_sort(InputContainer& input_container, Comparator comparator) { + std::sort(input_container.begin(), input_container.end(), comparator); +} + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc index 2e0eba8de0100fb4e7e45348618febd778c88c9a..93284b80f9e1f82c4b18dc7388754d5c01a7740c 100644 --- a/tensorflow/compiler/xla/window_util.cc +++ b/tensorflow/compiler/xla/window_util.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -25,6 +26,28 @@ limitations under the License. namespace xla { namespace window_util { +Window MakeWindow(tensorflow::gtl::ArraySlice sizes) { + Window window; + for (int64 size : sizes) { + auto* dimension = window.add_dimensions(); + dimension->set_size(size); + dimension->set_stride(1); + dimension->set_base_dilation(1); + dimension->set_window_dilation(1); + } + return window; +} + +PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes) { + PaddingConfig config; + for (int64 size : sizes) { + auto* dimension = config.add_dimensions(); + dimension->set_edge_padding_low(size); + dimension->set_edge_padding_high(size); + } + return config; +} + /* static */ string ToString(const WindowDimension& dim) { using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; @@ -88,6 +111,11 @@ string ToString(const Window& window) { return StrCat(dim.window_dilation()); }); } + if (HasWindowReversal(window)) { + add_field(" rhs_reversal", [](const WindowDimension& dim) { + return StrCat(dim.window_reversal() ? 1 : 0); + }); + } return str; } @@ -109,13 +137,21 @@ bool HasPadding(const Window& window) { return false; } -bool HasEvenPadding(const Window& window) { +bool HasSymmetricPadding(const Window& window) { return std::all_of(window.dimensions().begin(), window.dimensions().end(), [](const WindowDimension& dim) { return dim.padding_low() == dim.padding_high(); }); } +bool HasSymmetricPadding(const PaddingConfig& padding_config) { + return std::all_of(padding_config.dimensions().begin(), + padding_config.dimensions().end(), + [](const PaddingConfig::PaddingConfigDimension& dim) { + return dim.edge_padding_low() == dim.edge_padding_high(); + }); +} + bool HasNegativePadding(const Window& window) { return std::any_of(window.dimensions().begin(), window.dimensions().end(), [](const WindowDimension& dim) { @@ -141,10 +177,25 @@ bool HasWindowDilation(const Window& window) { return false; } +bool HasWindowReversal(const Window& window) { + for (const auto& dim : window.dimensions()) { + if (dim.window_reversal()) { + return true; + } + } + return false; +} + bool HasDilation(const Window& window) { return HasBaseDilation(window) || HasWindowDilation(window); } +bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) { + const WindowDimension& window_dim = window.dimensions(logical_dim); + return window_dim.size() == 1 && window_dim.stride() == 1 && + window_dim.padding_low() == 0 && window_dim.padding_high() == 0; +} + int64 DilatedBound(int64 bound, int64 dilation) { CHECK_GE(bound, 0); CHECK_GE(dilation, 1); diff --git a/tensorflow/compiler/xla/window_util.h b/tensorflow/compiler/xla/window_util.h index 235cb2d59d451a25dc4f824ab488f8cef6b03bfb..ba473e2c8c35202865a9a4981da7653fe1d6f552 100644 --- a/tensorflow/compiler/xla/window_util.h +++ b/tensorflow/compiler/xla/window_util.h @@ -18,10 +18,21 @@ limitations under the License. #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" namespace xla { namespace window_util { +// Creates a window with the given sizes in the dimensions and all strides set +// to 1. +Window MakeWindow(tensorflow::gtl::ArraySlice sizes); + +// Creates a padding config with symmetrical padding in each dimension, of value +// given by sizes; e.g. {0, 1, 2} would create a R3 padding config that had zero +// pixels of padding in dimension 0, one pixel of padding symmetrically, on each +// side of dimension 1, and two pixels of padding symmetrically on dimension 2. +PaddingConfig MakeSymmetricPadding(tensorflow::gtl::ArraySlice sizes); + string ToString(const WindowDimension& dim); string ToString(const Window& window); @@ -32,13 +43,24 @@ string ToString(const Window& window); bool HasStride(const Window& window); bool HasPadding(const Window& window); -bool HasEvenPadding(const Window& window); +bool HasSymmetricPadding(const Window& window); bool HasNegativePadding(const Window& window); +// As with HasSymmetricPadding(Window) above, returns whether the "padding low" +// is equivalent to the "padding high" for all dimensions, but works on a +// padding configuration. +bool HasSymmetricPadding(const PaddingConfig& padding_config); + bool HasBaseDilation(const Window& window); bool HasWindowDilation(const Window& window); bool HasDilation(const Window& window); +bool HasWindowReversal(const Window& window); + +// Returns true if the given logical dimension is inactive in the sense that it +// has window bound 1, no striding and no padding. +bool IsInactiveWindowDimension(const Window& window, int64 logical_dim); + // Returns the new bound after dilation. // // If a window with the given bound in some dimension is dilated with the given diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 127e5e81ac6d21945c7125ef913d236e8892758e..56162ab44e2e0e3e4478fe631888f243332dc1d8 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -82,8 +82,9 @@ message DebugOptions { // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; - // Dump compilation artifacts in binary proto into this directory. - string xla_dump_hlo_proto_to = 8; + // Dump Hlo after all hlo passes are executed as proto binary into this + // directory. + string xla_dump_optimized_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; @@ -175,6 +176,18 @@ message DebugOptions { // assignments, if available. bool xla_hlo_tfgraph_device_scopes = 93; + // If true, the GPU backend is free to use cudnn for HLO batch normalization + // ops. + bool xla_gpu_use_cudnn_batchnorm = 94; + + // Dump HLO before any hlo passes are executed as proto binary into this + // directory. + string xla_dump_unoptimized_hlo_proto_to = 95; + + // Dump HLO after each pass as an HloProto in binary file format into this + // directory. + string xla_dump_per_pass_hlo_proto_to = 96; + // Extra options to pass to the compilation backend; specific interpretation // of these values is left to the backend. map xla_backend_extra_options = 500; diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 2ba1a2d904e45e582ee4e8a4ea889ee69d55e747..3aea0217539b89b5d60ecfaf2605eee4b69af728 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -114,6 +114,17 @@ message PaddingConfig { repeated PaddingConfigDimension dimensions = 1; } +// A format specifies the method used by a layout to store an array in memory. +enum Format { + INVALID_FORMAT = 0; + // The default layout, with exactly one storage location per element (ignoring + // padding). + DENSE = 1; + // A sparsely encoded layout, providing only the index/value pairs of non-zero + // elements. + SPARSE = 2; +} + // A layout describes how the array is placed in (1D) memory space. This // includes the minor-to-major ordering of dimensions within a shape, as well as // any padding present in those dimensions. @@ -124,21 +135,30 @@ message PaddingConfig { // // See the XLA documentation for more information on shapes and layouts. message Layout { + // The method used to store the data in memory. The format determines which of + // the other fields are used by the layout. + Format format = 4; + // Sequence of dimension numbers, from minor (fastest varying index) to major // (slowest varying index). This field is required. repeated int64 minor_to_major = 1; - // The width to which the layout of each dimension is padded up - // to. If present, the size of the padded_dimensions must equal the - // rank of the shape. The padding appears at the end of a dimension, - // not at the beginning. This kind of padding, unlike padding in - // e.g. convolution, is not part of the shape. + // The width to which the layout of each dimension is padded up to. If + // present, the size of the padded_dimensions must equal the rank of the + // shape. The padding appears at the end of a dimension, not at the + // beginning. This kind of padding, unlike padding in e.g. convolution, is not + // part of the shape. This field must be unset unless the format is DENSE. repeated int64 padded_dimensions = 2; - // Describes the values in the padding specified by - // padded_dimensions. + // Describes the values in the padding specified by padded_dimensions. This + // field must be unset unless the format is DENSE. PaddingValue padding_value = 3; + // The maximum number of elements that can be stored for SPARSE formats. This + // can be used to determine the maximum size in bytes of arrays stored in + // memory. This field must be unset unless the format is SPARSE. + int64 max_sparse_elements = 5; + // Important: if any field is added, be sure to modify ShapeUtil::Equal() // appropriately to account for the new field. } @@ -321,7 +341,8 @@ message LiteralProto { // The F16s and BF16s are encoded in little endian byte order bytes f16s = 11; bytes bf16s = 13; - // Next = 14 + repeated int64 sparse_indices = 14; + // Next = 15 } message WindowDimension { @@ -498,6 +519,23 @@ message CustomCallRequest { Shape shape = 4; } +message DotDimensionNumbers { + // The dimension numbers that represent the 'lhs' contracting dimensions. + repeated int64 lhs_contracting_dimensions = 1; + // The dimension numbers that represent the 'rhs' contracting dimensions. + repeated int64 rhs_contracting_dimensions = 2; + // The dimension numbers that represent the 'lhs' batch dimensions. + repeated int64 lhs_batch_dimensions = 3; + // The dimension numbers that represent the 'rhs' batch dimensions. + repeated int64 rhs_batch_dimensions = 4; +}; + +message DotRequest { + ComputationDataHandle lhs = 2; + ComputationDataHandle rhs = 3; + DotDimensionNumbers dimension_numbers = 4; +} + message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; @@ -651,6 +689,14 @@ message ConcatenateRequest { int64 dimension = 3; } +message ConditionalRequest { + ComputationDataHandle predicate = 2; + ComputationDataHandle true_operand = 3; + ComputationHandle true_computation = 4; + ComputationDataHandle false_operand = 5; + ComputationHandle false_computation = 6; +} + message WhileRequest { ComputationHandle condition = 2; ComputationHandle body = 3; @@ -732,9 +778,6 @@ enum BinaryOperation { BINOP_LT = 9; BINOP_NE = 10; - // Dot product, matrix multiply. - BINOP_DOT = 12; - // Element-wise maximum. BINOP_MAX = 14; @@ -780,9 +823,7 @@ enum RandomDistribution { // parameter[0] and standard deviation parameter[1]. RNG_NORMAL = 2; - // Creates a Bernoulli-distribution-generated random number with mean - // parameter[0]. - RNG_BERNOULLI = 3; + // Next: 4 } message RngRequest { @@ -885,6 +926,7 @@ message OpRequest { ConvolveRequest convolve_request = 8; CrossReplicaSumRequest cross_replica_sum_request = 9; CustomCallRequest custom_call_request = 10; + DotRequest dot_request = 43; DynamicSliceRequest dynamic_slice_request = 11; DynamicUpdateSliceRequest dynamic_update_slice_request = 12; GetTupleElementRequest get_tuple_element_request = 13; @@ -914,7 +956,8 @@ message OpRequest { BatchNormInferenceRequest batch_norm_inference_request = 38; FftRequest fft_request = 41; ConvertRequest bitcast_convert_request = 42; - // Next: 43 + ConditionalRequest conditional_request = 44; + // Next: 45 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index b7ade951150412e0ad3f72c235f0677e68fce66e..bab37e8906e5c648acdc1556da7e5f4601776ff5 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -6,10 +6,17 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") py_library( name = "contrib_py", - srcs = glob(["**/*.py"]), + srcs = glob( + ["**/*.py"], + exclude = [ + "**/*_test.py", + ], + ), srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ @@ -18,7 +25,9 @@ py_library( "//tensorflow/contrib/bayesflow:bayesflow_py", "//tensorflow/contrib/boosted_trees:init_py", "//tensorflow/contrib/cloud:cloud_py", + "//tensorflow/contrib/cluster_resolver:cluster_resolver_pip", "//tensorflow/contrib/cluster_resolver:cluster_resolver_py", + "//tensorflow/contrib/coder:coder_ops_py", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/copy_graph:copy_graph_py", "//tensorflow/contrib/crf:crf_py", @@ -29,6 +38,7 @@ py_library( "//tensorflow/contrib/eager/python:tfe", "//tensorflow/contrib/estimator:estimator_py", "//tensorflow/contrib/factorization:factorization_py", + "//tensorflow/contrib/feature_column:feature_column_py", "//tensorflow/contrib/ffmpeg:ffmpeg_ops_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/fused_conv:fused_conv_py", @@ -41,6 +51,7 @@ py_library( "//tensorflow/contrib/image:single_image_random_dot_stereograms_py", "//tensorflow/contrib/input_pipeline:input_pipeline_py", "//tensorflow/contrib/integrate:integrate_py", + "//tensorflow/contrib/kafka", "//tensorflow/contrib/keras", "//tensorflow/contrib/kernel_methods", "//tensorflow/contrib/kfac", @@ -48,6 +59,7 @@ py_library( "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/contrib/legacy_seq2seq:seq2seq_py", + "//tensorflow/contrib/libsvm", "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", @@ -60,13 +72,15 @@ py_library( "//tensorflow/contrib/metrics:metrics_py", "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", - "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", + "//tensorflow/contrib/periodic_resample:init_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", "//tensorflow/contrib/quantize:quantize_graph", + "//tensorflow/contrib/py2tf", + "//tensorflow/contrib/receptive_field:receptive_field_py", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", @@ -94,20 +108,22 @@ py_library( "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", - ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]), + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_tensorrt([ + "//tensorflow/contrib/tensorrt:init_py", + ]), ) cc_library( name = "contrib_kernels", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/batching:batch_ops_kernels", "//tensorflow/contrib/boosted_trees:boosted_trees_kernels", + "//tensorflow/contrib/coder:all_kernels", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_kernels", + "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/contrib/factorization/kernels:all_kernels", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels", "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", - "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", "//tensorflow/contrib/rnn:all_kernels", "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", @@ -115,19 +131,23 @@ cc_library( "//tensorflow/contrib/tensor_forest:stats_ops_kernels", "//tensorflow/contrib/tensor_forest:tensor_forest_kernels", "//tensorflow/contrib/text:all_kernels", - ], + ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_collectives_py"]) + if_cuda([ + "//tensorflow/contrib/nccl:nccl_kernels", + ]), ) cc_library( name = "contrib_ops_op_lib", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/batching:batch_ops_op_lib", "//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib", + "//tensorflow/contrib/coder:all_ops", "//tensorflow/contrib/cudnn_rnn:cudnn_rnn_ops_op_lib", + "//tensorflow/contrib/data:dataset_ops_op_lib", "//tensorflow/contrib/factorization:all_ops", "//tensorflow/contrib/framework:all_ops", "//tensorflow/contrib/input_pipeline:input_pipeline_ops_op_lib", + "//tensorflow/contrib/kafka:kafka_ops_op_lib", "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 1eda1abfcf779ece7af3dbf2554c2a0a8c2611e9..4f6f539027b040de7554d09fe9118ff97aa006f8 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -19,9 +19,11 @@ from __future__ import division from __future__ import print_function # Add projects here, they will show up under tf.contrib. +from tensorflow.contrib import batching from tensorflow.contrib import bayesflow from tensorflow.contrib import cloud from tensorflow.contrib import cluster_resolver +from tensorflow.contrib import coder from tensorflow.contrib import compiler from tensorflow.contrib import copy_graph from tensorflow.contrib import crf @@ -31,6 +33,7 @@ from tensorflow.contrib import deprecated from tensorflow.contrib import distributions from tensorflow.contrib import estimator from tensorflow.contrib import factorization +from tensorflow.contrib import feature_column from tensorflow.contrib import framework from tensorflow.contrib import gan from tensorflow.contrib import graph_editor @@ -55,6 +58,7 @@ from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt +from tensorflow.contrib import periodic_resample from tensorflow.contrib import predictor from tensorflow.contrib import quantization from tensorflow.contrib import quantize @@ -80,14 +84,14 @@ from tensorflow.contrib import training from tensorflow.contrib import util from tensorflow.contrib.eager.python import tfe as eager from tensorflow.contrib.lite.python import lite -from tensorflow.contrib.ndlstm import python as ndlstm +from tensorflow.contrib.receptive_field import receptive_field_api as receptive_field from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs from tensorflow.contrib.summary import summary from tensorflow.python.util.lazy_loader import LazyLoader -ffmpeg = LazyLoader("ffmpeg", - globals(), "tensorflow.contrib.ffmpeg") +ffmpeg = LazyLoader("ffmpeg", globals(), + "tensorflow.contrib.ffmpeg") del LazyLoader del absolute_import diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index a5057da9fd43a88575813613d6ac9d17fd2b2e28..6658f0d9c13f6db17b25354cde2593d57f104f17 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -48,7 +48,7 @@ def _flatten_tensors(tensors): if shape.ndims is None: raise ValueError("At least one of the tensors in 'tensors' must have " "statically known rank.") - if len(shape) > 1: + if len(shape) != 1: reshaped = [] for t in tensors: with ops.colocate_with(t): @@ -289,7 +289,7 @@ def build_ring_all_reduce(input_tensors, num_workers, num_subchunks, chunks_by_dev) if pad_len > 0: output_tensors = _strip_padding(output_tensors, pad_len) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors @@ -466,7 +466,7 @@ def build_recursive_hd_all_reduce(input_tensors, red_op, un_op=None): if un_op: reduced_shards = [un_op(t) for t in reduced_shards] output_tensors = _build_recursive_hd_scatter(reduced_shards, devices) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors @@ -578,7 +578,7 @@ def build_shuffle_all_reduce(input_tensors, gather_devices, red_op, un_op=None): reduced_shards = _build_shuffle_gather(input_tensors, gather_devices, red_op, un_op) output_tensors = _build_shuffle_scatter(reduced_shards, dst_devices) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors @@ -744,21 +744,21 @@ def _build_nccl_hybrid(input_tensors, red_op, upper_level_f): level_2_output = upper_level_f(up_values) # Third stage: propagate within each worker using NCCL Broadcast for w in range(0, num_workers): - dst_devices = per_worker_devices[w][1:] - send_op, dst_tensors = nccl.broadcast(level_2_output[w], dst_devices) - # NOTE: need control dependency to ensure send_op executes - with ops.control_dependencies([send_op]): - with ops.device(per_worker_devices[w][0]): - dst_tensors.insert(0, array_ops.identity(level_2_output[w])) - down_values[w] = dst_tensors + dst_tensors = [] + with ops.device(per_worker_devices[w][0]): + broadcast_src = nccl.broadcast(array_ops.identity(level_2_output[w])) + for d in per_worker_devices[w]: + with ops.device(d): + dst_tensors.append(array_ops.identity(broadcast_src)) + down_values[w] = dst_tensors output_tensors = [v for sublist in down_values for v in sublist] - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors def _reduce_non_singleton(input_tensors, red_f, un_op): - """If input_tenors has more than one element apply red_f, else apply un_op.""" + """If input_tensors has more than one element apply red_f, else apply un_op.""" if len(input_tensors) > 1: return red_f(input_tensors) else: @@ -831,7 +831,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): for w in range(0, num_workers): output_tensors += _build_shuffle_scatter( [level_2_output[w]], per_worker_devices[w]) - if len(shape) > 1: + if len(shape) != 1: output_tensors = _reshape_tensors(output_tensors, shape) return output_tensors diff --git a/tensorflow/contrib/all_reduce/python/all_reduce_test.py b/tensorflow/contrib/all_reduce/python/all_reduce_test.py index 0802b2736909c2a6f075ea2eac6d4dd3ab2918d8..47bab0a3670a90644972b2c961954a3036b8ecba 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce_test.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce_test.py @@ -119,7 +119,7 @@ class AllReduceTest(test_util.TensorFlowTestCase): def _buildInitialVars(self, shape, dev_list): values = [] num_devices = len(dev_list) - dim = np.prod(shape) + dim = np.prod(shape) if shape else 1 for d in range(0, num_devices): with ops.device(dev_list[d]): npt = np.zeros(shape).astype(np.float32) @@ -164,6 +164,7 @@ class AllReduceTest(test_util.TensorFlowTestCase): (num_workers, num_gpus, shape, subdiv, elapsed)) def testRingAllReduce(self): + self._testRingAllReduce(1, 2, [], 1) self._testRingAllReduce(1, 2, [8], 1) self._testRingAllReduce(1, 2, [4, 4], 1) self._testRingAllReduce(6, 1, [8], 1) @@ -192,6 +193,7 @@ class AllReduceTest(test_util.TensorFlowTestCase): "elapsed=%f" % (num_workers, num_gpus, shape, elapsed)) def testShuffleAllReduce(self): + self._testShuffleAllReduce(1, 2, [], 1) self._testShuffleAllReduce(1, 2, [8], 1) self._testShuffleAllReduce(1, 2, [4, 4], 1) self._testShuffleAllReduce(1, 8, [32], 1) diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md index f49e5857fe5255c2459793cb1389052a2ff5f88f..db37bcf73d144eb81c32a461a276d10be7e2d193 100644 --- a/tensorflow/contrib/android/README.md +++ b/tensorflow/contrib/android/README.md @@ -15,9 +15,9 @@ For prebuilt libraries, see the page for a recent build. The TensorFlow Inference Interface is also available as a -[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and -can be included quite simply in your android project with a couple of lines in -the project's `build.gradle` file: +[JCenter package](https://bintray.com/google/tensorflow/tensorflow) +(see the tensorflow-android directory) and can be included quite simply in your +android project with a couple of lines in the project's `build.gradle` file: ``` allprojects { @@ -32,9 +32,9 @@ dependencies { ``` This will tell Gradle to use the -[latest version](https://bintray.com/google/tensorflow/tensorflow-android/_latestVersion) +[latest version](https://bintray.com/google/tensorflow/tensorflow/_latestVersion) of the TensorFlow AAR that has been released to -[https://bintray.com/google/tensorflow/tensorflow-android](https://bintray.com/google/tensorflow/tensorflow-android). +[JCenter](https://jcenter.bintray.com/org/tensorflow/tensorflow-android/). You may replace the `+` with an explicit version label if you wish to use a specific release of TensorFlow in your app. @@ -81,6 +81,11 @@ For documentation on building a self-contained AAR file with cmake, see [tensorflow/contrib/android/cmake](cmake). +### Makefile + +For documentation on building native TF libraries with make, including a CUDA-enabled variant for devices like the Nvidia Shield TV, see [tensorflow/contrib/makefile/README.md](../makefile/README.md) + + ## AssetManagerFileSystem This directory also contains a TensorFlow filesystem supporting the Android diff --git a/tensorflow/contrib/android/asset_manager_filesystem.h b/tensorflow/contrib/android/asset_manager_filesystem.h index 2b43939f148e360945e5d488d148fcb2c13008a6..665304b5eef1f8a3633c8c522259e20d744b1808 100644 --- a/tensorflow/contrib/android/asset_manager_filesystem.h +++ b/tensorflow/contrib/android/asset_manager_filesystem.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ +#ifndef TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ +#define TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ #include #include @@ -79,4 +79,4 @@ class AssetManagerFileSystem : public FileSystem { }; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ +#endif // TENSORFLOW_CONTRIB_ANDROID_ASSET_MANAGER_FILESYSTEM_H_ diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt index aba356d6167658f125001cbed6e3190c716ee7d6..a115d1610e2334a6626f29674f3dd195e3a3c648 100644 --- a/tensorflow/contrib/android/cmake/CMakeLists.txt +++ b/tensorflow/contrib/android/cmake/CMakeLists.txt @@ -34,6 +34,8 @@ add_library(lib_tf STATIC IMPORTED ) set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION ${PREBUILT_DIR}/lib/libtensorflow-core.a) # Change to compile flags should be replicated into bazel build file +# TODO: Consider options other than -O2 for binary size. +# e.g. -Os for gcc, and -Oz for clang. set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \ -std=c++11 -fno-rtti -fno-exceptions \ -O2 -Wno-narrowing -fomit-frame-pointer \ diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index dc5b9fb88742d78d0f40207b589e29451a6358dd..abddadac5bcace9b1f992b69bdcc69c24b29cd13 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -194,6 +194,11 @@ public class TensorFlowInferenceInterface { * @param outputNames A list of output nodes which should be filled by the inference pass. */ public void run(String[] outputNames, boolean enableStats) { + run(outputNames, enableStats, new String[] {}); + } + + /** An overloaded version of runInference that allows supplying targetNodeNames as well */ + public void run(String[] outputNames, boolean enableStats, String[] targetNodeNames) { // Release any Tensors from the previous run calls. closeFetches(); @@ -204,6 +209,11 @@ public class TensorFlowInferenceInterface { runner.fetch(tid.name, tid.outputIndex); } + // Add targets. + for (String t : targetNodeNames) { + runner.addTarget(t); + } + // Run the session. try { if (enableStats) { diff --git a/tensorflow/contrib/android/jni/run_stats_jni.cc b/tensorflow/contrib/android/jni/run_stats_jni.cc index 119fa9cd2c378d2ba2383ea8b0e09e1b6083d84e..707853b59befc2625145ad96952fbf9f66d62b43 100644 --- a/tensorflow/contrib/android/jni/run_stats_jni.cc +++ b/tensorflow/contrib/android/jni/run_stats_jni.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/util/stat_summarizer.h" -using tensorflow::StatSummarizer; using tensorflow::RunMetadata; +using tensorflow::StatSummarizer; namespace { StatSummarizer* requireHandle(JNIEnv* env, jlong handle) { diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index a111cfecb366fe245150cc71d2c43662d0d69090..ee67909133fc26ba98355db05a4b90d3dfa6b97b 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -12,7 +12,7 @@ cc_library( name = "batch_scheduler_hdrs", hdrs = ["batch_scheduler.h"], deps = [ - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs", ], ) @@ -20,18 +20,7 @@ cc_library( name = "batch_scheduler", hdrs = ["batch_scheduler.h"], deps = [ - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "batch_scheduler_test", - srcs = ["batch_scheduler_test.cc"], - deps = [ - ":batch_scheduler", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/kernels/batching_util:batch_scheduler", ], ) @@ -39,9 +28,7 @@ cc_library( name = "shared_batch_scheduler_hdrs", hdrs = ["shared_batch_scheduler.h"], deps = [ - ":batch_scheduler_hdrs", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:framework_headers_lib", + "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs", ], ) @@ -49,46 +36,16 @@ cc_library( name = "shared_batch_scheduler", hdrs = ["shared_batch_scheduler.h"], deps = [ - ":batch_scheduler", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:lib", + "//tensorflow/core/kernels/batching_util:shared_batch_scheduler", ], alwayslink = 1, ) -tf_cc_test( - name = "shared_batch_scheduler_test", - srcs = ["shared_batch_scheduler_test.cc"], - deps = [ - ":shared_batch_scheduler", - "//tensorflow/contrib/batching/test_util:fake_clock_env", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - cc_library( name = "adaptive_shared_batch_scheduler", hdrs = ["adaptive_shared_batch_scheduler.h"], deps = [ - ":batch_scheduler", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "adaptive_shared_batch_scheduler_test", - srcs = ["adaptive_shared_batch_scheduler_test.cc"], - tags = ["manual"], # b/69013768 - deps = [ - ":adaptive_shared_batch_scheduler", - "//tensorflow/contrib/batching/test_util:fake_clock_env", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler", ], ) @@ -96,34 +53,7 @@ cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], deps = [ - ":shared_batch_scheduler", - ], -) - -tf_cc_test( - name = "basic_batch_scheduler_test", - srcs = ["basic_batch_scheduler_test.cc"], - deps = [ - ":basic_batch_scheduler", - ":batch_scheduler", - "//tensorflow/core:lib", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -tf_cc_test( - name = "basic_batch_scheduler_benchmark", - srcs = ["basic_batch_scheduler_benchmark.cc"], - tags = [ - "local", - "manual", - ], - deps = [ - ":basic_batch_scheduler", - "//tensorflow/core:lib", - "//tensorflow/core:tensorflow", - "//tensorflow/core:test", + "//tensorflow/core/kernels/batching_util:basic_batch_scheduler", ], ) @@ -137,48 +67,14 @@ load( ) load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") -tf_custom_op_library( - name = "python/ops/_batch_ops.so", - srcs = ["ops/batch_ops.cc"], - deps = [ - "//tensorflow/contrib/batching/kernels:batch_kernels", - ], -) - -tf_gen_op_libs( - op_lib_names = ["batch_ops"], -) - -tf_gen_op_wrapper_py( - name = "batch_ops", - deps = [":batch_ops_op_lib"], -) - -tf_kernel_library( - name = "batch_ops_kernels", - deps = [ - "//tensorflow/contrib/batching/kernels:batch_kernels", - "//tensorflow/contrib/batching/util:periodic_function", - "//tensorflow/core/kernels:concat_lib", - "//tensorflow/core/kernels:ops_util", - "//tensorflow/core/kernels:split_lib", - ], - alwayslink = 1, -) - -tf_custom_op_py_library( +py_library( name = "batch_py", srcs = glob(["python/ops/*.py"]) + ["__init__.py"], - dso = [":python/ops/_batch_ops.so"], - kernels = [ - ":batch_ops_kernels", - ":batch_ops_op_lib", - ], srcs_version = "PY2AND3", deps = [ - ":batch_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", + "//tensorflow/python:batch_ops_gen", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", @@ -188,6 +84,14 @@ tf_custom_op_py_library( ], ) +cc_library( + name = "batch_ops_kernels", + deps = [ + "//tensorflow/core/kernels:batch_kernels", + ], + alwayslink = 1, +) + py_test( name = "batch_ops_test", size = "small", @@ -203,6 +107,7 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python:framework", "//tensorflow/python:gradients", "//tensorflow/python:script_ops", ], diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h index 6ed177e001758ad8c566c7965e1ec10ae5235fc8..86250e6692004a12a1fa338767a5db1e4c2e4195 100644 --- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -13,450 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ -#include -#include -#include -#include -#include +#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h" -#include "tensorflow/contrib/batching/batch_scheduler.h" -#include "tensorflow/contrib/batching/util/periodic_function.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/core/threadpool.h" -#include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { -namespace internal { -template -class ASBSBatch; - -template -class ASBSQueue; -} // namespace internal - -// Shared batch scheduler designed to minimize latency. The scheduler keeps -// track of a number of queues (one per model or model version) which are -// continuously enqueuing requests. The scheduler groups the requests into -// batches which it periodically sends off for processing (see -// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler -// prioritizes batches by age (i.e. the batch's oldest request) irrespective of -// queue. The scheduler will process the oldest batch at an adjustable rate, -// regardless of batch size. The user can provide feedback to help set this rate -// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). -// -// The rate (or rather, the corresponding period) is adjusted each time a batch -// is processed, using an exponentially weighted moving average to smooth -// potentially noisy feedback: -// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N -// period *= (1 + K * emwa_feedback) -// -// Some potential use cases: -// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing -// involves serial processing by a device, from a latency perspective it is -// desirable to keep the device evenly loaded, avoiding the need to wait for -// the device to process prior batches. -// feedback = num_pending_on_device() - desired_pending. -// CPU utilization - If the batch processing is cpu dominated, you can reap -// latency gains when underutilized by increasing the processing rate, but -// back the rate off when the load increases to avoid overload. -// feedback = cpu_rate() - desired_cpu_rate. - -template -class AdaptiveSharedBatchScheduler - : public std::enable_shared_from_this< - AdaptiveSharedBatchScheduler> { - public: - struct Options { - // The name to use for the pool of batch threads. - string thread_pool_name = {"batch_threads"}; - // Number of batch processing threads; equivalently the maximum number of - // concurrently running batches. - int64 num_batch_threads = port::NumSchedulableCPUs(); - // The environment to use (typically only overridden by test code). - Env* env = Env::Default(); - // Initial batch scheduling period in microseconds. Will be altered for - // non-zero rate_feedback. - double initial_scheduling_period_micros = 500; - // Minimum batch scheduling period in microseconds. Recommend setting this - // value greater than 0, otherwise it may take a while to recover from a - // sustained time of negative scheduling_period_feedback (which may occur - // under low load). - double min_scheduling_period_micros = 100; - // Maximum batch scheduling period in microseconds. - double max_scheduling_period_micros = 10000; - // Feedback function used to modify the scheduling period each time a batch - // is scheduled. Should return values roughly O(1), with positive values - // resulting in an increased period. - std::function scheduling_period_feedback{[] { return 0.; }}; - // To handle potentially noisy scheduling_period_feedback, the period is - // adjusted using an exponentially weighted moving average over the previous - // feedback_smoothing_batches batches. Must be greater than 0. - int64 feedback_smoothing_batches = 10; - }; - - // Ownership is shared between the caller of Create() and any queues created - // via AddQueue(). - static Status Create( - const Options& options, - std::shared_ptr>* scheduler); - - struct QueueOptions { - // Maximum size of each batch. - int max_batch_size = 1000; - // Maximum number of enqueued (i.e. non-scheduled) batches. - int max_enqueued_batches = 10; - }; - - using BatchProcessor = std::function>)>; - - // Adds queue (and its callback) to be managed by this scheduler. - Status AddQueue(const QueueOptions& options, - BatchProcessor process_batch_callback, - std::unique_ptr>* queue); - - private: - // access to AddBatch, RemoveQueue, GetEnv. - friend class internal::ASBSQueue; - - explicit AdaptiveSharedBatchScheduler(const Options& options); - - // Batch scheduling function which runs every scheduling_period_ microseconds. - void ProcessOneBatch(); - - // Notifies scheduler of non-empty batch which is eligible for processing. - void AddBatch(internal::ASBSBatch*); - - // Removes queue from scheduler. - void RemoveQueue(const internal::ASBSQueue* queue); - - Env* GetEnv() const { return options_.env; } - - const Options options_; - - struct BatchCompare { - bool operator()(const internal::ASBSBatch* a, - const internal::ASBSBatch* b); - }; - - // Collection of batches added by AddBatch, ordered by age. Owned by scheduler - // until they are released for processing. - std::priority_queue*, - std::vector*>, BatchCompare> - batches_ GUARDED_BY(mu_); - - // Unowned queues and callbacks added by AddQueue. - std::unordered_map*, BatchProcessor> - queues_and_callbacks_ GUARDED_BY(mu_); - - mutex mu_; - - // Responsible for running ProcessOneBatch. PeriodicFunction was used in order - // to check for deletion so that the thread can be shut down. - std::unique_ptr scheduling_thread_; - - // Responsible for running the batch processing callbacks. - std::unique_ptr batch_thread_pool_; - - // Time interval in microseconds between successive ProcessOneBatch calls. - double scheduling_period_; - - // Exponentially weighted moving average of - // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch - // call. - double ewma_feedback_ = 0; - - TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); -}; - -////////////////////////////////////////////////////////// -// Implementation details follow. API users need not read. - -namespace internal { -// Consolidates tasks into batches, passing them off to the -// AdaptiveSharedBatchScheduler for processing. -template -class ASBSQueue : public BatchScheduler { - public: - using QueueOptions = - typename AdaptiveSharedBatchScheduler::QueueOptions; - - ASBSQueue(std::shared_ptr> scheduler, - const QueueOptions& options); - - ~ASBSQueue() override; - - // Adds task to current batch. Fails if the task size is larger than the batch - // size or if the current batch is full and this queue's number of outstanding - // batches is at its maximum. - Status Schedule(std::unique_ptr* task) override; - - // Number of tasks waiting to be scheduled. - size_t NumEnqueuedTasks() const override; - - // Number of size 1 tasks which could currently be scheduled without failing. - size_t SchedulingCapacity() const override; - - // Notifies queue that a batch is about to be scheduled; the queue should not - // place any more tasks in this batch. - void ReleaseBatch(const ASBSBatch* batch); - - private: - std::shared_ptr> scheduler_; - const QueueOptions options_; - // Owned by scheduler_. - ASBSBatch* current_batch_ GUARDED_BY(mu_) = nullptr; - int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0; - int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0; - mutable mutex mu_; - TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue); -}; - -// Batch which remembers when and by whom it was created. -template -class ASBSBatch : public Batch { - public: - ASBSBatch(ASBSQueue* queue, int64 creation_time_micros) - : queue_(queue), creation_time_micros_(creation_time_micros) {} - - ~ASBSBatch() override {} - - ASBSQueue* queue() const { return queue_; } - - int64 creation_time_micros() const { return creation_time_micros_; } - - private: - ASBSQueue* queue_; - const int64 creation_time_micros_; - TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch); -}; -} // namespace internal - -// ---------------- AdaptiveSharedBatchScheduler ---------------- - -template -Status AdaptiveSharedBatchScheduler::Create( - const Options& options, - std::shared_ptr>* scheduler) { - if (options.num_batch_threads < 1) { - return errors::InvalidArgument("num_batch_threads must be positive; was ", - options.num_batch_threads); - } - if (options.min_scheduling_period_micros < 0) { - return errors::InvalidArgument( - "min_scheduling_period_micros must be >= 0; was ", - options.min_scheduling_period_micros); - } - if (options.min_scheduling_period_micros > - options.initial_scheduling_period_micros) { - return errors::InvalidArgument( - "initial_scheduling_period_micros (", - options.initial_scheduling_period_micros, - ") must be >= min_scheduling_period_micros (", - options.min_scheduling_period_micros, ")"); - } - if (options.initial_scheduling_period_micros > - options.max_scheduling_period_micros) { - return errors::InvalidArgument( - "initial_scheduling_period_micros (", - options.initial_scheduling_period_micros, - ") must be <= max_scheduling_period_micros (", - options.max_scheduling_period_micros, ")"); - } - if (options.feedback_smoothing_batches < 1) { - return errors::InvalidArgument( - "feedback_smoothing_batches must be positive; was ", - options.feedback_smoothing_batches); - } - scheduler->reset(new AdaptiveSharedBatchScheduler(options)); - return Status::OK(); -} - -template -AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( - const Options& options) - : options_(options), - scheduling_period_(options.initial_scheduling_period_micros) { - PeriodicFunction::Options opts; - opts.thread_name_prefix = "scheduling_thread"; - opts.env = GetEnv(); - scheduling_thread_.reset( - new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); - batch_thread_pool_.reset(new thread::ThreadPool( - GetEnv(), options.thread_pool_name, options.num_batch_threads)); -} - -template -Status AdaptiveSharedBatchScheduler::AddQueue( - const QueueOptions& options, BatchProcessor process_batch_callback, - std::unique_ptr>* queue) { - if (options.max_batch_size <= 0) { - return errors::InvalidArgument("max_batch_size must be positive; was ", - options.max_batch_size); - } - if (options.max_enqueued_batches <= 0) { - return errors::InvalidArgument( - "max_enqueued_batches must be positive; was ", - options.max_enqueued_batches); - } - internal::ASBSQueue* asbs_queue_raw; - queue->reset(asbs_queue_raw = new internal::ASBSQueue( - this->shared_from_this(), options)); - mutex_lock l(mu_); - queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; - return Status::OK(); -} - -template -void AdaptiveSharedBatchScheduler::AddBatch( - internal::ASBSBatch* batch) { - mutex_lock l(mu_); - batches_.push(batch); -} - -template -void AdaptiveSharedBatchScheduler::RemoveQueue( - const internal::ASBSQueue* queue) { - mutex_lock l(mu_); - queues_and_callbacks_.erase(queue); -} - -template -void AdaptiveSharedBatchScheduler::ProcessOneBatch() { - static const double kFeedbackMultiplier = .001; - internal::ASBSBatch* batch = nullptr; - BatchProcessor callback; - const int64 start_time_micros = GetEnv()->NowMicros(); - { - mutex_lock l(mu_); - if (!batches_.empty()) { - batch = batches_.top(); - batches_.pop(); - callback = queues_and_callbacks_[batch->queue()]; - } - } - if (batch != nullptr) { - double feedback = options_.scheduling_period_feedback(); - const int64 N = options_.feedback_smoothing_batches; - ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N; - scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_); - if (scheduling_period_ < options_.min_scheduling_period_micros) { - scheduling_period_ = options_.min_scheduling_period_micros; - } else if (scheduling_period_ > options_.max_scheduling_period_micros) { - scheduling_period_ = options_.max_scheduling_period_micros; - } - // Queue may destroy itself after ReleaseBatch is called. - batch->queue()->ReleaseBatch(batch); - batch_thread_pool_->Schedule([callback, batch] { - callback(std::unique_ptr>(batch)); - }); - } - const int64 sleep_time = - scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros); - if (sleep_time > 0) { - GetEnv()->SleepForMicroseconds(sleep_time); - } -} - -template -bool AdaptiveSharedBatchScheduler::BatchCompare::operator()( - const internal::ASBSBatch* a, - const internal::ASBSBatch* b) { - return a->creation_time_micros() > b->creation_time_micros(); -} - -// ---------------- ASBSQueue ---------------- - -namespace internal { -template -ASBSQueue::ASBSQueue( - std::shared_ptr> scheduler, - const QueueOptions& options) - : scheduler_(scheduler), options_(options) {} - -template -ASBSQueue::~ASBSQueue() { - // Wait until last batch has been scheduled. - const int kSleepMicros = 1000; - for (;;) { - { - mutex_lock l(mu_); - if (num_enqueued_batches_ == 0) { - break; - } - } - scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros); - } - scheduler_->RemoveQueue(this); -} - -template -Status ASBSQueue::Schedule(std::unique_ptr* task) { - ASBSBatch* new_batch = nullptr; - size_t size = (*task)->size(); - if (size > options_.max_batch_size) { - return errors::InvalidArgument("Task size ", size, - " is larger than maximum batch size ", - options_.max_batch_size); - } - { - mutex_lock l(mu_); - // Current batch is full, create another if allowed. - if (current_batch_ && - current_batch_->size() + size > options_.max_batch_size) { - if (num_enqueued_batches_ >= options_.max_enqueued_batches) { - return errors::Unavailable("The batch scheduling queue is full"); - } - current_batch_->Close(); - current_batch_ = nullptr; - } - if (!current_batch_) { - num_enqueued_batches_++; - current_batch_ = new_batch = - new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); - } - current_batch_->AddTask(std::move(*task)); - num_enqueued_tasks_++; - } - if (new_batch != nullptr) scheduler_->AddBatch(new_batch); - return Status::OK(); -} - -template -void ASBSQueue::ReleaseBatch(const ASBSBatch* batch) { - mutex_lock l(mu_); - num_enqueued_batches_--; - num_enqueued_tasks_ -= batch->num_tasks(); - if (batch == current_batch_) { - current_batch_->Close(); - current_batch_ = nullptr; - } -} - -template -size_t ASBSQueue::NumEnqueuedTasks() const { - mutex_lock l(mu_); - return num_enqueued_tasks_; -} - -template -size_t ASBSQueue::SchedulingCapacity() const { - mutex_lock l(mu_); - const int current_batch_capacity = - current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; - const int spare_batches = - options_.max_enqueued_batches - num_enqueued_batches_; - return spare_batches * options_.max_batch_size + current_batch_capacity; -} -} // namespace internal -} // namespace serving -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h index 9d3805fbaf39978159dd2f4a754e6d41a07acf6a..d9b37da6933aa0847c229607f43d1d5d121a928c 100644 --- a/tensorflow/contrib/batching/basic_batch_scheduler.h +++ b/tensorflow/contrib/batching/basic_batch_scheduler.h @@ -13,252 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ -#include -#include -#include -#include -#include +#include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h" -#include "tensorflow/contrib/batching/shared_batch_scheduler.h" - -namespace tensorflow { -namespace serving { - -// A BatchScheduler implementation geared toward handling a single request type -// running on a specific set of hardware resources. A typical scenario is one in -// which all requests invoke the same machine-learned model on one GPU. -// -// If there are, say, two GPUs and two models each bound to one of the GPUs, one -// could use two BasicBatchScheduler instances to schedule the two model/GPU -// combinations independently. If multiple models must share a given GPU or -// other hardware resource, consider using SharedBatchScheduler instead. -// -// -// PARAMETERS AND BEHAVIOR: -// -// BasicBatchScheduler runs a fixed pool of threads, which it uses to process -// batches of tasks. It enforces a maximum batch size, and enqueues a bounded -// number of tasks. If the queue is nearly empty, such that a full batch cannot -// be formed, when a thread becomes free, it anyway schedules a batch -// immediately if a task has been in the queue for longer than a given timeout -// parameter. If the timeout parameter is set to 0, then the batch threads will -// always be kept busy (unless there are zero tasks waiting to be processed). -// -// For online serving, it is recommended to set the maximum number of enqueued -// batches worth of tasks equal to the number of batch threads, which allows -// enqueuing of enough tasks s.t. if every thread becomes available it can be -// kept busy, but no more. For bulk processing jobs and throughput-oriented -// benchmarks, you may want to set it much higher. -// -// When Schedule() is called, if the queue is full the call will fail with an -// UNAVAILABLE error (after which the client may retry again later). If the call -// succeeds, the maximum time the task will spend in the queue before being -// placed in a batch and assigned to a thread for processing, is the greater of: -// - the maximum time to process ceil(max_enqueued_batches/num_batch_threads) -// (1 in the recommended configuration) batches of previously-submitted tasks -// - the configured timeout parameter (which can be 0, as mentioned above) -// -// Unlike StreamingBatchScheduler, when BasicBatchScheduler assigns a batch to a -// thread, it closes the batch. The process-batch callback may assume that every -// batch it receives is closed at the outset. -// -// -// RECOMMENDED USE-CASES: -// -// BasicBatchScheduler is suitable for use-cases that feature a single kind of -// request (e.g. a server performing inference with a single machine-learned -// model, possibly evolving over time), with loose versioning semantics. -// Concretely, the following conditions should hold: -// -// A. All requests batched onto a given resource (e.g. a hardware accelerator, -// or a pool accelerators) are of the same type. For example, they all -// invoke the same machine-learned model. -// -// These variations are permitted: -// - The model may reside in a single servable, or it may be spread across -// multiple servables that are used in unison (e.g. a vocabulary lookup -// table servable and a tensorflow session servable). -// - The model's servable(s) may be static, or they may evolve over time -// (successive servable versions). -// - Zero or more of the servables are used in the request thread; the rest -// are used in the batch thread. In our running example, the vocabulary -// lookups and tensorflow runs may both be performed in the batch thread, -// or alternatively the vocabulary lookup may occur in the request thread -// with only the tensorflow run performed in the batch thread. -// -// In contrast, BasicBatchScheduler is not a good fit if the server -// hosts multiple distinct models running on a pool accelerators, with each -// request specifying which model it wants to use. BasicBatchScheduler -// has no facility to time-multiplex the batch threads across multiple -// models in a principled way. More basically, it cannot ensure that a given -// batch doesn't contain a mixture of requests for different models. -// -// B. Requests do not specify a particular version of the servable(s) that must -// be used. Instead, each request is content to use the "latest" version. -// -// BasicBatchScheduler does not constrain which requests get grouped -// together into a batch, so using this scheduler there is no way to achieve -// cohesion of versioned requests to version-specific batches. -// -// C. No servable version coordination needs to be performed between the -// request threads and the batch threads. Often, servables are only used in -// the batch threads, in which case this condition trivially holds. If -// servables are used in both threads, then the use-case must tolerate -// version skew across the servables used in the two kinds of threads. -// -// -// EXAMPLE USE-CASE FLOW: -// -// For such use-cases, request processing via BasicBatchScheduler generally -// follows this flow (given for illustration; variations are possible): -// 1. Optionally perform some pre-processing on each request in the request -// threads. -// 2. Route the requests to the batch scheduler, as batching::Task objects. -// (Since all requests are of the same type and are not versioned, the -// scheduler is free to group them into batches arbitrarily.) -// 3. Merge the requests into a single batched representation B. -// 4. Obtain handles to the servable(s) needed to process B. The simplest -// approach is to obtain the latest version of each servable. Alternatively, -// if cross-servable consistency is required (e.g. the vocabulary lookup -// table's version number must match that of the tensorflow session), -// identify an appropriate version number and obtain the servable handles -// accordingly. -// 5. Process B using the obtained servable handles, and split the result into -// individual per-request units. -// 6. Perform any post-processing in the batch thread and/or request thread. -// -// -// PERFORMANCE TUNING: See README.md. -// -template -class BasicBatchScheduler : public BatchScheduler { - public: - // TODO(b/25089730): Tune defaults based on best practices as they develop. - // (Keep them mirrored to the ones in SharedBatchScheduler::QueueOptions and - // SharedBatchScheduler::Options.) - struct Options { - // The maximum size of each batch. - // - // The scheduler may form batches of any size between 1 and this number - // (inclusive). If there is a need to quantize the batch sizes, i.e. only - // submit batches whose size is in a small set of allowed sizes, that can be - // done by adding padding in the process-batch callback. - int max_batch_size = 1000; - - // If a task has been enqueued for this amount of time (in microseconds), - // and a thread is available, the scheduler will immediately form a batch - // from enqueued tasks and assign the batch to the thread for processing, - // even if the batch's size is below 'max_batch_size'. - // - // This parameter offers a way to bound queue latency, so that a task isn't - // stuck in the queue indefinitely waiting for enough tasks to arrive to - // make a full batch. (The latency bound is given in the class documentation - // above.) - // - // The goal is to smooth out batch sizes under low request rates, and thus - // avoid latency spikes. - int64 batch_timeout_micros = 0; - - // The name to use for the pool of batch threads. - string thread_pool_name = {"batch_threads"}; - - // The number of threads to use to process batches. - // Must be >= 1, and should be tuned carefully. - int num_batch_threads = port::NumSchedulableCPUs(); - - // The maximum allowable number of enqueued (accepted by Schedule() but - // not yet being processed on a batch thread) tasks in terms of batches. - // If this limit is reached, Schedule() will return an UNAVAILABLE error. - // See the class documentation above for guidelines on how to tune this - // parameter. - int max_enqueued_batches = 10; - - // The following options are typically only overridden by test code. - - // The environment to use. - Env* env = Env::Default(); - }; - static Status Create(const Options& options, - std::function>)> - process_batch_callback, - std::unique_ptr* scheduler); - - ~BasicBatchScheduler() override = default; - - Status Schedule(std::unique_ptr* task) override; - size_t NumEnqueuedTasks() const override; - size_t SchedulingCapacity() const override; - - private: - explicit BasicBatchScheduler( - std::unique_ptr> shared_scheduler_queue); - - // This class is merely a thin wrapper around a SharedBatchScheduler with a - // single queue. - std::unique_ptr> shared_scheduler_queue_; - - TF_DISALLOW_COPY_AND_ASSIGN(BasicBatchScheduler); -}; - -////////// -// Implementation details follow. API users need not read. - -template -Status BasicBatchScheduler::Create( - const Options& options, - std::function>)> - process_batch_callback, - std::unique_ptr* scheduler) { - typename SharedBatchScheduler::Options shared_scheduler_options; - shared_scheduler_options.thread_pool_name = options.thread_pool_name; - shared_scheduler_options.num_batch_threads = options.num_batch_threads; - shared_scheduler_options.env = options.env; - std::shared_ptr> shared_scheduler; - TF_RETURN_IF_ERROR(SharedBatchScheduler::Create( - shared_scheduler_options, &shared_scheduler)); - - typename SharedBatchScheduler::QueueOptions - shared_scheduler_queue_options; - shared_scheduler_queue_options.max_batch_size = options.max_batch_size; - shared_scheduler_queue_options.batch_timeout_micros = - options.batch_timeout_micros; - shared_scheduler_queue_options.max_enqueued_batches = - options.max_enqueued_batches; - std::unique_ptr> shared_scheduler_queue; - TF_RETURN_IF_ERROR(shared_scheduler->AddQueue(shared_scheduler_queue_options, - process_batch_callback, - &shared_scheduler_queue)); - - scheduler->reset( - new BasicBatchScheduler(std::move(shared_scheduler_queue))); - return Status::OK(); -} - -template -Status BasicBatchScheduler::Schedule( - std::unique_ptr* task) { - return shared_scheduler_queue_->Schedule(task); -} - -template -size_t BasicBatchScheduler::NumEnqueuedTasks() const { - return shared_scheduler_queue_->NumEnqueuedTasks(); -} - -template -size_t BasicBatchScheduler::SchedulingCapacity() const { - return shared_scheduler_queue_->SchedulingCapacity(); -} - -template -BasicBatchScheduler::BasicBatchScheduler( - std::unique_ptr> shared_scheduler_queue) - : shared_scheduler_queue_(std::move(shared_scheduler_queue)) {} - -} // namespace serving -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index a5072f439abad3c5db79a514a7f2baff0b021b39..8e94e1fd8b969d4fef8dbc8c322557f9da3833e6 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -13,264 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Abstractions for processing small tasks in a batched fashion, to reduce -// processing times and costs that can be amortized across multiple tasks. -// -// The core class is BatchScheduler, which groups tasks into batches. -// -// BatchScheduler encapsulates logic for aggregating multiple tasks into a -// batch, and kicking off processing of a batch on a thread pool it manages. -// -// This file defines an abstract BatchScheduler class. +#ifndef TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ +#include "tensorflow/core/kernels/batching_util/batch_scheduler.h" -#include -#include -#include -#include -#include -#include - -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { - -// The abstract superclass for a unit of work to be done as part of a batch. -// -// An implementing subclass typically contains (or points to): -// (a) input data; -// (b) a thread-safe completion signal (e.g. a Notification); -// (c) a place to store the outcome (success, or some error), upon completion; -// (d) a place to store the output data, upon success. -// -// Items (b), (c) and (d) are typically non-owned pointers to data homed -// elsewhere, because a task's ownership gets transferred to a BatchScheduler -// (see below) and it may be deleted as soon as it is done executing. -class BatchTask { - public: - virtual ~BatchTask() = default; - - // Returns the size of the task, in terms of how much it contributes to the - // size of a batch. (A batch's size is the sum of its task sizes.) - virtual size_t size() const = 0; -}; - -// A thread-safe collection of BatchTasks, to be executed together in some -// fashion. -// -// At a given time, a batch is either "open" or "closed": an open batch can -// accept new tasks; a closed one cannot. A batch is monotonic: initially it is -// open and tasks can be added to it; then it is closed and its set of tasks -// remains fixed for the remainder of its life. A closed batch cannot be re- -// opened. Tasks can never be removed from a batch. -// -// Type parameter TaskType must be a subclass of BatchTask. -template -class Batch { - public: - Batch() = default; - virtual ~Batch(); // Blocks until the batch is closed. - - // Appends 'task' to the batch. After calling AddTask(), the newly-added task - // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). - // Dies if the batch is closed. - void AddTask(std::unique_ptr task); - - // Removes the most recently added task. Returns nullptr if the batch is - // empty. - std::unique_ptr RemoveTask(); - - // Returns the number of tasks in the batch. - int num_tasks() const; - - // Returns true iff the batch contains 0 tasks. - bool empty() const; - - // Returns a reference to the ith task (in terms of insertion order). - const TaskType& task(int i) const; - - // Returns a pointer to the ith task (in terms of insertion order). - TaskType* mutable_task(int i); - - // Returns the sum of the task sizes. - size_t size() const; - - // Returns true iff the batch is currently closed. - bool IsClosed() const; - - // Blocks until the batch is closed. - void WaitUntilClosed() const; - - // Marks the batch as closed. Dies if called more than once. - void Close(); - - private: - mutable mutex mu_; - - // The tasks in the batch. - std::vector> tasks_ GUARDED_BY(mu_); - - // The sum of the sizes of the tasks in 'tasks_'. - size_t size_ GUARDED_BY(mu_) = 0; - - // Whether the batch has been closed. - Notification closed_; - - TF_DISALLOW_COPY_AND_ASSIGN(Batch); -}; - -// An abstract batch scheduler class. Collects individual tasks into batches, -// and processes each batch on a pool of "batch threads" that it manages. The -// actual logic for processing a batch is accomplished via a callback. -// -// Type parameter TaskType must be a subclass of BatchTask. -template -class BatchScheduler { - public: - virtual ~BatchScheduler() = default; - - // Submits a task to be processed as part of a batch. - // - // Ownership of '*task' is transferred to the callee iff the method returns - // Status::OK. In that case, '*task' is left as nullptr. Otherwise, '*task' is - // left as-is. - // - // If no batch processing capacity is available to process this task at the - // present time, and any task queue maintained by the implementing subclass is - // full, this method returns an UNAVAILABLE error code. The client may retry - // later. - // - // Other problems, such as the task size being larger than the maximum batch - // size, yield other, permanent error types. - // - // In all cases, this method returns "quickly" without blocking for any - // substantial amount of time. If the method returns Status::OK, the task is - // processed asynchronously, and any errors that occur during the processing - // of the batch that includes the task can be reported to 'task'. - virtual Status Schedule(std::unique_ptr* task) = 0; - - // Returns the number of tasks that have been scheduled (i.e. accepted by - // Schedule()), but have yet to be handed to a thread for execution as part of - // a batch. Note that this returns the number of tasks, not the aggregate task - // size (so if there is one task of size 3 and one task of size 5, this method - // returns 2 rather than 8). - virtual size_t NumEnqueuedTasks() const = 0; - - // Returns a guaranteed number of size 1 tasks that can be Schedule()d without - // getting an UNAVAILABLE error. In a typical implementation, returns the - // available space on a queue. - // - // There are two important caveats: - // 1. The guarantee does not extend to varying-size tasks due to possible - // internal fragmentation of batches. - // 2. The guarantee only holds in a single-thread environment or critical - // section, i.e. if an intervening thread cannot call Schedule(). - // - // This method is useful for monitoring, or for guaranteeing a future slot in - // the schedule (but being mindful about the caveats listed above). - virtual size_t SchedulingCapacity() const = 0; -}; - -////////// -// Implementation details follow. API users need not read. - -template -Batch::~Batch() { - WaitUntilClosed(); -} - -template -void Batch::AddTask(std::unique_ptr task) { - DCHECK(!IsClosed()); - { - mutex_lock l(mu_); - size_ += task->size(); - tasks_.push_back(std::move(task)); - } -} - -template -std::unique_ptr Batch::RemoveTask() { - { - mutex_lock l(mu_); - if (tasks_.empty()) { - return nullptr; - } - std::unique_ptr task = std::move(tasks_.back()); - tasks_.pop_back(); - return task; - } -} - -template -int Batch::num_tasks() const { - { - mutex_lock l(mu_); - return tasks_.size(); - } -} - -template -bool Batch::empty() const { - { - mutex_lock l(mu_); - return tasks_.empty(); - } -} - -template -const TaskType& Batch::task(int i) const { - DCHECK_GE(i, 0); - { - mutex_lock l(mu_); - DCHECK_LT(i, tasks_.size()); - return *tasks_[i].get(); - } -} - -template -TaskType* Batch::mutable_task(int i) { - DCHECK_GE(i, 0); - { - mutex_lock l(mu_); - DCHECK_LT(i, tasks_.size()); - return tasks_[i].get(); - } -} - -template -size_t Batch::size() const { - { - mutex_lock l(mu_); - return size_; - } -} - -template -bool Batch::IsClosed() const { - return const_cast(&closed_)->HasBeenNotified(); -} - -template -void Batch::WaitUntilClosed() const { - const_cast(&closed_)->WaitForNotification(); -} - -template -void Batch::Close() { - closed_.Notify(); -} - -} // namespace serving -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/kernels/BUILD b/tensorflow/contrib/batching/kernels/BUILD deleted file mode 100644 index 6e53dd9a5fc0201c5ed91d1eaf07f940e341fb5e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/kernels/BUILD +++ /dev/null @@ -1,34 +0,0 @@ -# Description: -# Contains kernels for the batching ops. - -package(default_visibility = ["//tensorflow:__subpackages__"]) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -cc_library( - name = "batch_kernels", - srcs = ["batch_kernels.cc"], - deps = [ - "//tensorflow/contrib/batching:shared_batch_scheduler_hdrs", - "//tensorflow/contrib/batching/util:periodic_function_dynamic", - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/kernels:concat_lib_hdrs", - "//tensorflow/core/kernels:ops_util_hdrs", - "//tensorflow/core/kernels:split_lib_hdrs", - ], - alwayslink = 1, -) - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) diff --git a/tensorflow/contrib/batching/ops/batch_ops.cc b/tensorflow/contrib/batching/ops/batch_ops.cc deleted file mode 100644 index 85e0ccba4aa372bdc21fb194263569b8b787bb6c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/batching/ops/batch_ops.cc +++ /dev/null @@ -1,164 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("Batch") - .Input("in_tensors: T") - .Output("batched_tensors: T") - .Output("batch_index: int64") - .Output("id: int64") - .Attr("num_batch_threads: int") - .Attr("max_batch_size: int") - .Attr("batch_timeout_micros: int") - .Attr("allowed_batch_sizes: list(int) = []") - .Attr("grad_timeout_micros: int") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("batching_queue: string = ''") - .Attr("T: list(type)") - .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector in_shapes; - TF_RETURN_IF_ERROR(c->input("in_tensors", &in_shapes)); - std::vector out_shapes(in_shapes.size()); - for (int i = 0; i < in_shapes.size(); ++i) { - TF_RETURN_IF_ERROR( - c->ReplaceDim(in_shapes[i], 0, c->UnknownDim(), &out_shapes[i])); - } - TF_RETURN_IF_ERROR(c->set_output("batched_tensors", out_shapes)); - TF_RETURN_IF_ERROR(c->set_output("id", {c->Scalar()})); - TF_RETURN_IF_ERROR(c->set_output( - "batch_index", - {c->MakeShape({shape_inference::DimensionOrConstant(c->UnknownDim()), - shape_inference::DimensionOrConstant(3)})})); - return Status::OK(); - }) - .Doc(R"doc( -Batches all input tensors nondeterministically. - -When many instances of this Op are being run concurrently with the same -container/shared_name in the same device, some will output zero-shaped Tensors -and others will output Tensors of size up to max_batch_size. - -All Tensors in in_tensors are batched together (so, for example, labels and -features should be batched with a single instance of this operation. - -Each invocation of batch emits an `id` scalar which will be used to identify -this particular invocation when doing unbatch or its gradient. - -Each op which emits a non-empty batch will also emit a non-empty batch_index -Tensor, which, is a [K, 3] matrix where each row contains the invocation's id, -start, and length of elements of each set of Tensors present in batched_tensors. - -Batched tensors are concatenated along the first dimension, and all tensors in -in_tensors must have the first dimension of the same size. - -in_tensors: The tensors to be batched. -num_batch_threads: Number of scheduling threads for processing batches of work. - Determines the number of batches processed in parallel. -max_batch_size: Batch sizes will never be bigger than this. -batch_timeout_micros: Maximum number of microseconds to wait before outputting - an incomplete batch. -allowed_batch_sizes: Optional list of allowed batch sizes. If left empty, does - nothing. Otherwise, supplies a list of batch sizes, causing the op to pad - batches up to one of those sizes. The entries must increase monotonically, and - the final entry must equal max_batch_size. -grad_timeout_micros: The timeout to use for the gradient. See Unbatch. -batched_tensors: Either empty tensors or a batch of concatenated Tensors. -batch_index: If out_tensors is non-empty, has information to invert it. -container: Controls the scope of sharing of this batch. -id: always contains a scalar with a unique ID for this invocation of Batch. -shared_name: Concurrently running instances of batch in the same device with the - same container and shared_name will batch their elements together. If left - empty, the op name will be used as the shared name. -T: the types of tensors to be batched. -)doc"); - -REGISTER_OP("Unbatch") - .Input("batched_tensor: T") - .Input("batch_index: int64") - .Input("id: int64") - .Output("unbatched_tensor: T") - .Attr("timeout_micros: int") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("T: type") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle out_shape; - TF_RETURN_IF_ERROR( - c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &out_shape)); - c->set_output(0, out_shape); - return Status::OK(); - }) - .Doc(R"doc( -Reverses the operation of Batch for a single output Tensor. - -An instance of Unbatch either receives an empty batched_tensor, in which case it -asynchronously waits until the values become available from a concurrently -running instance of Unbatch with the same container and shared_name, or receives -a non-empty batched_tensor in which case it finalizes all other concurrently -running instances and outputs its own element from the batch. - -batched_tensor: The possibly transformed output of Batch. The size of the first - dimension should remain unchanged by the transformations for the operation to - work. -batch_index: The matching batch_index obtained from Batch. -id: The id scalar emitted by Batch. -unbatched_tensor: The Tensor corresponding to this execution. -timeout_micros: Maximum amount of time (in microseconds) to wait to receive the - batched input tensor associated with a given invocation of the op. -container: Container to control resource sharing. -shared_name: Instances of Unbatch with the same container and shared_name are - assumed to possibly belong to the same batch. If left empty, the op name will - be used as the shared name. -)doc"); - -REGISTER_OP("UnbatchGrad") - .Input("original_input: T") - .Input("batch_index: int64") - .Input("grad: T") - .Input("id: int64") - .Output("batched_grad: T") - .Attr("container: string = ''") - .Attr("shared_name: string = ''") - .Attr("T: type") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(2)))); - return Status::OK(); - }) - .Doc(R"doc( -Gradient of Unbatch. - -Acts like Batch but using the given batch_index index of batching things as they -become available. This ensures that the gradients are propagated back in the -same session which did the forward pass. - -original_input: The input to the Unbatch operation this is the gradient of. -batch_index: The batch_index given to the Unbatch operation this is the gradient -of. -grad: The downstream gradient. -id: The id scalar emitted by Batch. -batched_grad: The return value, either an empty tensor or the batched gradient. -container: Container to control resource sharing. -shared_name: Instances of UnbatchGrad with the same container and shared_name - are assumed to possibly belong to the same batch. If left empty, the op name - will be used as the shared name. - )doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/batching/python/ops/batch_ops.py b/tensorflow/contrib/batching/python/ops/batch_ops.py index cee4d7b4a9710e285957f27ace7c2762c473c5c7..921d6917a4e478c3e60771fdc3ae99febc33d2e3 100644 --- a/tensorflow/contrib/batching/python/ops/batch_ops.py +++ b/tensorflow/contrib/batching/python/ops/batch_ops.py @@ -18,18 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.batching.ops import gen_batch_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_batch_ops # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.batching.ops.gen_batch_ops import * +from tensorflow.python.ops.gen_batch_ops import * # pylint: enable=wildcard-import -from tensorflow.contrib.util import loader -from tensorflow.python.framework import ops -from tensorflow.python.platform import resource_loader - - -_batch_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("_batch_ops.so")) @ops.RegisterGradient("Batch") @@ -59,10 +53,13 @@ def _UnbatchGrad(op, grad): # pylint: disable=invalid-name ] -def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, +def batch_function(num_batch_threads, + max_batch_size, + batch_timeout_micros, allowed_batch_sizes=None, grad_timeout_micros=60 * 1000 * 1000, - unbatch_timeout_micros=60 * 1000 * 1000): + unbatch_timeout_micros=60 * 1000 * 1000, + max_enqueued_batches=10): """Batches the computation done by the decorated function. So, for example, in the following code @@ -100,6 +97,7 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, documentation of the unbatch op for more details. Defaults to 60s. unbatch_timeout_micros: The timeout to use for unbatching. See the documentation of the unbatch op for more details. Defaults to 60s. + max_enqueued_batches: The maximum depth of the batch queue. Defaults to 10. Returns: The decorated function will return the unbatched computation output Tensors. @@ -117,6 +115,7 @@ def batch_function(num_batch_threads, max_batch_size, batch_timeout_micros, num_batch_threads=num_batch_threads, max_batch_size=max_batch_size, batch_timeout_micros=batch_timeout_micros, + max_enqueued_batches=max_enqueued_batches, allowed_batch_sizes=allowed_batch_sizes, grad_timeout_micros=grad_timeout_micros, shared_name=name) diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h index 41a3f99137ade2552432fee62ddce17d064148a4..83a59695d7db7e0a24fb437a3ea71a4d9e23c93f 100644 --- a/tensorflow/contrib/batching/shared_batch_scheduler.h +++ b/tensorflow/contrib/batching/shared_batch_scheduler.h @@ -13,688 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ +#define TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ -#include -#include -#include -#include -#include -#include -#include -#include +#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h" -#include "tensorflow/contrib/batching/batch_scheduler.h" -#include "tensorflow/contrib/batching/util/periodic_function.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/cpu_info.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { -namespace internal { -template -class Queue; -} // namespace internal -} // namespace serving -} // namespace tensorflow - -namespace tensorflow { -namespace serving { - -// A batch scheduler for server instances that service multiple request types -// (e.g. multiple machine-learned models, or multiple versions of a model served -// concurrently), or even multiple distinct tasks for a given request. The -// scheduler multiplexes batches of different kinds of tasks onto a fixed-size -// thread pool (each batch contains tasks of a single type), in a carefully -// controlled manner. A common configuration is to set the number of threads -// equal to the number of hardware accelerator units, in which case the -// scheduler takes care of multiplexing the task types onto the shared hardware, -// in a manner that is both fair and efficient. -// -// Semantically, SharedBatchScheduler behaves like having N instances of -// BasicBatchScheduler (see basic_batch_scheduler.h), one per task type. The -// difference is that under the covers there is a single shared thread pool, -// instead of N independent ones, with their sharing deliberately coordinated. -// -// SharedBatchScheduler does not implement the BatchScheduler API; rather, it -// presents an abstraction of "queues", where each queue coresponds to one type -// of task. Tasks submitted to a given queue are placed in their own batches, -// and cannot be mixed with other tasks. Queues can be added and deleted -// dynamically, to accommodate e.g. versions of a model being brought up and -// down over the lifetime of a server. -// -// The batch thread pool round-robins through the queues, running one batch -// from a queue and then moving to the next queue. Each queue behaves like a -// BasicBatchScheduler instance, in the sense that it has maximum batch size and -// timeout parameters, which govern when a batch is eligible to be processed. -// -// Each queue is independently configured with a maximum size (in terms of the -// maximum number of batches worth of enqueued tasks). For online serving, it is -// recommended that the queue sizes be configured such that the sum of the sizes -// of the active queues roughly equal the number of batch threads. (The idea is -// that if all threads become available at roughly the same time, there will be -// enough enqueued work for them to take on, but no more.) -// -// If queue sizes are configured in the manner suggested above, the maximum time -// a task can spend in a queue before being placed in a batch and assigned to a -// thread for processing, is the greater of: -// - the maximum time to process one batch of tasks from any active queue -// - the configured timeout parameter for the task's queue (which can be 0) -// -// For bulk processing jobs and throughput-oriented benchmarks, you may want to -// set the maximum queue size to a large value. -// -// TODO(b/26539183): Support queue servicing policies other than round-robin. -// E.g. let each queue specify a "share" (an int >= 1), so e.g. with queues A -// and B having shares 1 and 2 respectively, the servicing pattern is ABBABB... -// -// -// PERFORMANCE TUNING: See README.md. -// -template -class SharedBatchScheduler - : public std::enable_shared_from_this> { - public: - // TODO(b/25089730): Tune defaults based on best practices as they develop. - struct Options { - // The name to use for the pool of batch threads. - string thread_pool_name = {"batch_threads"}; - - // The number of threads to use to process batches. - // Must be >= 1, and should be tuned carefully. - int num_batch_threads = port::NumSchedulableCPUs(); - - // The environment to use. - // (Typically only overridden by test code.) - Env* env = Env::Default(); - }; - // Ownership is shared between the caller of Create() and any queues created - // via AddQueue(). - static Status Create( - const Options& options, - std::shared_ptr>* scheduler); - - ~SharedBatchScheduler(); - - // Adds a queue to which tasks may be submitted. The returned queue implements - // the BatchScheduler API. Each queue has its own set of scheduling options, - // and its own callback to process batches of tasks submitted to the queue. - // - // The returned queue's destructor blocks until all tasks submitted to it have - // been processed. - struct QueueOptions { - // The maximum size of each batch. - // - // The scheduler may form batches of any size between 1 and this number - // (inclusive). If there is a need to quantize the batch sizes, i.e. only - // submit batches whose size is in a small set of allowed sizes, that can be - // done by adding padding in the process-batch callback. - int max_batch_size = 1000; - - // If a task has been enqueued for this amount of time (in microseconds), - // and a thread is available, the scheduler will immediately form a batch - // from enqueued tasks and assign the batch to the thread for processing, - // even if the batch's size is below 'max_batch_size'. - // - // This parameter offers a way to bound queue latency, so that a task isn't - // stuck in the queue indefinitely waiting for enough tasks to arrive to - // make a full batch. (The latency bound is given in the class documentation - // above.) - // - // The goal is to smooth out batch sizes under low request rates, and thus - // avoid latency spikes. - int64 batch_timeout_micros = 0; - - // The maximum allowable number of enqueued (accepted by Schedule() but - // not yet being processed on a batch thread) tasks in terms of batches. - // If this limit is reached, Schedule() will return an UNAVAILABLE error. - // See the class documentation above for guidelines on how to tune this - // parameter. - int max_enqueued_batches = 10; - }; - Status AddQueue(const QueueOptions& options, - std::function>)> - process_batch_callback, - std::unique_ptr>* queue); - - private: - explicit SharedBatchScheduler(const Options& options); - - // The code executed in 'batch_threads_'. Obtains a batch to process from the - // queue pointed to by 'next_queue_to_schedule_', and processes it. If that - // queue declines to provide a batch to process, moves onto the next queue. If - // no queues provide a batch to process, just sleeps briefly and exits. - void ThreadLogic(); - - const Options options_; - - mutex mu_; - - // A list of queues. (We use std::list instead of std::vector to ensure that - // iterators are not invalidated by adding/removing elements. It also offers - // efficient removal of elements from the middle.) - using QueueList = std::list>>; - - // All "active" queues, i.e. ones that either: - // - have not been removed, or - // - have been removed but are not yet empty. - QueueList queues_ GUARDED_BY(mu_); - - // An iterator over 'queues_', pointing to the queue from which the next - // available batch thread should grab work. - typename QueueList::iterator next_queue_to_schedule_ GUARDED_BY(mu_); - - // Used by idle batch threads to wait for work to enter the system. Notified - // whenever a batch becomes schedulable. - condition_variable schedulable_batch_cv_; - - // Threads that process batches obtained from the queues. - std::vector> batch_threads_; - - TF_DISALLOW_COPY_AND_ASSIGN(SharedBatchScheduler); -}; - -////////// -// Implementation details follow. API users need not read. - -namespace internal { - -// A task queue for SharedBatchScheduler. Accepts tasks and accumulates them -// into batches, and dispenses those batches to be processed via a "pull" -// interface. The queue's behavior is governed by maximum batch size, timeout -// and maximum queue length parameters; see their documentation in -// SharedBatchScheduler. -// -// The queue is implemented as a deque of batches, with these invariants: -// - The number of batches is between 1 and 'options_.max_enqueued_batches'. -// - The back-most batch is open; the rest are closed. -// -// Submitted tasks are added to the open batch. If that batch doesn't have room -// but the queue isn't full, then that batch is closed and a new open batch is -// started. -// -// Batch pull requests are handled by dequeuing the front-most batch if it is -// closed. If the front-most batch is open (i.e. the queue contains only one -// batch) and has reached the timeout, it is immediately closed and returned; -// otherwise no batch is returned for the request. -template -class Queue { - public: - using ProcessBatchCallback = - std::function>)>; - using SchedulableBatchCallback = std::function; - Queue(const typename SharedBatchScheduler::QueueOptions& options, - Env* env, ProcessBatchCallback process_batch_callback, - SchedulableBatchCallback schdulable_batch_callback); - - // Illegal to destruct unless the queue is empty. - ~Queue(); - - // Submits a task to the queue, with the same semantics as - // BatchScheduler::Schedule(). - Status Schedule(std::unique_ptr* task); - - // Returns the number of enqueued tasks, with the same semantics as - // BatchScheduler::NumEnqueuedTasks(). - size_t NumEnqueuedTasks() const; - - // Returns the queue capacity, with the same semantics as - // BatchScheduler::SchedulingCapacity(). - size_t SchedulingCapacity() const; - - // Called by a thread that is ready to process a batch, to request one from - // this queue. Either returns a batch that is ready to be processed, or - // nullptr if the queue declines to schedule a batch at this time. If it - // returns a batch, the batch is guaranteed to be closed. - std::unique_ptr> ScheduleBatch(); - - // Processes a batch that has been returned earlier by ScheduleBatch(). - void ProcessBatch(std::unique_ptr> batch); - - // Determines whether the queue is empty, i.e. has no tasks waiting or being - // processed. - bool IsEmpty() const; - - // Marks the queue closed, and waits until it is empty. - void CloseAndWaitUntilEmpty(); - - bool closed() const { - mutex_lock l(mu_); - return closed_; - } - - private: - // Same as IsEmpty(), but assumes the caller already holds a lock on 'mu_'. - bool IsEmptyInternal() const EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Closes the open batch residing at the back of 'batches_', and inserts a - // fresh open batch behind it. - void StartNewBatch() EXCLUSIVE_LOCKS_REQUIRED(mu_); - - // Determines whether the open batch residing at the back of 'batches_' is - // currently schedulable. - bool IsOpenBatchSchedulable() const EXCLUSIVE_LOCKS_REQUIRED(mu_); - - const typename SharedBatchScheduler::QueueOptions options_; - - // The environment to use. - Env* env_; - - // A callback invoked to processes a batch of work units. Always invoked from - // a batch thread. - ProcessBatchCallback process_batch_callback_; - - // A callback invoked to notify the scheduler that a new batch has become - // schedulable. - SchedulableBatchCallback schedulable_batch_callback_; - - mutable mutex mu_; - - // Whether this queue can accept new tasks. This variable is monotonic: it - // starts as false, and then at some point gets set to true and remains true - // for the duration of this object's life. - bool closed_ GUARDED_BY(mu_) = false; - - // The enqueued batches. See the invariants in the class comments above. - std::deque>> batches_ GUARDED_BY(mu_); - - // The time at which the first task was added to the open (back-most) batch - // in 'batches_'. Valid iff that batch contains at least one task. - uint64 open_batch_start_time_micros_ GUARDED_BY(mu_); - - // Whether this queue contains a batch that is eligible to be scheduled. Used - // to keep track of when to call 'schedulable_batch_callback_'. - bool schedulable_batch_ GUARDED_BY(mu_) = false; - - // The number of batches currently being processed by batch threads. - // Incremented in ScheduleBatch() and decremented in ProcessBatch(). - int num_batches_being_processed_ GUARDED_BY(mu_) = 0; - - // Used by CloseAndWaitUntilEmpty() to wait until the queue is empty, for the - // case in which the queue is not empty when CloseAndWaitUntilEmpty() starts. - // When ProcessBatch() dequeues the last batch and makes the queue empty, if - // 'empty_notification_' is non-null it calls 'empty_notification_->Notify()'. - Notification* empty_notification_ GUARDED_BY(mu_) = nullptr; - - TF_DISALLOW_COPY_AND_ASSIGN(Queue); -}; - -// A RAII-style object that points to a Queue and implements -// the BatchScheduler API. To be handed out to clients who call AddQueue(). -template -class QueueHandle : public BatchScheduler { - public: - QueueHandle(std::shared_ptr> scheduler, - Queue* queue); - ~QueueHandle() override; - - Status Schedule(std::unique_ptr* task) override; - size_t NumEnqueuedTasks() const override; - size_t SchedulingCapacity() const override; - - private: - // The scheduler that owns 'queue_'. - std::shared_ptr> scheduler_; - - // The queue this handle wraps. Owned by 'scheduler_', which keeps it alive at - // least until this class's destructor closes it. - Queue* queue_; - - TF_DISALLOW_COPY_AND_ASSIGN(QueueHandle); -}; - -} // namespace internal - -template -Status SharedBatchScheduler::Create( - const Options& options, - std::shared_ptr>* scheduler) { - if (options.num_batch_threads < 1) { - return errors::InvalidArgument("num_batch_threads must be positive; was ", - options.num_batch_threads); - } - scheduler->reset(new SharedBatchScheduler(options)); - return Status::OK(); -} - -template -SharedBatchScheduler::~SharedBatchScheduler() { - // Wait until the batch threads finish clearing out and deleting the closed - // queues. - for (;;) { - { - mutex_lock l(mu_); - if (queues_.empty()) { - break; - } - } - const int64 kSleepTimeMicros = 100; - options_.env->SleepForMicroseconds(kSleepTimeMicros); - } - // Delete the batch threads before allowing state the threads may access (e.g. - // 'mu_') to be deleted. - batch_threads_.clear(); -} - -template -Status SharedBatchScheduler::AddQueue( - const QueueOptions& options, - std::function>)> - process_batch_callback, - std::unique_ptr>* queue) { - if (options.max_batch_size <= 0) { - return errors::InvalidArgument("max_batch_size must be positive; was ", - options.max_batch_size); - } - if (options.batch_timeout_micros < 0) { - return errors::InvalidArgument( - "batch_timeout_micros must be non-negative; was ", - options.batch_timeout_micros); - } - if (options.max_enqueued_batches < 0) { - return errors::InvalidArgument( - "max_enqueued_batches must be non-negative; was ", - options.max_enqueued_batches); - } - - auto schedulable_batch_callback = [this] { - mutex_lock l(mu_); - schedulable_batch_cv_.notify_one(); - }; - auto internal_queue = - std::unique_ptr>(new internal::Queue( - options, options_.env, process_batch_callback, - schedulable_batch_callback)); - auto handle = std::unique_ptr>( - new internal::QueueHandle(this->shared_from_this(), - internal_queue.get())); - { - mutex_lock l(mu_); - queues_.push_back(std::move(internal_queue)); - if (next_queue_to_schedule_ == queues_.end()) { - next_queue_to_schedule_ = queues_.begin(); - } - } - *queue = std::move(handle); - return Status::OK(); -} - -template -SharedBatchScheduler::SharedBatchScheduler(const Options& options) - : options_(options), next_queue_to_schedule_(queues_.end()) { - // Kick off the batch threads. - PeriodicFunction::Options periodic_fn_options; - periodic_fn_options.thread_name_prefix = - strings::StrCat(options.thread_pool_name, "_"); - for (int i = 0; i < options.num_batch_threads; ++i) { - std::unique_ptr thread(new PeriodicFunction( - [this] { this->ThreadLogic(); }, - 0 /* function invocation interval time */, periodic_fn_options)); - batch_threads_.push_back(std::move(thread)); - } -} - -template -void SharedBatchScheduler::ThreadLogic() { - // A batch to process next (or nullptr if no work to do). - std::unique_ptr> batch_to_process; - // The queue with which 'batch_to_process' is associated. - internal::Queue* queue_for_batch = nullptr; - { - mutex_lock l(mu_); - - const int num_queues = queues_.size(); - for (int num_queues_tried = 0; - batch_to_process == nullptr && num_queues_tried < num_queues; - ++num_queues_tried) { - DCHECK(next_queue_to_schedule_ != queues_.end()); - - // If a closed queue responds to ScheduleBatch() with nullptr, the queue - // will never yield any further batches so we can drop it. To avoid a - // race, we take a snapshot of the queue's closedness state *before* - // calling ScheduleBatch(). - const bool queue_closed = (*next_queue_to_schedule_)->closed(); - - // Ask '*next_queue_to_schedule_' if it wants us to process a batch. - batch_to_process = (*next_queue_to_schedule_)->ScheduleBatch(); - if (batch_to_process != nullptr) { - queue_for_batch = next_queue_to_schedule_->get(); - } - - // Advance 'next_queue_to_schedule_'. - if (queue_closed && (*next_queue_to_schedule_)->IsEmpty() && - batch_to_process == nullptr) { - // We've encountered a closed queue with no work to do. Drop it. - DCHECK_NE(queue_for_batch, next_queue_to_schedule_->get()); - next_queue_to_schedule_ = queues_.erase(next_queue_to_schedule_); - } else { - ++next_queue_to_schedule_; - } - if (next_queue_to_schedule_ == queues_.end() && !queues_.empty()) { - // We've hit the end. Wrap to the first queue. - next_queue_to_schedule_ = queues_.begin(); - } - } - - if (batch_to_process == nullptr) { - // We couldn't find any work to do. Wait until a new batch becomes - // schedulable, or some time has elapsed, before checking again. - const int64 kTimeoutMillis = 1; // The smallest accepted granule of time. - WaitForMilliseconds(&l, &schedulable_batch_cv_, kTimeoutMillis); - return; - } - } - - queue_for_batch->ProcessBatch(std::move(batch_to_process)); -} - -namespace internal { - -template -Queue::Queue( - const typename SharedBatchScheduler::QueueOptions& options, - Env* env, ProcessBatchCallback process_batch_callback, - SchedulableBatchCallback schedulable_batch_callback) - : options_(options), - env_(env), - process_batch_callback_(process_batch_callback), - schedulable_batch_callback_(schedulable_batch_callback) { - // Create an initial, open batch. - batches_.emplace_back(new Batch); -} - -template -Queue::~Queue() { - mutex_lock l(mu_); - DCHECK(IsEmptyInternal()); - - // Close the (empty) open batch, so its destructor doesn't block. - batches_.back()->Close(); -} - -template -Status Queue::Schedule(std::unique_ptr* task) { - if ((*task)->size() > options_.max_batch_size) { - return errors::InvalidArgument("Task size ", (*task)->size(), - " is larger than maximum batch size ", - options_.max_batch_size); - } - - bool notify_of_schedulable_batch = false; - { - mutex_lock l(mu_); - - DCHECK(!closed_); - - if (batches_.back()->size() + (*task)->size() > options_.max_batch_size) { - if (batches_.size() >= options_.max_enqueued_batches) { - return errors::Unavailable( - "The batch scheduling queue to which this task was submitted is " - "full"); - } - StartNewBatch(); - } - if (batches_.back()->empty()) { - open_batch_start_time_micros_ = env_->NowMicros(); - } - batches_.back()->AddTask(std::move(*task)); - - if (!schedulable_batch_) { - if (batches_.size() > 1 || IsOpenBatchSchedulable()) { - schedulable_batch_ = true; - notify_of_schedulable_batch = true; - } - } - } - - if (notify_of_schedulable_batch) { - schedulable_batch_callback_(); - } - - return Status::OK(); -} - -template -size_t Queue::NumEnqueuedTasks() const { - mutex_lock l(mu_); - size_t num_enqueued_tasks = 0; - for (const auto& batch : batches_) { - num_enqueued_tasks += batch->num_tasks(); - } - return num_enqueued_tasks; -} - -template -size_t Queue::SchedulingCapacity() const { - mutex_lock l(mu_); - const int num_new_batches_schedulable = - options_.max_enqueued_batches - batches_.size(); - const int open_batch_capacity = - options_.max_batch_size - batches_.back()->size(); - return (num_new_batches_schedulable * options_.max_batch_size) + - open_batch_capacity; -} - -template -std::unique_ptr> Queue::ScheduleBatch() { - // The batch to schedule, which we may populate below. (If left as nullptr, - // that means we are electing not to schedule a batch at this time.) - std::unique_ptr> batch_to_schedule; - - { - mutex_lock l(mu_); - - // Consider closing the open batch at this time, to schedule it. - if (batches_.size() == 1 && IsOpenBatchSchedulable()) { - StartNewBatch(); - } - - if (batches_.size() >= 2) { - // There is at least one closed batch that is ready to be scheduled. - ++num_batches_being_processed_; - batch_to_schedule = std::move(batches_.front()); - batches_.pop_front(); - } else { - schedulable_batch_ = false; - } - } - - return batch_to_schedule; -} - -template -void Queue::ProcessBatch(std::unique_ptr> batch) { - process_batch_callback_(std::move(batch)); - - { - mutex_lock l(mu_); - --num_batches_being_processed_; - if (empty_notification_ != nullptr && IsEmptyInternal()) { - empty_notification_->Notify(); - } - } -} - -template -bool Queue::IsEmpty() const { - mutex_lock l(mu_); - return IsEmptyInternal(); -} - -template -void Queue::CloseAndWaitUntilEmpty() { - Notification empty; - { - mutex_lock l(mu_); - closed_ = true; - if (IsEmptyInternal()) { - empty.Notify(); - } else { - // Arrange for ProcessBatch() to notify when the queue becomes empty. - empty_notification_ = ∅ - } - } - empty.WaitForNotification(); -} - -template -bool Queue::IsEmptyInternal() const { - return num_batches_being_processed_ == 0 && batches_.size() == 1 && - batches_.back()->empty(); -} - -template -void Queue::StartNewBatch() { - batches_.back()->Close(); - batches_.emplace_back(new Batch); -} - -template -bool Queue::IsOpenBatchSchedulable() const { - Batch* open_batch = batches_.back().get(); - if (open_batch->empty()) { - return false; - } - return closed_ || open_batch->size() >= options_.max_batch_size || - env_->NowMicros() >= - open_batch_start_time_micros_ + options_.batch_timeout_micros; -} - -template -QueueHandle::QueueHandle( - std::shared_ptr> scheduler, - Queue* queue) - : scheduler_(scheduler), queue_(queue) {} - -template -QueueHandle::~QueueHandle() { - queue_->CloseAndWaitUntilEmpty(); -} - -template -Status QueueHandle::Schedule(std::unique_ptr* task) { - return queue_->Schedule(task); -} - -template -size_t QueueHandle::NumEnqueuedTasks() const { - return queue_->NumEnqueuedTasks(); -} - -template -size_t QueueHandle::SchedulingCapacity() const { - return queue_->SchedulingCapacity(); -} - -} // namespace internal - -} // namespace serving -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD index d1ced0d8c367f44b520a9bba2db8a3e0969bab4c..6db627faad1df4a4b73082e74e7754829ff2b514 100644 --- a/tensorflow/contrib/batching/test_util/BUILD +++ b/tensorflow/contrib/batching/test_util/BUILD @@ -22,11 +22,9 @@ filegroup( cc_library( name = "fake_clock_env", testonly = 1, - srcs = ["fake_clock_env.cc"], hdrs = ["fake_clock_env.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:tensorflow", + "//tensorflow/core/kernels/batching_util:fake_clock_env", ], ) diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h index 35cafcb73c51feb4e9e15a61d1830c8ef6bc3e0f..40a39a5569854350c72a47102f3dac07b362ce8e 100644 --- a/tensorflow/contrib/batching/test_util/fake_clock_env.h +++ b/tensorflow/contrib/batching/test_util/fake_clock_env.h @@ -13,64 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ +#ifndef TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ +#define TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ -#include -#include -#include +#include "tensorflow/core/kernels/batching_util/fake_clock_env.h" -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { -namespace test_util { - -// An Env implementation with a fake clock for NowMicros() and -// SleepForMicroseconds(). The clock doesn't advance on its own; it advances via -// an explicit Advance() method. -// All other Env virtual methods pass through to a wrapped Env. -class FakeClockEnv : public EnvWrapper { - public: - explicit FakeClockEnv(Env* wrapped); - ~FakeClockEnv() override = default; - - // Advance the clock by a certain number of microseconds. - void AdvanceByMicroseconds(int micros); - - // Blocks until there is a sleeping thread that is scheduled to wake up at - // the given (absolute) time. - void BlockUntilSleepingThread(uint64 wake_time); - - // Blocks until there are at least num_threads sleeping. - void BlockUntilThreadsAsleep(int num_threads); - - // Methods that this class implements. - uint64 NowMicros() override; - void SleepForMicroseconds(int64 micros) override; - - private: - mutex mu_; - - uint64 current_time_ GUARDED_BY(mu_) = 0; - - struct SleepingThread { - uint64 wake_time; - Notification* wake_notification; - }; - std::vector sleeping_threads_ GUARDED_BY(mu_); - - TF_DISALLOW_COPY_AND_ASSIGN(FakeClockEnv); -}; - -} // namespace test_util -} // namespace serving -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_ diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD index f33a08cb817e9f2832be953ef6ff1aba04c4c288..2a84a7712a8fa66e89db41ff4e7ebe4f620029ca 100644 --- a/tensorflow/contrib/batching/util/BUILD +++ b/tensorflow/contrib/batching/util/BUILD @@ -22,12 +22,11 @@ filegroup( cc_library( name = "periodic_function_dynamic", - srcs = ["periodic_function.cc"], hdrs = ["periodic_function.h"], visibility = ["//visibility:public"], deps = [ - "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels/batching_util:periodic_function_dynamic", + "//third_party/eigen3", ], ) @@ -36,17 +35,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":periodic_function_dynamic", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "periodic_function_test", - srcs = ["periodic_function_test.cc"], - deps = [ - ":periodic_function_dynamic", - "//tensorflow/contrib/batching/test_util:fake_clock_env", - "//tensorflow/core:test", - "//tensorflow/core:test_main", + "//tensorflow/core/kernels/batching_util:periodic_function", ], ) diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h index 2c032d802fe5f23a267db28dc869a253f16afc34..aa2ed0a385125fa090a7a56b6339a87eb2d57b1f 100644 --- a/tensorflow/contrib/batching/util/periodic_function.h +++ b/tensorflow/contrib/batching/util/periodic_function.h @@ -12,121 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#define TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ -// PeriodicFunction will periodically call the given function with a specified -// period in a background thread. After Start() returns, the thread is -// guaranteed to have started. The destruction of the class causes the -// background thread to be destroyed as well. Start() should not be called more -// than once. -// -// PeriodicFunction runs the function as soon as any previous run both is -// complete and was started more than "interval_micros" earlier. Thus, runs are -// both serialized, and normally have a period of "interval_micros" if no run -// exceeds the time. -// -// Note that, if the function takes longer than two interval_micross to finish, -// then PeriodicFunction will "skip" at least one call to the function. For -// instance, if the period is 50ms and the function starts runs at time 0 for -// 150ms, then the function will immediately start executing again at time 150, -// but there will be no function runs corresponding to times 50 or 100. This is -// especially important to remember when using an environment with a simulated -// clock: advancing simulated time atomically over N interval_micross will not -// cause the function to be called N times. -// -// This object is thread-safe. -// -// Example: -// -// class Foo { -// public: -// Foo() : periodic_function_([this]() { Bar(); }, -// 1000 /* 1000us == 1ms*/) { -// } -// -// private: -// void Bar() { ... } -// -// PeriodicFunction periodic_function_; -// }; +#include "tensorflow/core/kernels/batching_util/periodic_function.h" -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ - -#include -#include -#include - -#include "tensorflow/core/lib/core/notification.h" -#include "tensorflow/core/platform/env.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/thread_annotations.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace serving { - -namespace internal { -class PeriodicFunctionTestAccess; -} - -class PeriodicFunction { - public: - // Provides the ability to customize several aspects of the PeriodicFunction. - // Passed to constructor of PeriodicFunction. - struct Options { - Options() {} - - // Any standard thread options, such as stack size, should - // be passed via "thread_options". - ThreadOptions thread_options; - - // Specifies the thread name prefix (see the description in class - // Thread). - string thread_name_prefix = "periodic_function"; - - // The environment to use. Does not take ownership, but must remain alive - // for as long as the PeriodicFunction exists. - Env* env = Env::Default(); - - // Specifies the length of sleep before the first invocation of the - // function. - // This can be used for adding a random jitter to avoid synchronous behavior - // across multiple periodic functions. - int64 startup_delay_micros = 0; - }; - - // Also starts the background thread which will be calling the function. - PeriodicFunction(const std::function& function, int64 interval_micros, - const Options& options = Options()); - - ~PeriodicFunction(); - - private: - friend class internal::PeriodicFunctionTestAccess; - - // Notifies the background thread to stop. - void NotifyStop(); - - // (Blocking.) Loops forever calling "function_" every "interval_micros_". - void RunLoop(int64 start) LOCKS_EXCLUDED(mutex_); - - const std::function function_; // Actual client function - const int64 interval_micros_; // Interval between calls. - const Options options_; - - // Protects state below. - mutable mutex mutex_; - // Used to notify the thread to stop. - Notification stop_thread_; - - // Thread for running "function_" - std::unique_ptr thread_ GUARDED_BY(mutex_) = nullptr; - - TF_DISALLOW_COPY_AND_ASSIGN(PeriodicFunction); -}; - -} // namespace serving -} // namespace tensorflow - -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ +#endif // TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_ diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index a262d4aecdbb69dfcd8b88bc0a09060500d6b1c9..74712aeb67c3f0a31def78f25a0298f9c02c9590 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -99,6 +99,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "layers_conv_variational_test", + size = "small", + srcs = ["python/kernel_tests/layers_conv_variational_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + ], +) + cuda_py_test( name = "layers_dense_variational_test", size = "small", @@ -118,6 +137,26 @@ cuda_py_test( ], ) +cuda_py_test( + name = "mcmc_diagnostics_test", + size = "small", + srcs = ["python/kernel_tests/mcmc_diagnostics_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/python:spectral_ops_test_util", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + ], +) + cuda_py_test( name = "monte_carlo_test", size = "small", @@ -156,6 +195,7 @@ cuda_py_test( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = ["no_mac"], # b/73192243 ) cuda_py_test( @@ -198,6 +238,46 @@ cuda_py_test( "//tensorflow/python:platform_test", "//tensorflow/python:random_seed", ], + tags = ["notsan"], +) + +cuda_py_test( + name = "variable_utils_test", + size = "small", + srcs = ["python/kernel_tests/variable_utils_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "variational_sgd_optimizer_test", + size = "small", + srcs = ["python/kernel_tests/variational_sgd_optimizer_test.py"], + additional_deps = [ + ":bayesflow_py", + "//third_party/py/numpy", + "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", + ], + tags = ["notsan"], ) filegroup( diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 95b9452b1ada60c44672f37800ced2133d2bd8b2..528c4fbacd06c7b0defa0e32bd24a98b2bc07b64 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -26,9 +26,11 @@ from tensorflow.contrib.bayesflow.python.ops import custom_grad from tensorflow.contrib.bayesflow.python.ops import halton_sequence from tensorflow.contrib.bayesflow.python.ops import hmc from tensorflow.contrib.bayesflow.python.ops import layers +from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo from tensorflow.contrib.bayesflow.python.ops import optimizers +from tensorflow.contrib.bayesflow.python.ops import variable_utils # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented @@ -42,10 +44,12 @@ _allowed_symbols = [ 'hmc', 'layers', 'metropolis_hastings', + 'mcmc_diagnostics', 'monte_carlo', 'optimizers', 'special_math', 'stochastic_variables', + 'variable_utils', 'variational_inference', ] diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py index b1f108e5f01e4945ee83d8262f1d99877f0fe9f0..5bd834e56245ab4d874544cfd014fe59ae521ea8 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -12,40 +12,53 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for Hamiltonian Monte Carlo. -""" +"""Tests for Hamiltonian Monte Carlo.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections + import numpy as np -from scipy import special from scipy import stats from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _compute_energy_change +from tensorflow.contrib.bayesflow.python.ops.hmc_impl import _leapfrog_integrator +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_linalg_ops +from tensorflow.python.ops import gradients_impl as gradients_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import gamma as gamma_lib +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.platform import test -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.platform import tf_logging as logging_ops + + +def _reduce_variance(x, axis=None, keepdims=False): + sample_mean = math_ops.reduce_mean(x, axis, keepdims=True) + return math_ops.reduce_mean( + math_ops.squared_difference(x, sample_mean), axis, keepdims) -# TODO(b/66964210): Test float16. class HMCTest(test.TestCase): def setUp(self): self._shape_param = 5. self._rate_param = 10. - self._expected_x = (special.digamma(self._shape_param) - - np.log(self._rate_param)) - self._expected_exp_x = self._shape_param / self._rate_param random_seed.set_random_seed(10003) np.random.seed(10003) + def assertAllFinite(self, x): + self.assertAllEqual(np.ones_like(x).astype(bool), np.isfinite(x)) + def _log_gamma_log_prob(self, x, event_dims=()): """Computes log-pdf of a log-gamma random variable. @@ -60,63 +73,46 @@ class HMCTest(test.TestCase): self._rate_param * math_ops.exp(x), event_dims) - def _log_gamma_log_prob_grad(self, x, event_dims=()): - """Computes log-pdf and gradient of a log-gamma random variable. - - Args: - x: Value of the random variable. - event_dims: Dimensions not to treat as independent. Default is (), - i.e., all dimensions are independent. - - Returns: - log_prob: The log-pdf up to a normalizing constant. - grad: The gradient of the log-pdf with respect to x. - """ - return (math_ops.reduce_sum(self._shape_param * x - - self._rate_param * math_ops.exp(x), - event_dims), - self._shape_param - self._rate_param * math_ops.exp(x)) - - def _n_event_dims(self, x_shape, event_dims): - return np.prod([int(x_shape[i]) for i in event_dims]) - - def _integrator_conserves_energy(self, x, event_dims, sess, + def _integrator_conserves_energy(self, x, independent_chain_ndims, sess, feed_dict=None): - def potential_and_grad(x): - log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims) - return -log_prob, -grad - - step_size = array_ops.placeholder(np.float32, [], name='step_size') - hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') + step_size = array_ops.placeholder(np.float32, [], name="step_size") + hmc_lf_steps = array_ops.placeholder(np.int32, [], name="hmc_lf_steps") if feed_dict is None: feed_dict = {} feed_dict[hmc_lf_steps] = 1000 - m = random_ops.random_normal(array_ops.shape(x)) - potential_0, grad_0 = potential_and_grad(x) - old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m, - event_dims) - - _, new_m, potential_1, _ = ( - hmc.leapfrog_integrator(step_size, hmc_lf_steps, x, - m, potential_and_grad, grad_0)) + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) - new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, + m = random_ops.random_normal(array_ops.shape(x)) + log_prob_0 = self._log_gamma_log_prob(x, event_dims) + grad_0 = gradients_ops.gradients(log_prob_0, x) + old_energy = -log_prob_0 + 0.5 * math_ops.reduce_sum(m**2., event_dims) + + new_m, _, log_prob_1, _ = _leapfrog_integrator( + current_momentums=[m], + target_log_prob_fn=lambda x: self._log_gamma_log_prob(x, event_dims), + current_state_parts=[x], + step_sizes=[step_size], + num_leapfrog_steps=hmc_lf_steps, + current_target_log_prob=log_prob_0, + current_grads_target_log_prob=grad_0) + new_m = new_m[0] + + new_energy = -log_prob_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, event_dims) x_shape = sess.run(x, feed_dict).shape - n_event_dims = self._n_event_dims(x_shape, event_dims) - feed_dict[step_size] = 0.1 / n_event_dims - old_energy_val, new_energy_val = sess.run([old_energy, new_energy], - feed_dict) - logging.vlog(1, 'average energy change: {}'.format( - abs(old_energy_val - new_energy_val).mean())) - - self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool), - abs(old_energy_val - new_energy_val) < 1.) - - def _integrator_conserves_energy_wrapper(self, event_dims): + event_size = np.prod(x_shape[independent_chain_ndims:]) + feed_dict[step_size] = 0.1 / event_size + old_energy_, new_energy_ = sess.run([old_energy, new_energy], + feed_dict) + logging_ops.vlog(1, "average energy relative change: {}".format( + (1. - new_energy_ / old_energy_).mean())) + self.assertAllClose(old_energy_, new_energy_, atol=0., rtol=0.02) + + def _integrator_conserves_energy_wrapper(self, independent_chain_ndims): """Tests the long-term energy conservation of the leapfrog integrator. The leapfrog integrator is symplectic, so for sufficiently small step @@ -124,135 +120,310 @@ class HMCTest(test.TestCase): the energy of the system blowing up or collapsing. Args: - event_dims: A tuple of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. + independent_chain_ndims: Python `int` scalar representing the number of + dims associated with independent chains. """ - with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - - feed_dict = {x_ph: np.zeros([50, 10, 2])} - self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict) + with self.test_session(graph=ops.Graph()) as sess: + x_ph = array_ops.placeholder(np.float32, name="x_ph") + feed_dict = {x_ph: np.random.rand(50, 10, 2)} + self._integrator_conserves_energy(x_ph, independent_chain_ndims, + sess, feed_dict) def testIntegratorEnergyConservationNullShape(self): - self._integrator_conserves_energy_wrapper([]) + self._integrator_conserves_energy_wrapper(0) def testIntegratorEnergyConservation1(self): - self._integrator_conserves_energy_wrapper([1]) + self._integrator_conserves_energy_wrapper(1) def testIntegratorEnergyConservation2(self): - self._integrator_conserves_energy_wrapper([2]) - - def testIntegratorEnergyConservation12(self): - self._integrator_conserves_energy_wrapper([1, 2]) - - def testIntegratorEnergyConservation012(self): - self._integrator_conserves_energy_wrapper([0, 1, 2]) - - def _chain_gets_correct_expectations(self, x, event_dims, sess, - feed_dict=None): + self._integrator_conserves_energy_wrapper(2) + + def testIntegratorEnergyConservation3(self): + self._integrator_conserves_energy_wrapper(3) + + def testSampleChainSeedReproducibleWorksCorrectly(self): + with self.test_session(graph=ops.Graph()) as sess: + num_results = 10 + independent_chain_ndims = 1 + + def log_gamma_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) + return self._log_gamma_log_prob(x, event_dims) + + kwargs = dict( + target_log_prob_fn=log_gamma_log_prob, + current_state=np.random.rand(4, 3, 2), + step_size=0.1, + num_leapfrog_steps=2, + num_burnin_steps=150, + seed=52, + ) + + samples0, kernel_results0 = hmc.sample_chain( + **dict(list(kwargs.items()) + list(dict( + num_results=2 * num_results, + num_steps_between_results=0).items()))) + + samples1, kernel_results1 = hmc.sample_chain( + **dict(list(kwargs.items()) + list(dict( + num_results=num_results, + num_steps_between_results=1).items()))) + + [ + samples0_, + samples1_, + target_log_prob0_, + target_log_prob1_, + ] = sess.run([ + samples0, + samples1, + kernel_results0.current_target_log_prob, + kernel_results1.current_target_log_prob, + ]) + self.assertAllClose(samples0_[::2], samples1_, + atol=1e-5, rtol=1e-5) + self.assertAllClose(target_log_prob0_[::2], target_log_prob1_, + atol=1e-5, rtol=1e-5) + + def _chain_gets_correct_expectations(self, x, independent_chain_ndims, + sess, feed_dict=None): + counter = collections.Counter() def log_gamma_log_prob(x): + counter["target_calls"] += 1 + event_dims = math_ops.range(independent_chain_ndims, + array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) - step_size = array_ops.placeholder(np.float32, [], name='step_size') - hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') - hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps') + num_results = array_ops.placeholder( + np.int32, [], name="num_results") + step_size = array_ops.placeholder( + np.float32, [], name="step_size") + num_leapfrog_steps = array_ops.placeholder( + np.int32, [], name="num_leapfrog_steps") if feed_dict is None: feed_dict = {} - feed_dict.update({step_size: 0.1, - hmc_lf_steps: 2, - hmc_n_steps: 300}) - - sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps], - step_size, - hmc_lf_steps, - x, log_gamma_log_prob, - event_dims) - - acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain], - feed_dict) - samples = samples[feed_dict[hmc_n_steps] // 2:] - expected_x_est = samples.mean() - expected_exp_x_est = np.exp(samples).mean() - - logging.vlog(1, 'True E[x, exp(x)]: {}\t{}'.format( - self._expected_x, self._expected_exp_x)) - logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format( - expected_x_est, expected_exp_x_est)) - self.assertNear(expected_x_est, self._expected_x, 2e-2) - self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2) - self.assertTrue((acceptance_probs > 0.5).all()) - self.assertTrue((acceptance_probs <= 1.0).all()) - - def _chain_gets_correct_expectations_wrapper(self, event_dims): - with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - - feed_dict = {x_ph: np.zeros([50, 10, 2])} - self._chain_gets_correct_expectations(x_ph, event_dims, sess, - feed_dict) + feed_dict.update({num_results: 150, + step_size: 0.05, + num_leapfrog_steps: 2}) + + samples, kernel_results = hmc.sample_chain( + num_results=num_results, + target_log_prob_fn=log_gamma_log_prob, + current_state=x, + step_size=step_size, + num_leapfrog_steps=num_leapfrog_steps, + num_burnin_steps=150, + seed=42) + + self.assertAllEqual(dict(target_calls=2), counter) + + expected_x = (math_ops.digamma(self._shape_param) + - np.log(self._rate_param)) + + expected_exp_x = self._shape_param / self._rate_param + + acceptance_probs_, samples_, expected_x_ = sess.run( + [kernel_results.acceptance_probs, samples, expected_x], + feed_dict) + + actual_x = samples_.mean() + actual_exp_x = np.exp(samples_).mean() + + logging_ops.vlog(1, "True E[x, exp(x)]: {}\t{}".format( + expected_x_, expected_exp_x)) + logging_ops.vlog(1, "Estimated E[x, exp(x)]: {}\t{}".format( + actual_x, actual_exp_x)) + self.assertNear(actual_x, expected_x_, 2e-2) + self.assertNear(actual_exp_x, expected_exp_x, 2e-2) + self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), + acceptance_probs_ > 0.5) + self.assertAllEqual(np.ones_like(acceptance_probs_, np.bool), + acceptance_probs_ <= 1.) + + def _chain_gets_correct_expectations_wrapper(self, independent_chain_ndims): + with self.test_session(graph=ops.Graph()) as sess: + x_ph = array_ops.placeholder(np.float32, name="x_ph") + feed_dict = {x_ph: np.random.rand(50, 10, 2)} + self._chain_gets_correct_expectations(x_ph, independent_chain_ndims, + sess, feed_dict) def testHMCChainExpectationsNullShape(self): - self._chain_gets_correct_expectations_wrapper([]) + self._chain_gets_correct_expectations_wrapper(0) def testHMCChainExpectations1(self): - self._chain_gets_correct_expectations_wrapper([1]) + self._chain_gets_correct_expectations_wrapper(1) def testHMCChainExpectations2(self): - self._chain_gets_correct_expectations_wrapper([2]) - - def testHMCChainExpectations12(self): - self._chain_gets_correct_expectations_wrapper([1, 2]) - - def _kernel_leaves_target_invariant(self, initial_draws, event_dims, + self._chain_gets_correct_expectations_wrapper(2) + + def testKernelResultsUsingTruncatedDistribution(self): + def log_prob(x): + return array_ops.where( + x >= 0., + -x - x**2, # Non-constant gradient. + array_ops.fill(x.shape, math_ops.cast(-np.inf, x.dtype))) + # This log_prob has the property that it is likely to attract + # the HMC flow toward, and below, zero...but for x <=0, + # log_prob(x) = -inf, which should result in rejection, as well + # as a non-finite log_prob. Thus, this distribution gives us an opportunity + # to test out the kernel results ability to correctly capture rejections due + # to finite AND non-finite reasons. + # Why use a non-constant gradient? This ensures the leapfrog integrator + # will not be exact. + + num_results = 1000 + # Large step size, will give rejections due to integration error in addition + # to rejection due to going into a region of log_prob = -inf. + step_size = 0.1 + num_leapfrog_steps = 5 + num_chains = 2 + + with self.test_session(graph=ops.Graph()) as sess: + + # Start multiple independent chains. + initial_state = ops.convert_to_tensor([0.1] * num_chains) + + states, kernel_results = hmc.sample_chain( + num_results=num_results, + target_log_prob_fn=log_prob, + current_state=initial_state, + step_size=step_size, + num_leapfrog_steps=num_leapfrog_steps, + seed=42) + + states_, kernel_results_ = sess.run([states, kernel_results]) + pstates_ = kernel_results_.proposed_state + + neg_inf_mask = np.isneginf(kernel_results_.proposed_target_log_prob) + + # First: Test that the mathematical properties of the above log prob + # function in conjunction with HMC show up as expected in kernel_results_. + + # We better have log_prob = -inf some of the time. + self.assertLess(0, neg_inf_mask.sum()) + # We better have some rejections due to something other than -inf. + self.assertLess(neg_inf_mask.sum(), (~kernel_results_.is_accepted).sum()) + # We better have been accepted a decent amount, even near the end of the + # chain, or else this HMC run just got stuck at some point. + self.assertLess( + 0.1, kernel_results_.is_accepted[int(0.9 * num_results):].mean()) + # We better not have any NaNs in proposed state or log_prob. + # We may have some NaN in grads, which involve multiplication/addition due + # to gradient rules. This is the known "NaN grad issue with tf.where." + self.assertAllEqual(np.zeros_like(states_), + np.isnan(kernel_results_.proposed_target_log_prob)) + self.assertAllEqual(np.zeros_like(states_), + np.isnan(states_)) + # We better not have any +inf in states, grads, or log_prob. + self.assertAllEqual(np.zeros_like(states_), + np.isposinf(kernel_results_.proposed_target_log_prob)) + self.assertAllEqual( + np.zeros_like(states_), + np.isposinf(kernel_results_.proposed_grads_target_log_prob[0])) + self.assertAllEqual(np.zeros_like(states_), + np.isposinf(states_)) + + # Second: Test that kernel_results is congruent with itself and + # acceptance/rejection of states. + + # Proposed state is negative iff proposed target log prob is -inf. + np.testing.assert_array_less(pstates_[neg_inf_mask], 0.) + np.testing.assert_array_less(0., pstates_[~neg_inf_mask]) + + # Acceptance probs are zero whenever proposed state is negative. + self.assertAllEqual( + np.zeros_like(pstates_[neg_inf_mask]), + kernel_results_.acceptance_probs[neg_inf_mask]) + + # The move is accepted ==> state = proposed state. + self.assertAllEqual( + states_[kernel_results_.is_accepted], + pstates_[kernel_results_.is_accepted], + ) + # The move was rejected <==> state[t] == state[t - 1]. + for t in range(1, num_results): + for i in range(num_chains): + if kernel_results_.is_accepted[t, i]: + self.assertNotEqual(states_[t, i], states_[t - 1, i]) + else: + self.assertEqual(states_[t, i], states_[t - 1, i]) + + def _kernel_leaves_target_invariant(self, initial_draws, + independent_chain_ndims, sess, feed_dict=None): def log_gamma_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) def fake_log_prob(x): """Cooled version of the target distribution.""" return 1.1 * log_gamma_log_prob(x) - step_size = array_ops.placeholder(np.float32, [], name='step_size') + step_size = array_ops.placeholder(np.float32, [], name="step_size") if feed_dict is None: feed_dict = {} feed_dict[step_size] = 0.4 - sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws, - log_gamma_log_prob, event_dims) - bad_sample, bad_acceptance_probs, _, _ = hmc.kernel( - step_size, 5, initial_draws, fake_log_prob, event_dims) - (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val, - updated_draws_val, fake_draws_val) = sess.run([acceptance_probs, - bad_acceptance_probs, - initial_draws, sample, - bad_sample], feed_dict) + sample, kernel_results = hmc.kernel( + target_log_prob_fn=log_gamma_log_prob, + current_state=initial_draws, + step_size=step_size, + num_leapfrog_steps=5, + seed=43) + + bad_sample, bad_kernel_results = hmc.kernel( + target_log_prob_fn=fake_log_prob, + current_state=initial_draws, + step_size=step_size, + num_leapfrog_steps=5, + seed=44) + + [ + acceptance_probs_, + bad_acceptance_probs_, + initial_draws_, + updated_draws_, + fake_draws_, + ] = sess.run([ + kernel_results.acceptance_probs, + bad_kernel_results.acceptance_probs, + initial_draws, + sample, + bad_sample, + ], feed_dict) + # Confirm step size is small enough that we usually accept. - self.assertGreater(acceptance_probs_val.mean(), 0.5) - self.assertGreater(bad_acceptance_probs_val.mean(), 0.5) + self.assertGreater(acceptance_probs_.mean(), 0.5) + self.assertGreater(bad_acceptance_probs_.mean(), 0.5) + # Confirm step size is large enough that we sometimes reject. - self.assertLess(acceptance_probs_val.mean(), 0.99) - self.assertLess(bad_acceptance_probs_val.mean(), 0.99) - _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(), - updated_draws_val.flatten()) - _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(), - fake_draws_val.flatten()) - logging.vlog(1, 'acceptance rate for true target: {}'.format( - acceptance_probs_val.mean())) - logging.vlog(1, 'acceptance rate for fake target: {}'.format( - bad_acceptance_probs_val.mean())) - logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true)) - logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake)) + self.assertLess(acceptance_probs_.mean(), 0.99) + self.assertLess(bad_acceptance_probs_.mean(), 0.99) + + _, ks_p_value_true = stats.ks_2samp(initial_draws_.flatten(), + updated_draws_.flatten()) + _, ks_p_value_fake = stats.ks_2samp(initial_draws_.flatten(), + fake_draws_.flatten()) + + logging_ops.vlog(1, "acceptance rate for true target: {}".format( + acceptance_probs_.mean())) + logging_ops.vlog(1, "acceptance rate for fake target: {}".format( + bad_acceptance_probs_.mean())) + logging_ops.vlog(1, "K-S p-value for true target: {}".format( + ks_p_value_true)) + logging_ops.vlog(1, "K-S p-value for fake target: {}".format( + ks_p_value_fake)) # Make sure that the MCMC update hasn't changed the empirical CDF much. self.assertGreater(ks_p_value_true, 1e-3) # Confirm that targeting the wrong distribution does # significantly change the empirical CDF. self.assertLess(ks_p_value_fake, 1e-6) - def _kernel_leaves_target_invariant_wrapper(self, event_dims): + def _kernel_leaves_target_invariant_wrapper(self, independent_chain_ndims): """Tests that the kernel leaves the target distribution invariant. Draws some independent samples from the target distribution, @@ -264,86 +435,429 @@ class HMCTest(test.TestCase): does change the target distribution. (And that we can detect that.) Args: - event_dims: A tuple of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. + independent_chain_ndims: Python `int` scalar representing the number of + dims associated with independent chains. """ - with self.test_session() as sess: + with self.test_session(graph=ops.Graph()) as sess: initial_draws = np.log(np.random.gamma(self._shape_param, size=[50000, 2, 2])) initial_draws -= np.log(self._rate_param) - x_ph = array_ops.placeholder(np.float32, name='x_ph') + x_ph = array_ops.placeholder(np.float32, name="x_ph") feed_dict = {x_ph: initial_draws} - self._kernel_leaves_target_invariant(x_ph, event_dims, sess, - feed_dict) - - def testKernelLeavesTargetInvariantNullShape(self): - self._kernel_leaves_target_invariant_wrapper([]) + self._kernel_leaves_target_invariant(x_ph, independent_chain_ndims, + sess, feed_dict) def testKernelLeavesTargetInvariant1(self): - self._kernel_leaves_target_invariant_wrapper([1]) + self._kernel_leaves_target_invariant_wrapper(1) def testKernelLeavesTargetInvariant2(self): - self._kernel_leaves_target_invariant_wrapper([2]) + self._kernel_leaves_target_invariant_wrapper(2) - def testKernelLeavesTargetInvariant12(self): - self._kernel_leaves_target_invariant_wrapper([1, 2]) + def testKernelLeavesTargetInvariant3(self): + self._kernel_leaves_target_invariant_wrapper(3) + + def _ais_gets_correct_log_normalizer(self, init, independent_chain_ndims, + sess, feed_dict=None): + counter = collections.Counter() - def _ais_gets_correct_log_normalizer(self, init, event_dims, sess, - feed_dict=None): def proposal_log_prob(x): - return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi), - event_dims) + counter["proposal_calls"] += 1 + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) + return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), + axis=event_dims) def target_log_prob(x): + counter["target_calls"] += 1 + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) return self._log_gamma_log_prob(x, event_dims) if feed_dict is None: feed_dict = {} - w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob, - proposal_log_prob, event_dims) - - w_val = sess.run(w, feed_dict) - init_shape = sess.run(init, feed_dict).shape - normalizer_multiplier = np.prod([init_shape[i] for i in event_dims]) - - true_normalizer = -self._shape_param * np.log(self._rate_param) - true_normalizer += special.gammaln(self._shape_param) - true_normalizer *= normalizer_multiplier - - n_weights = np.prod(w_val.shape) - normalized_w = np.exp(w_val - true_normalizer) - standard_error = np.std(normalized_w) / np.sqrt(n_weights) - logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format( - true_normalizer, np.log(normalized_w.mean()) + true_normalizer, - n_weights)) - self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error) - - def _ais_gets_correct_log_normalizer_wrapper(self, event_dims): + num_steps = 200 + + _, ais_weights, _ = hmc.sample_annealed_importance_chain( + proposal_log_prob_fn=proposal_log_prob, + num_steps=num_steps, + target_log_prob_fn=target_log_prob, + step_size=0.5, + current_state=init, + num_leapfrog_steps=2, + seed=45) + + # We have three calls because the calculation of `ais_weights` entails + # another call to the `convex_combined_log_prob_fn`. We could refactor + # things to avoid this, if needed (eg, b/72994218). + self.assertAllEqual(dict(target_calls=3, proposal_calls=3), counter) + + event_shape = array_ops.shape(init)[independent_chain_ndims:] + event_size = math_ops.reduce_prod(event_shape) + + log_true_normalizer = ( + -self._shape_param * math_ops.log(self._rate_param) + + math_ops.lgamma(self._shape_param)) + log_true_normalizer *= math_ops.cast(event_size, log_true_normalizer.dtype) + + log_estimated_normalizer = (math_ops.reduce_logsumexp(ais_weights) + - np.log(num_steps)) + + ratio_estimate_true = math_ops.exp(ais_weights - log_true_normalizer) + ais_weights_size = array_ops.size(ais_weights) + standard_error = math_ops.sqrt( + _reduce_variance(ratio_estimate_true) + / math_ops.cast(ais_weights_size, ratio_estimate_true.dtype)) + + [ + ratio_estimate_true_, + log_true_normalizer_, + log_estimated_normalizer_, + standard_error_, + ais_weights_size_, + event_size_, + ] = sess.run([ + ratio_estimate_true, + log_true_normalizer, + log_estimated_normalizer, + standard_error, + ais_weights_size, + event_size, + ], feed_dict) + + logging_ops.vlog(1, " log_true_normalizer: {}\n" + " log_estimated_normalizer: {}\n" + " ais_weights_size: {}\n" + " event_size: {}\n".format( + log_true_normalizer_, + log_estimated_normalizer_, + ais_weights_size_, + event_size_)) + self.assertNear(ratio_estimate_true_.mean(), 1., 4. * standard_error_) + + def _ais_gets_correct_log_normalizer_wrapper(self, independent_chain_ndims): """Tests that AIS yields reasonable estimates of normalizers.""" - with self.test_session() as sess: - x_ph = array_ops.placeholder(np.float32, name='x_ph') - + with self.test_session(graph=ops.Graph()) as sess: + x_ph = array_ops.placeholder(np.float32, name="x_ph") initial_draws = np.random.normal(size=[30, 2, 1]) - feed_dict = {x_ph: initial_draws} - - self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess, - feed_dict) - - def testAISNullShape(self): - self._ais_gets_correct_log_normalizer_wrapper([]) + self._ais_gets_correct_log_normalizer( + x_ph, + independent_chain_ndims, + sess, + feed_dict={x_ph: initial_draws}) def testAIS1(self): - self._ais_gets_correct_log_normalizer_wrapper([1]) + self._ais_gets_correct_log_normalizer_wrapper(1) def testAIS2(self): - self._ais_gets_correct_log_normalizer_wrapper([2]) - - def testAIS12(self): - self._ais_gets_correct_log_normalizer_wrapper([1, 2]) - -if __name__ == '__main__': + self._ais_gets_correct_log_normalizer_wrapper(2) + + def testAIS3(self): + self._ais_gets_correct_log_normalizer_wrapper(3) + + def testSampleAIChainSeedReproducibleWorksCorrectly(self): + with self.test_session(graph=ops.Graph()) as sess: + independent_chain_ndims = 1 + x = np.random.rand(4, 3, 2) + + def proposal_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) + return -0.5 * math_ops.reduce_sum(x**2. + np.log(2 * np.pi), + axis=event_dims) + + def target_log_prob(x): + event_dims = math_ops.range(independent_chain_ndims, array_ops.rank(x)) + return self._log_gamma_log_prob(x, event_dims) + + ais_kwargs = dict( + proposal_log_prob_fn=proposal_log_prob, + num_steps=200, + target_log_prob_fn=target_log_prob, + step_size=0.5, + current_state=x, + num_leapfrog_steps=2, + seed=53) + + _, ais_weights0, _ = hmc.sample_annealed_importance_chain( + **ais_kwargs) + + _, ais_weights1, _ = hmc.sample_annealed_importance_chain( + **ais_kwargs) + + [ais_weights0_, ais_weights1_] = sess.run([ + ais_weights0, ais_weights1]) + + self.assertAllClose(ais_weights0_, ais_weights1_, + atol=1e-5, rtol=1e-5) + + def testNanRejection(self): + """Tests that an update that yields NaN potentials gets rejected. + + We run HMC with a target distribution that returns NaN + log-likelihoods if any element of x < 0, and unit-scale + exponential log-likelihoods otherwise. The exponential potential + pushes x towards 0, ensuring that any reasonably large update will + push us over the edge into NaN territory. + """ + def _unbounded_exponential_log_prob(x): + """An exponential distribution with log-likelihood NaN for x < 0.""" + per_element_potentials = array_ops.where( + x < 0., + array_ops.fill(array_ops.shape(x), x.dtype.as_numpy_dtype(np.nan)), + -x) + return math_ops.reduce_sum(per_element_potentials) + + with self.test_session(graph=ops.Graph()) as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, kernel_results = hmc.kernel( + target_log_prob_fn=_unbounded_exponential_log_prob, + current_state=initial_x, + step_size=2., + num_leapfrog_steps=5, + seed=46) + initial_x_, updated_x_, acceptance_probs_ = sess.run( + [initial_x, updated_x, kernel_results.acceptance_probs]) + + logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) + logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) + logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + + self.assertAllEqual(initial_x_, updated_x_) + self.assertEqual(acceptance_probs_, 0.) + + def testNanFromGradsDontPropagate(self): + """Test that update with NaN gradients does not cause NaN in results.""" + def _nan_log_prob_with_nan_gradient(x): + return np.nan * math_ops.reduce_sum(x) + + with self.test_session(graph=ops.Graph()) as sess: + initial_x = math_ops.linspace(0.01, 5, 10) + updated_x, kernel_results = hmc.kernel( + target_log_prob_fn=_nan_log_prob_with_nan_gradient, + current_state=initial_x, + step_size=2., + num_leapfrog_steps=5, + seed=47) + initial_x_, updated_x_, acceptance_probs_ = sess.run( + [initial_x, updated_x, kernel_results.acceptance_probs]) + + logging_ops.vlog(1, "initial_x = {}".format(initial_x_)) + logging_ops.vlog(1, "updated_x = {}".format(updated_x_)) + logging_ops.vlog(1, "acceptance_probs = {}".format(acceptance_probs_)) + + self.assertAllEqual(initial_x_, updated_x_) + self.assertEqual(acceptance_probs_, 0.) + + self.assertAllFinite( + gradients_ops.gradients(updated_x, initial_x)[0].eval()) + self.assertAllEqual([True], [g is None for g in gradients_ops.gradients( + kernel_results.proposed_grads_target_log_prob, initial_x)]) + self.assertAllEqual([False], [g is None for g in gradients_ops.gradients( + kernel_results.proposed_grads_target_log_prob, + kernel_results.proposed_state)]) + + # Gradients of the acceptance probs and new log prob are not finite. + # self.assertAllFinite( + # gradients_ops.gradients(acceptance_probs, initial_x)[0].eval()) + # self.assertAllFinite( + # gradients_ops.gradients(new_log_prob, initial_x)[0].eval()) + + def _testChainWorksDtype(self, dtype): + with self.test_session(graph=ops.Graph()) as sess: + states, kernel_results = hmc.sample_chain( + num_results=10, + target_log_prob_fn=lambda x: -math_ops.reduce_sum(x**2., axis=-1), + current_state=np.zeros(5).astype(dtype), + step_size=0.01, + num_leapfrog_steps=10, + seed=48) + states_, acceptance_probs_ = sess.run( + [states, kernel_results.acceptance_probs]) + self.assertEqual(dtype, states_.dtype) + self.assertEqual(dtype, acceptance_probs_.dtype) + + def testChainWorksIn64Bit(self): + self._testChainWorksDtype(np.float64) + + def testChainWorksIn16Bit(self): + self._testChainWorksDtype(np.float16) + + def testChainWorksCorrelatedMultivariate(self): + dtype = np.float32 + true_mean = dtype([0, 0]) + true_cov = dtype([[1, 0.5], + [0.5, 1]]) + num_results = 2000 + counter = collections.Counter() + with self.test_session(graph=ops.Graph()) as sess: + def target_log_prob(x, y): + counter["target_calls"] += 1 + # Corresponds to unnormalized MVN. + # z = matmul(inv(chol(true_cov)), [x, y] - true_mean) + z = array_ops.stack([x, y], axis=-1) - true_mean + z = array_ops.squeeze( + gen_linalg_ops.matrix_triangular_solve( + np.linalg.cholesky(true_cov), + z[..., array_ops.newaxis]), + axis=-1) + return -0.5 * math_ops.reduce_sum(z**2., axis=-1) + states, _ = hmc.sample_chain( + num_results=num_results, + target_log_prob_fn=target_log_prob, + current_state=[dtype(-2), dtype(2)], + step_size=[0.5, 0.5], + num_leapfrog_steps=2, + num_burnin_steps=200, + num_steps_between_results=1, + seed=54) + self.assertAllEqual(dict(target_calls=2), counter) + states = array_ops.stack(states, axis=-1) + self.assertEqual(num_results, states.shape[0].value) + sample_mean = math_ops.reduce_mean(states, axis=0) + x = states - sample_mean + sample_cov = math_ops.matmul(x, x, transpose_a=True) / dtype(num_results) + [sample_mean_, sample_cov_] = sess.run([ + sample_mean, sample_cov]) + self.assertAllClose(true_mean, sample_mean_, + atol=0.05, rtol=0.) + self.assertAllClose(true_cov, sample_cov_, + atol=0., rtol=0.1) + + +class _EnergyComputationTest(object): + + def testHandlesNanFromPotential(self): + with self.test_session(graph=ops.Graph()) as sess: + x = [1, np.inf, -np.inf, np.nan] + target_log_prob, proposed_target_log_prob = [ + self.dtype(x.flatten()) for x in np.meshgrid(x, x)] + num_chains = len(target_log_prob) + dummy_momentums = [-1, 1] + momentums = [self.dtype([dummy_momentums] * num_chains)] + proposed_momentums = [self.dtype([dummy_momentums] * num_chains)] + + target_log_prob = ops.convert_to_tensor(target_log_prob) + momentums = [ops.convert_to_tensor(momentums[0])] + proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) + proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] + + energy = _compute_energy_change( + target_log_prob, + momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims=1) + grads = gradients_ops.gradients(energy, momentums) + + [actual_energy, grads_] = sess.run([energy, grads]) + + # Ensure energy is `inf` (note: that's positive inf) in weird cases and + # finite otherwise. + expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) + self.assertAllEqual(expected_energy, actual_energy) + + # Ensure gradient is finite. + self.assertAllEqual(np.ones_like(grads_).astype(np.bool), + np.isfinite(grads_)) + + def testHandlesNanFromKinetic(self): + with self.test_session(graph=ops.Graph()) as sess: + x = [1, np.inf, -np.inf, np.nan] + momentums, proposed_momentums = [ + [np.reshape(self.dtype(x), [-1, 1])] + for x in np.meshgrid(x, x)] + num_chains = len(momentums[0]) + target_log_prob = np.ones(num_chains, self.dtype) + proposed_target_log_prob = np.ones(num_chains, self.dtype) + + target_log_prob = ops.convert_to_tensor(target_log_prob) + momentums = [ops.convert_to_tensor(momentums[0])] + proposed_target_log_prob = ops.convert_to_tensor(proposed_target_log_prob) + proposed_momentums = [ops.convert_to_tensor(proposed_momentums[0])] + + energy = _compute_energy_change( + target_log_prob, + momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims=1) + grads = gradients_ops.gradients(energy, momentums) + + [actual_energy, grads_] = sess.run([energy, grads]) + + # Ensure energy is `inf` (note: that's positive inf) in weird cases and + # finite otherwise. + expected_energy = self.dtype([0] + [np.inf]*(num_chains - 1)) + self.assertAllEqual(expected_energy, actual_energy) + + # Ensure gradient is finite. + g = grads_[0].reshape([len(x), len(x)])[:, 0] + self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isfinite(g)) + + # The remaining gradients are nan because the momentum was itself nan or + # inf. + g = grads_[0].reshape([len(x), len(x)])[:, 1:] + self.assertAllEqual(np.ones_like(g).astype(np.bool), np.isnan(g)) + + +class EnergyComputationTest16(test.TestCase, _EnergyComputationTest): + dtype = np.float16 + + +class EnergyComputationTest32(test.TestCase, _EnergyComputationTest): + dtype = np.float32 + + +class EnergyComputationTest64(test.TestCase, _EnergyComputationTest): + dtype = np.float64 + + +class _HMCHandlesLists(object): + + def testStateParts(self): + with self.test_session(graph=ops.Graph()) as sess: + dist_x = normal_lib.Normal(loc=self.dtype(0), scale=self.dtype(1)) + dist_y = independent_lib.Independent( + gamma_lib.Gamma(concentration=self.dtype([1, 2]), + rate=self.dtype([0.5, 0.75])), + reinterpreted_batch_ndims=1) + def target_log_prob(x, y): + return dist_x.log_prob(x) + dist_y.log_prob(y) + x0 = [dist_x.sample(seed=1), dist_y.sample(seed=2)] + samples, _ = hmc.sample_chain( + num_results=int(2e3), + target_log_prob_fn=target_log_prob, + current_state=x0, + step_size=0.85, + num_leapfrog_steps=3, + num_burnin_steps=int(250), + seed=49) + actual_means = [math_ops.reduce_mean(s, axis=0) for s in samples] + actual_vars = [_reduce_variance(s, axis=0) for s in samples] + expected_means = [dist_x.mean(), dist_y.mean()] + expected_vars = [dist_x.variance(), dist_y.variance()] + [ + actual_means_, + actual_vars_, + expected_means_, + expected_vars_, + ] = sess.run([ + actual_means, + actual_vars, + expected_means, + expected_vars, + ]) + self.assertAllClose(expected_means_, actual_means_, atol=0.05, rtol=0.16) + self.assertAllClose(expected_vars_, actual_vars_, atol=0., rtol=0.25) + + +class HMCHandlesLists32(_HMCHandlesLists, test.TestCase): + dtype = np.float32 + + +class HMCHandlesLists64(_HMCHandlesLists, test.TestCase): + dtype = np.float64 + + +if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py new file mode 100644 index 0000000000000000000000000000000000000000..750afb6654311fea30a1dc6b31b20aa3b4160ae2 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_conv_variational_test.py @@ -0,0 +1,521 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for convolutional Bayesian layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import layers_conv_variational as prob_layers_lib +from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.platform import test + + +class Counter(object): + """Helper class to manage incrementing a counting `int`.""" + + def __init__(self): + self._value = -1 + + @property + def value(self): + return self._value + + def __call__(self): + self._value += 1 + return self._value + + +class MockDistribution(independent_lib.Independent): + """Monitors layer calls to the underlying distribution.""" + + def __init__(self, result_sample, result_log_prob, loc=None, scale=None): + self.result_sample = result_sample + self.result_log_prob = result_log_prob + self.result_loc = loc + self.result_scale = scale + self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) + if loc is not None and scale is not None: + self.result_distribution = normal_lib.Normal(loc=self.result_loc, + scale=self.result_scale) + self.called_log_prob = Counter() + self.called_sample = Counter() + self.called_loc = Counter() + self.called_scale = Counter() + + def log_prob(self, *args, **kwargs): + self.called_log_prob() + return self.result_log_prob + + def sample(self, *args, **kwargs): + self.called_sample() + return self.result_sample + + @property + def distribution(self): # for dummy check on Independent(Normal) + return self.result_distribution + + @property + def loc(self): + self.called_loc() + return self.result_loc + + @property + def scale(self): + self.called_scale() + return self.result_scale + + +class MockKLDivergence(object): + """Monitors layer calls to the divergence implementation.""" + + def __init__(self, result): + self.result = result + self.args = [] + self.called = Counter() + + def __call__(self, *args, **kwargs): + self.called() + self.args.append(args) + return self.result + + +class ConvVariational(test.TestCase): + + def _testKLPenaltyKernel(self, layer_class): + with self.test_session(): + layer = layer_class(filters=2, kernel_size=3) + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform([2, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) + + # No keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) + + _ = layer(inputs) + + # Yes keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 1) + self.assertListEqual(layer.losses, losses) + + def _testKLPenaltyBoth(self, layer_class): + def _make_normal(dtype, *args): # pylint: disable=unused-argument + return normal_lib.Normal( + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) + with self.test_session(): + layer = layer_class( + filters=2, + kernel_size=3, + bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), + bias_prior_fn=_make_normal) + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform([2, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 1], seed=1) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform([2, 3, 3, 3, 1], seed=1) + + # No keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) + + _ = layer(inputs) + + # Yes keys. + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 2) + self.assertListEqual(layer.losses, losses) + + def _testConvSetUp(self, layer_class, batch_size, depth=None, + height=None, width=None, channels=None, filters=None, + **kwargs): + seed = Counter() + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform( + [batch_size, width, channels], seed=seed()) + kernel_size = (2,) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform( + [batch_size, height, width, channels], seed=seed()) + kernel_size = (2, 2) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform( + [batch_size, depth, height, width, channels], seed=seed()) + kernel_size = (2, 2, 2) + + kernel_shape = kernel_size + (channels, filters) + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_shape, seed=seed()), + scale=random_ops.random_uniform(kernel_shape, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform(kernel_shape, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_shape, seed=seed())) + + bias_size = (filters,) + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + layer = layer_class( + filters=filters, + kernel_size=kernel_size, + padding="SAME", + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence, + **kwargs) + + outputs = layer(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + return (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, + layer, inputs, outputs, kl_penalty, kernel_shape) + + def _testConvReparameterization(self, layer_class): + batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 + with self.test_session() as sess: + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty, kernel_shape) = self._testConvSetUp( + layer_class, batch_size, + depth=depth, height=height, width=width, channels=channels, + filters=filters) + + convolution_op = nn_ops.Convolution( + tensor_shape.TensorShape(inputs.shape), + filter_shape=tensor_shape.TensorShape(kernel_shape), + padding="SAME") + expected_outputs = convolution_op(inputs, kernel_posterior.result_sample) + expected_outputs = nn.bias_add(expected_outputs, + bias_posterior.result_sample, + data_format="NHWC") + + [ + expected_outputs_, actual_outputs_, + expected_kernel_, actual_kernel_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_posterior.result_sample, layer.kernel_posterior_tensor, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_kernel_, actual_kernel_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, + kernel_prior.distribution, + kernel_posterior.result_sample]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def _testConvFlipout(self, layer_class): + batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 + with self.test_session() as sess: + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty, kernel_shape) = self._testConvSetUp( + layer_class, batch_size, + depth=depth, height=height, width=width, channels=channels, + filters=filters, seed=44) + + convolution_op = nn_ops.Convolution( + tensor_shape.TensorShape(inputs.shape), + filter_shape=tensor_shape.TensorShape(kernel_shape), + padding="SAME") + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(kernel_posterior.result_loc), + scale=kernel_posterior.result_scale) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) + + expected_outputs = convolution_op( + inputs, kernel_posterior.distribution.loc) + + input_shape = array_ops.shape(inputs) + output_shape = array_ops.shape(expected_outputs) + batch_shape = array_ops.expand_dims(input_shape[0], 0) + channels = input_shape[-1] + rank = len(inputs.get_shape()) - 2 + + sign_input = random_ops.random_uniform( + array_ops.concat([batch_shape, + array_ops.expand_dims(channels, 0)], 0), + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=layer.seed) + sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) + sign_output = random_ops.random_uniform( + array_ops.concat([batch_shape, + array_ops.expand_dims(filters, 0)], 0), + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=distribution_util.gen_new_seed( + layer.seed, salt="conv_flipout")) + sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) + for _ in range(rank): + sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) + sign_output = array_ops.expand_dims(sign_output, 1) + + sign_input = array_ops.tile( # tile for element-wise op broadcasting + sign_input, + [1] + [input_shape[i + 1] for i in range(rank)] + [1]) + sign_output = array_ops.tile( + sign_output, + [1] + [output_shape[i + 1] for i in range(rank)] + [1]) + + perturbed_inputs = convolution_op( + inputs * sign_input, expected_kernel_posterior_affine_tensor) + perturbed_inputs *= sign_output + + expected_outputs += perturbed_inputs + expected_outputs = nn.bias_add(expected_outputs, + bias_posterior.result_sample, + data_format="NHWC") + + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) + + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, kernel_prior.distribution, None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def _testRandomConvFlipout(self, layer_class): + batch_size, depth, height, width, channels, filters = 2, 4, 4, 4, 3, 5 + with self.test_session() as sess: + seed = Counter() + if layer_class in (prob_layers_lib.Conv1DReparameterization, + prob_layers_lib.Conv1DFlipout): + inputs = random_ops.random_uniform( + [batch_size, width, channels], seed=seed()) + kernel_size = (2,) + elif layer_class in (prob_layers_lib.Conv2DReparameterization, + prob_layers_lib.Conv2DFlipout): + inputs = random_ops.random_uniform( + [batch_size, height, width, channels], seed=seed()) + kernel_size = (2, 2) + elif layer_class in (prob_layers_lib.Conv3DReparameterization, + prob_layers_lib.Conv3DFlipout): + inputs = random_ops.random_uniform( + [batch_size, depth, height, width, channels], seed=seed()) + kernel_size = (2, 2, 2) + + kernel_shape = kernel_size + (channels, filters) + bias_size = (filters,) + + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform( + kernel_shape, seed=seed()), + scale=random_ops.random_uniform( + kernel_shape, seed=seed()), + result_log_prob=random_ops.random_uniform( + kernel_shape, seed=seed()), + result_sample=random_ops.random_uniform( + kernel_shape, seed=seed())) + bias_posterior = MockDistribution( + loc=random_ops.random_uniform( + bias_size, seed=seed()), + scale=random_ops.random_uniform( + bias_size, seed=seed()), + result_log_prob=random_ops.random_uniform( + bias_size, seed=seed()), + result_sample=random_ops.random_uniform( + bias_size, seed=seed())) + layer_one = layer_class( + filters=filters, + kernel_size=kernel_size, + padding="SAME", + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=44) + layer_two = layer_class( + filters=filters, + kernel_size=kernel_size, + padding="SAME", + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=45) + + outputs_one = layer_one(inputs) + outputs_two = layer_two(inputs) + + outputs_one_, outputs_two_ = sess.run([ + outputs_one, outputs_two]) + + self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), + np.prod(outputs_one_.shape)) + + def testKLPenaltyKernelConv1DReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv1DReparameterization) + + def testKLPenaltyKernelConv2DReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv2DReparameterization) + + def testKLPenaltyKernelConv3DReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv3DReparameterization) + + def testKLPenaltyKernelConv1DFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv1DFlipout) + + def testKLPenaltyKernelConv2DFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv2DFlipout) + + def testKLPenaltyKernelConv3DFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.Conv3DFlipout) + + def testKLPenaltyBothConv1DReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv1DReparameterization) + + def testKLPenaltyBothConv2DReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv2DReparameterization) + + def testKLPenaltyBothConv3DReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv3DReparameterization) + + def testKLPenaltyBothConv1DFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv1DFlipout) + + def testKLPenaltyBothConv2DFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv2DFlipout) + + def testKLPenaltyBothConv3DFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.Conv3DFlipout) + + def testConv1DReparameterization(self): + self._testConvReparameterization(prob_layers_lib.Conv1DReparameterization) + + def testConv2DReparameterization(self): + self._testConvReparameterization(prob_layers_lib.Conv2DReparameterization) + + def testConv3DReparameterization(self): + self._testConvReparameterization(prob_layers_lib.Conv3DReparameterization) + + def testConv1DFlipout(self): + self._testConvFlipout(prob_layers_lib.Conv1DFlipout) + + def testConv2DFlipout(self): + self._testConvFlipout(prob_layers_lib.Conv2DFlipout) + + def testConv3DFlipout(self): + self._testConvFlipout(prob_layers_lib.Conv3DFlipout) + + def testRandomConv1DFlipout(self): + self._testRandomConvFlipout(prob_layers_lib.Conv1DFlipout) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py index 50358fd1c2b7635ffe2d08c5af3219bb0a11498b..342f38ccec7ec74db1b393d6cdc22300205cc547 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/layers_dense_variational_test.py @@ -18,11 +18,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational_impl as prob_layers_lib +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import layers_dense_variational as prob_layers_lib +from tensorflow.contrib.bayesflow.python.ops import layers_util as prob_layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.platform import test @@ -41,14 +48,18 @@ class Counter(object): return self._value -class MockDistribution(normal_lib.Normal): - """Monitors DenseVariational calls to the underlying distribution.""" +class MockDistribution(independent_lib.Independent): + """Monitors layer calls to the underlying distribution.""" def __init__(self, result_sample, result_log_prob, loc=None, scale=None): self.result_sample = result_sample self.result_log_prob = result_log_prob self.result_loc = loc self.result_scale = scale + self.result_distribution = normal_lib.Normal(loc=0.0, scale=1.0) + if loc is not None and scale is not None: + self.result_distribution = normal_lib.Normal(loc=self.result_loc, + scale=self.result_scale) self.called_log_prob = Counter() self.called_sample = Counter() self.called_loc = Counter() @@ -62,6 +73,10 @@ class MockDistribution(normal_lib.Normal): self.called_sample() return self.result_sample + @property + def distribution(self): # for dummy check on Independent(Normal) + return self.result_distribution + @property def loc(self): self.called_loc() @@ -74,7 +89,7 @@ class MockDistribution(normal_lib.Normal): class MockKLDivergence(object): - """Monitors DenseVariational calls to the divergence implementation.""" + """Monitors layer calls to the divergence implementation.""" def __init__(self, result): self.result = result @@ -87,94 +102,125 @@ class MockKLDivergence(object): return self.result -class DenseVariationalLocalReparametrization(test.TestCase): +class DenseVariational(test.TestCase): - def testKLPenaltyKernel(self): + def _testKLPenaltyKernel(self, layer_class): with self.test_session(): - dense_vi = prob_layers_lib.DenseVariational(units=2) + layer = layer_class(units=2) inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) - _ = dense_vi(inputs) + _ = layer(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 1) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 1) + self.assertListEqual(layer.losses, losses) - def testKLPenaltyBoth(self): + def _testKLPenaltyBoth(self, layer_class): def _make_normal(dtype, *args): # pylint: disable=unused-argument return normal_lib.Normal( loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)) with self.test_session(): - dense_vi = prob_layers_lib.DenseVariational( + layer = layer_class( units=2, - bias_posterior_fn=prob_layers_lib.default_mean_field_normal_fn(), + bias_posterior_fn=prob_layers_util.default_mean_field_normal_fn(), bias_prior_fn=_make_normal) inputs = random_ops.random_uniform([2, 3], seed=1) # No keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 0) - self.assertListEqual(dense_vi.losses, loss_keys) + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 0) + self.assertListEqual(layer.losses, losses) - _ = dense_vi(inputs) + _ = layer(inputs) # Yes keys. - loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - self.assertEqual(len(loss_keys), 2) - self.assertListEqual(dense_vi.losses, loss_keys) - - def testVariationalNonLocal(self): + losses = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + self.assertEqual(len(losses), 2) + self.assertListEqual(layer.losses, losses) + + def _testDenseSetUp(self, layer_class, batch_size, in_size, out_size, + **kwargs): + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_size = [in_size, out_size] + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform(kernel_size, seed=seed()), + scale=random_ops.random_uniform(kernel_size, seed=seed()), + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), + result_sample=random_ops.random_uniform(kernel_size, seed=seed())) + kernel_divergence = MockKLDivergence( + result=random_ops.random_uniform(kernel_size, seed=seed())) + + bias_size = [out_size] + bias_posterior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_prior = MockDistribution( + result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), + result_sample=random_ops.random_uniform(bias_size, seed=seed())) + bias_divergence = MockKLDivergence( + result=random_ops.random_uniform(bias_size, seed=seed())) + + layer = layer_class( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + kernel_prior_fn=lambda *args: kernel_prior, + kernel_divergence_fn=kernel_divergence, + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + bias_prior_fn=lambda *args: bias_prior, + bias_divergence_fn=bias_divergence, + **kwargs) + + outputs = layer(inputs) + + kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + return (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, + layer, inputs, outputs, kl_penalty) + + def testKLPenaltyKernelReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseReparameterization) + + def testKLPenaltyKernelLocalReparameterization(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseLocalReparameterization) + + def testKLPenaltyKernelFlipout(self): + self._testKLPenaltyKernel(prob_layers_lib.DenseFlipout) + + def testKLPenaltyBothReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseReparameterization) + + def testKLPenaltyBothLocalReparameterization(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseLocalReparameterization) + + def testKLPenaltyBothFlipout(self): + self._testKLPenaltyBoth(prob_layers_lib.DenseFlipout) + + def testDenseReparameterization(self): batch_size, in_size, out_size = 2, 3, 4 with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseReparameterization, + batch_size, in_size, out_size) expected_outputs = ( math_ops.matmul(inputs, kernel_posterior.result_sample) + bias_posterior.result_sample) - dense_vi = prob_layers_lib.DenseVariational( - units=2, - kernel_use_local_reparameterization=False, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence) - - outputs = dense_vi(inputs) - - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) - [ expected_outputs_, actual_outputs_, expected_kernel_, actual_kernel_, @@ -183,9 +229,9 @@ class DenseVariationalLocalReparametrization(test.TestCase): expected_bias_divergence_, actual_bias_divergence_, ] = sess.run([ expected_outputs, outputs, - kernel_posterior.result_sample, dense_vi.kernel.posterior_tensor, + kernel_posterior.result_sample, layer.kernel_posterior_tensor, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, layer.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -206,40 +252,25 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, kernel_posterior.result_sample]], + [[kernel_posterior.distribution, + kernel_prior.distribution, + kernel_posterior.result_sample]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) - def testVariationalLocal(self): + def testDenseLocalReparameterization(self): batch_size, in_size, out_size = 2, 3, 4 with self.test_session() as sess: - seed = Counter() - inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) - - kernel_size = [in_size, out_size] - kernel_posterior = MockDistribution( - loc=random_ops.random_uniform(kernel_size, seed=seed()), - scale=random_ops.random_uniform(kernel_size, seed=seed()), - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(kernel_size, seed=seed()), - result_sample=random_ops.random_uniform(kernel_size, seed=seed())) - kernel_divergence = MockKLDivergence( - result=random_ops.random_uniform(kernel_size, seed=seed())) - - bias_size = [out_size] - bias_posterior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_prior = MockDistribution( - result_log_prob=random_ops.random_uniform(bias_size, seed=seed()), - result_sample=random_ops.random_uniform(bias_size, seed=seed())) - bias_divergence = MockKLDivergence( - result=random_ops.random_uniform(bias_size, seed=seed())) + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseLocalReparameterization, + batch_size, in_size, out_size) expected_kernel_posterior_affine = normal_lib.Normal( loc=math_ops.matmul(inputs, kernel_posterior.result_loc), @@ -250,21 +281,80 @@ class DenseVariationalLocalReparametrization(test.TestCase): expected_outputs = (expected_kernel_posterior_affine_tensor + bias_posterior.result_sample) - dense_vi = prob_layers_lib.DenseVariational( - units=2, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=lambda *args: kernel_posterior, - kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), - kernel_prior_fn=lambda *args: kernel_prior, - kernel_divergence_fn=kernel_divergence, - bias_posterior_fn=lambda *args: bias_posterior, - bias_posterior_tensor_fn=lambda d: d.sample(seed=43), - bias_prior_fn=lambda *args: bias_prior, - bias_divergence_fn=bias_divergence) + [ + expected_outputs_, actual_outputs_, + expected_kernel_divergence_, actual_kernel_divergence_, + expected_bias_, actual_bias_, + expected_bias_divergence_, actual_bias_divergence_, + ] = sess.run([ + expected_outputs, outputs, + kernel_divergence.result, kl_penalty[0], + bias_posterior.result_sample, layer.bias_posterior_tensor, + bias_divergence.result, kl_penalty[1], + ]) - outputs = dense_vi(inputs) + self.assertAllClose( + expected_bias_, actual_bias_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_outputs_, actual_outputs_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_kernel_divergence_, actual_kernel_divergence_, + rtol=1e-6, atol=0.) + self.assertAllClose( + expected_bias_divergence_, actual_bias_divergence_, + rtol=1e-6, atol=0.) + + self.assertAllEqual( + [[kernel_posterior.distribution, + kernel_prior.distribution, + None]], + kernel_divergence.args) + + self.assertAllEqual( + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], + bias_divergence.args) + + def testDenseFlipout(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + (kernel_posterior, kernel_prior, kernel_divergence, + bias_posterior, bias_prior, bias_divergence, layer, inputs, + outputs, kl_penalty) = self._testDenseSetUp( + prob_layers_lib.DenseFlipout, + batch_size, in_size, out_size, seed=44) + + expected_kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(kernel_posterior.result_loc), + scale=kernel_posterior.result_scale) + expected_kernel_posterior_affine_tensor = ( + expected_kernel_posterior_affine.sample(seed=42)) - kl_penalty = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES) + sign_input = random_ops.random_uniform( + [batch_size, in_size], + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=layer.seed) + sign_input = math_ops.cast(2 * sign_input - 1, inputs.dtype) + sign_output = random_ops.random_uniform( + [batch_size, out_size], + minval=0, + maxval=2, + dtype=dtypes.int32, + seed=distribution_util.gen_new_seed( + layer.seed, salt="dense_flipout")) + sign_output = math_ops.cast(2 * sign_output - 1, inputs.dtype) + perturbed_inputs = math_ops.matmul( + inputs * sign_input, expected_kernel_posterior_affine_tensor) + perturbed_inputs *= sign_output + + expected_outputs = math_ops.matmul(inputs, kernel_posterior.result_loc) + expected_outputs += perturbed_inputs + expected_outputs += bias_posterior.result_sample [ expected_outputs_, actual_outputs_, @@ -274,7 +364,7 @@ class DenseVariationalLocalReparametrization(test.TestCase): ] = sess.run([ expected_outputs, outputs, kernel_divergence.result, kl_penalty[0], - bias_posterior.result_sample, dense_vi.bias.posterior_tensor, + bias_posterior.result_sample, layer.bias_posterior_tensor, bias_divergence.result, kl_penalty[1], ]) @@ -292,13 +382,62 @@ class DenseVariationalLocalReparametrization(test.TestCase): rtol=1e-6, atol=0.) self.assertAllEqual( - [[kernel_posterior, kernel_prior, None]], + [[kernel_posterior.distribution, kernel_prior.distribution, None]], kernel_divergence.args) self.assertAllEqual( - [[bias_posterior, bias_prior, bias_posterior.result_sample]], + [[bias_posterior.distribution, + bias_prior.distribution, + bias_posterior.result_sample]], bias_divergence.args) + def testRandomDenseFlipout(self): + batch_size, in_size, out_size = 2, 3, 4 + with self.test_session() as sess: + seed = Counter() + inputs = random_ops.random_uniform([batch_size, in_size], seed=seed()) + + kernel_posterior = MockDistribution( + loc=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + scale=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + result_log_prob=random_ops.random_uniform( + [in_size, out_size], seed=seed()), + result_sample=random_ops.random_uniform( + [in_size, out_size], seed=seed())) + bias_posterior = MockDistribution( + loc=random_ops.random_uniform( + [out_size], seed=seed()), + scale=random_ops.random_uniform( + [out_size], seed=seed()), + result_log_prob=random_ops.random_uniform( + [out_size], seed=seed()), + result_sample=random_ops.random_uniform( + [out_size], seed=seed())) + layer_one = prob_layers_lib.DenseFlipout( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=44) + layer_two = prob_layers_lib.DenseFlipout( + units=out_size, + kernel_posterior_fn=lambda *args: kernel_posterior, + kernel_posterior_tensor_fn=lambda d: d.sample(seed=42), + bias_posterior_fn=lambda *args: bias_posterior, + bias_posterior_tensor_fn=lambda d: d.sample(seed=43), + seed=45) + + outputs_one = layer_one(inputs) + outputs_two = layer_two(inputs) + + outputs_one_, outputs_two_ = sess.run([ + outputs_one, outputs_two]) + + self.assertLess(np.sum(np.isclose(outputs_one_, outputs_two_)), out_size) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..52e36e135d95c1ec919c710f35d59073c2134d05 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/mcmc_diagnostics_test.py @@ -0,0 +1,445 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MCMC diagnostic utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import mcmc_diagnostics_impl as mcmc_diagnostics +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops_test_util +from tensorflow.python.platform import test + +rng = np.random.RandomState(42) + + +class _EffectiveSampleSizeTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError( + "Subclass failed to implement `use_static_shape`.") + + def _check_versus_expected_effective_sample_size(self, + x_, + expected_ess, + sess, + atol=1e-2, + rtol=1e-2, + filter_threshold=None, + filter_beyond_lag=None): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + ess = mcmc_diagnostics.effective_sample_size( + x, + filter_threshold=filter_threshold, + filter_beyond_lag=filter_beyond_lag) + if self.use_static_shape: + self.assertAllEqual(x.shape[1:], ess.shape) + + ess_ = sess.run(ess) + + self.assertAllClose( + np.ones_like(ess_) * expected_ess, ess_, atol=atol, rtol=rtol) + + def testIidRank1NormalHasFullEssMaxLags10(self): + # With a length 5000 iid normal sequence, and filter_beyond_lag = 10, we + # should have a good estimate of ESS, and it should be close to the full + # sequence length of 5000. + # The choice of filter_beyond_lag = 10 is a short cutoff, reasonable only + # since we know the correlation length should be zero right away. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=10, + filter_threshold=None, + rtol=0.3) + + def testIidRank2NormalHasFullEssMaxLags10(self): + # See similar test for Rank1Normal for reasoning. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000, 2).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=10, + filter_threshold=None, + rtol=0.3) + + def testIidRank1NormalHasFullEssMaxLagThresholdZero(self): + # With a length 5000 iid normal sequence, and filter_threshold = 0, + # we should have a super-duper estimate of ESS, and it should be very close + # to the full sequence length of 5000. + # The choice of filter_beyond_lag = 0 means we cutoff as soon as the + # auto-corris below zero. This should happen very quickly, due to the fact + # that the theoretical auto-corr is [1, 0, 0,...] + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=None, + filter_threshold=0., + rtol=0.1) + + def testIidRank2NormalHasFullEssMaxLagThresholdZero(self): + # See similar test for Rank1Normal for reasoning. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=rng.randn(5000, 2).astype(np.float32), + expected_ess=5000, + sess=sess, + filter_beyond_lag=None, + filter_threshold=0., + rtol=0.1) + + def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLags50(self): + # Create x_, such that + # x_[i] = iid_x_[0], i = 0,...,9 + # x_[i] = iid_x_[1], i = 10,..., 19, + # and so on. + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=x_, + expected_ess=50000 // 10, + sess=sess, + filter_beyond_lag=50, + filter_threshold=None, + rtol=0.2) + + def testLength10CorrelationHasEssOneTenthTotalLengthUsingMaxLagsThresholdZero( + self): + # Create x_, such that + # x_[i] = iid_x_[0], i = 0,...,9 + # x_[i] = iid_x_[1], i = 10,..., 19, + # and so on. + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + self._check_versus_expected_effective_sample_size( + x_=x_, + expected_ess=50000 // 10, + sess=sess, + filter_beyond_lag=None, + filter_threshold=0., + rtol=0.1) + + def testListArgs(self): + # x_ has correlation length 10 ==> ESS = N / 10 + # y_ has correlation length 1 ==> ESS = N + iid_x_ = rng.randn(5000, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((5000, 10)).astype(np.float32)).reshape((50000,)) + y_ = rng.randn(50000).astype(np.float32) + states = [x_, x_, y_, y_] + filter_threshold = [0., None, 0., None] + filter_beyond_lag = [None, 5, None, 5] + + # See other tests for reasoning on tolerance. + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + ess = mcmc_diagnostics.effective_sample_size( + states, + filter_threshold=filter_threshold, + filter_beyond_lag=filter_beyond_lag) + ess_ = sess.run(ess) + self.assertAllEqual(4, len(ess_)) + + self.assertAllClose(50000 // 10, ess_[0], rtol=0.3) + self.assertAllClose(50000 // 10, ess_[1], rtol=0.3) + self.assertAllClose(50000, ess_[2], rtol=0.1) + self.assertAllClose(50000, ess_[3], rtol=0.1) + + def testMaxLagsThresholdLessThanNeg1SameAsNone(self): + # Setting both means we filter out items R_k from the auto-correlation + # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. + + # x_ has correlation length 10. + iid_x_ = rng.randn(500, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + + ess_none_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=None, filter_beyond_lag=None) + ess_none_200 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=None, filter_beyond_lag=200) + ess_neg2_200 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=-2., filter_beyond_lag=200) + ess_neg2_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=-2., filter_beyond_lag=None) + ess_none_none_, ess_none_200_, ess_neg2_200_, ess_neg2_none_ = sess.run( + [ess_none_none, ess_none_200, ess_neg2_200, ess_neg2_none]) + + # filter_threshold=-2 <==> filter_threshold=None. + self.assertAllClose(ess_none_none_, ess_neg2_none_) + self.assertAllClose(ess_none_200_, ess_neg2_200_) + + def testMaxLagsArgsAddInAnOrManner(self): + # Setting both means we filter out items R_k from the auto-correlation + # sequence if k > filter_beyond_lag OR k >= j where R_j < filter_threshold. + + # x_ has correlation length 10. + iid_x_ = rng.randn(500, 1).astype(np.float32) + x_ = (iid_x_ * np.ones((500, 10)).astype(np.float32)).reshape((5000,)) + with self.test_session() as sess: + with spectral_ops_test_util.fft_kernel_label_map(): + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + + ess_1_9 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=9) + ess_1_none = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=None) + ess_none_9 = mcmc_diagnostics.effective_sample_size( + x, filter_threshold=1., filter_beyond_lag=9) + ess_1_9_, ess_1_none_, ess_none_9_ = sess.run( + [ess_1_9, ess_1_none, ess_none_9]) + + # Since R_k = 1 for k < 10, and R_k < 1 for k >= 10, + # filter_threshold = 1 <==> filter_beyond_lag = 9. + self.assertAllClose(ess_1_9_, ess_1_none_) + self.assertAllClose(ess_1_9_, ess_none_9_) + + +class EffectiveSampleSizeStaticTest(test.TestCase, _EffectiveSampleSizeTest): + + @property + def use_static_shape(self): + return True + + +class EffectiveSampleSizeDynamicTest(test.TestCase, _EffectiveSampleSizeTest): + + @property + def use_static_shape(self): + return False + + +class _PotentialScaleReductionTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError( + "Subclass failed to impliment `use_static_shape`.") + + def testListOfStatesWhereFirstPassesSecondFails(self): + """Simple test showing API with two states. Read first!.""" + n_samples = 1000 + + # state_0 is two scalar chains taken from iid Normal(0, 1). Will pass. + state_0 = rng.randn(n_samples, 2) + + # state_1 is three 4-variate chains taken from Normal(0, 1) that have been + # shifted. Since every chain is shifted, they are not the same, and the + # test should fail. + offset = np.array([1., -1., 2.]).reshape(3, 1) + state_1 = rng.randn(n_samples, 3, 4) + offset + + rhat = mcmc_diagnostics.potential_scale_reduction( + chains_states=[state_0, state_1], independent_chain_ndims=1) + + self.assertIsInstance(rhat, list) + with self.test_session() as sess: + rhat_0_, rhat_1_ = sess.run(rhat) + + # r_hat_0 should be close to 1, meaning test is passed. + self.assertAllEqual((), rhat_0_.shape) + self.assertAllClose(1., rhat_0_, rtol=0.02) + + # r_hat_1 should be greater than 1.2, meaning test has failed. + self.assertAllEqual((4,), rhat_1_.shape) + self.assertAllEqual(np.ones_like(rhat_1_).astype(bool), rhat_1_ > 1.2) + + def check_results(self, state_, independent_chain_shape, should_pass): + sample_ndims = 1 + independent_chain_ndims = len(independent_chain_shape) + with self.test_session(): + state = array_ops.placeholder_with_default( + input=state_, shape=state_.shape if self.use_static_shape else None) + + rhat = mcmc_diagnostics.potential_scale_reduction( + state, independent_chain_ndims=independent_chain_ndims) + + if self.use_static_shape: + self.assertAllEqual( + state_.shape[sample_ndims + independent_chain_ndims:], rhat.shape) + + rhat_ = rhat.eval() + if should_pass: + self.assertAllClose(np.ones_like(rhat_), rhat_, atol=0, rtol=0.02) + else: + self.assertAllEqual(np.ones_like(rhat_).astype(bool), rhat_ > 1.2) + + def iid_normal_chains_should_pass_wrapper(self, + sample_shape, + independent_chain_shape, + other_shape, + dtype=np.float32): + """Check results with iid normal chains.""" + + state_shape = sample_shape + independent_chain_shape + other_shape + state_ = rng.randn(*state_shape).astype(dtype) + + # The "other" dimensions do not have to be identical, just independent, so + # force them to not be identical. + if other_shape: + state_ *= rng.rand(*other_shape).astype(dtype) + + self.check_results(state_, independent_chain_shape, should_pass=True) + + def testPassingIIDNdimsAreIndependentOneOtherZero(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], independent_chain_shape=[4], other_shape=[]) + + def testPassingIIDNdimsAreIndependentOneOtherOne(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], independent_chain_shape=[3], other_shape=[7]) + + def testPassingIIDNdimsAreIndependentOneOtherTwo(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], independent_chain_shape=[2], other_shape=[5, 7]) + + def testPassingIIDNdimsAreIndependentTwoOtherTwo64Bit(self): + self.iid_normal_chains_should_pass_wrapper( + sample_shape=[10000], + independent_chain_shape=[2, 3], + other_shape=[5, 7], + dtype=np.float64) + + def offset_normal_chains_should_fail_wrapper( + self, sample_shape, independent_chain_shape, other_shape): + """Check results with normal chains that are offset from each other.""" + + state_shape = sample_shape + independent_chain_shape + other_shape + state_ = rng.randn(*state_shape) + + # Add a significant offset to the different (formerly iid) chains. + offset = np.linspace( + 0, 2, num=np.prod(independent_chain_shape)).reshape([1] * len( + sample_shape) + independent_chain_shape + [1] * len(other_shape)) + state_ += offset + + self.check_results(state_, independent_chain_shape, should_pass=False) + + def testFailingOffsetNdimsAreSampleOneIndependentOneOtherOne(self): + self.offset_normal_chains_should_fail_wrapper( + sample_shape=[10000], independent_chain_shape=[2], other_shape=[5]) + + +class PotentialScaleReductionStaticTest(test.TestCase, + _PotentialScaleReductionTest): + + @property + def use_static_shape(self): + return True + + def testIndependentNdimsLessThanOneRaises(self): + with self.assertRaisesRegexp(ValueError, "independent_chain_ndims"): + mcmc_diagnostics.potential_scale_reduction( + rng.rand(2, 3, 4), independent_chain_ndims=0) + + +class PotentialScaleReductionDynamicTest(test.TestCase, + _PotentialScaleReductionTest): + + @property + def use_static_shape(self): + return False + + +class _ReduceVarianceTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError( + "Subclass failed to impliment `use_static_shape`.") + + def check_versus_numpy(self, x_, axis, biased, keepdims): + with self.test_session(): + x_ = np.asarray(x_) + x = array_ops.placeholder_with_default( + input=x_, shape=x_.shape if self.use_static_shape else None) + var = mcmc_diagnostics._reduce_variance( + x, axis=axis, biased=biased, keepdims=keepdims) + np_var = np.var(x_, axis=axis, ddof=0 if biased else 1, keepdims=keepdims) + + if self.use_static_shape: + self.assertAllEqual(np_var.shape, var.shape) + + var_ = var.eval() + # We will mask below, which changes shape, so check shape explicitly here. + self.assertAllEqual(np_var.shape, var_.shape) + + # We get NaN when we divide by zero due to the size being the same as ddof + nan_mask = np.isnan(np_var) + if nan_mask.any(): + self.assertTrue(np.isnan(var_[nan_mask]).all()) + self.assertAllClose(np_var[~nan_mask], var_[~nan_mask], atol=0, rtol=0.02) + + def testScalarBiasedTrue(self): + self.check_versus_numpy(x_=-1.234, axis=None, biased=True, keepdims=False) + + def testScalarBiasedFalse(self): + # This should result in NaN. + self.check_versus_numpy(x_=-1.234, axis=None, biased=False, keepdims=False) + + def testShape2x3x4AxisNoneBiasedFalseKeepdimsFalse(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4), axis=None, biased=True, keepdims=False) + + def testShape2x3x4Axis1BiasedFalseKeepdimsTrue(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4), axis=1, biased=True, keepdims=True) + + def testShape2x3x4x5Axis13BiasedFalseKeepdimsTrue(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4, 5), axis=1, biased=True, keepdims=True) + + def testShape2x3x4x5Axis13BiasedFalseKeepdimsFalse(self): + self.check_versus_numpy( + x_=rng.randn(2, 3, 4, 5), axis=1, biased=False, keepdims=False) + + +class ReduceVarianceTestStaticShape(test.TestCase, _ReduceVarianceTest): + + @property + def use_static_shape(self): + return True + + +class ReduceVarianceTestDynamicShape(test.TestCase, _ReduceVarianceTest): + + @property + def use_static_shape(self): + return False + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py index 66793383fdd5c71f136900197a91be6966e2f8c7..756c25683bd4b0c8c77e9e28485ca2a85582999c 100644 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/sgld_optimizer_test.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -36,9 +36,9 @@ class SGLDOptimizerTest(test.TestCase): grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) decay_rate = 0.53 - sgd_op = SGLDOptimizer( - 3.0, preconditioner_decay_rate=decay_rate).apply_gradients( - zip([grads0, grads1], [var0, var1])) + sgd_optimizer = SGLDOptimizer(3.0, preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) variables.global_variables_initializer().run() # Fetch params to validate initial values self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) @@ -54,6 +54,7 @@ class SGLDOptimizerTest(test.TestCase): decay_rate + (1 - decay_rate) * 0.01**2 + 1e-8)) self.assertAllCloseAccordingToType( [3.0 - 3.0 * grads_scaled, 4.0 - 3.0 * grads_scaled], var1.eval()) + self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) def testBasicMultiInstance(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: @@ -102,6 +103,8 @@ class SGLDOptimizerTest(test.TestCase): sgd_optimizer2.variable_scope) self.assertNotEqual(sgd_optimizer.variable_scope.name, sgd_optimizer2.variable_scope.name) + self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) + self.assertAllCloseAccordingToType(1, sgd_optimizer2._counter.eval()) def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f978cf86417dc5ff5412a3eee584330a266e0964 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variable_utils_test.py @@ -0,0 +1,135 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions related to managing `tf.Variable`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +import numpy as np + +from tensorflow.contrib.bayesflow.python.ops import variable_utils + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.ops import variables as variables_ops +from tensorflow.python.platform import test + + +def test_fn(x): + x = ops.convert_to_tensor(x, name="x") + dtype = x.dtype.as_numpy_dtype + s = x.shape.as_list() + z = varscope_ops.get_variable( + name="z", + dtype=dtype, + initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) + y = varscope_ops.get_variable( + name="y", + dtype=dtype, + initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)**2) + return x + y + z + + +class _WrapCallableTest(object): + + def testDefaultArgsWorkCorrectly(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( + test_fn, [x]) + + varscope_ops.get_variable_scope().reuse_variables() + + result = wrapped_fn(self.dtype(2), [3, 4, 5], 0.5) + + y_actual = varscope_ops.get_variable("y", dtype=self.dtype) + z_actual = varscope_ops.get_variable("z", dtype=self.dtype) + + variables_ops.global_variables_initializer().run() + result_ = result.eval() + + self.assertEqual(self.dtype, result_.dtype) + self.assertAllEqual([5.5, 6.5, 7.5], result_) + self.assertAllEqual([y_actual, z_actual], vars_args) + + def testNonDefaultArgsWorkCorrectly(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + + _ = test_fn(self.dtype([0., 0.])) # Needed to create vars. + varscope_ops.get_variable_scope().reuse_variables() + + y_actual = varscope_ops.get_variable("y", dtype=self.dtype) + + wrapped_fn, vars_args = variable_utils.externalize_variables_as_args( + test_fn, [x], possible_ancestor_vars=[y_actual]) + + result = wrapped_fn(self.dtype([2, 3]), 0.5) # x, y + + variables_ops.global_variables_initializer().run() + result_ = result.eval() + + self.assertEqual(self.dtype, result_.dtype) + self.assertAllEqual([2.5, 4.5], result_) + self.assertAllEqual([y_actual], vars_args) + + def testWarnings(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + wrapped_fn, _ = variable_utils.externalize_variables_as_args( + test_fn, [x], possible_ancestor_vars=[]) + varscope_ops.get_variable_scope().reuse_variables() + with warnings.catch_warnings(record=True) as w: + wrapped_fn(self.dtype(2)) + w = sorted(w, key=lambda w: str(w.message)) + self.assertEqual(2, len(w)) + self.assertRegexpMatches( + str(w[0].message), + r"Variable .* 'y:0' .* not found in bypass dict.") + self.assertRegexpMatches( + str(w[1].message), + r"Variable .* 'z:0' .* not found in bypass dict.") + + def testExceptions(self): + with self.test_session(): + x = constant_op.constant(self.dtype([0.1, 0.2])) + wrapped_fn, _ = variable_utils.externalize_variables_as_args( + test_fn, + [x], + possible_ancestor_vars=[], + assert_variable_override=True) + varscope_ops.get_variable_scope().reuse_variables() + with self.assertRaisesRegexp(ValueError, r"not found"): + wrapped_fn(self.dtype(2)) + + +class WrapCallableTest16(test.TestCase, _WrapCallableTest): + dtype = np.float16 + + +class WrapCallableTest32(test.TestCase, _WrapCallableTest): + dtype = np.float32 + + +class WrapCallableTest64(test.TestCase, _WrapCallableTest): + dtype = np.float64 + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..83c64dbe0fd586edcb784a5c09a4c133aaa99cff --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_sgd_optimizer_test.py @@ -0,0 +1,268 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional test for GradientDescent.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.contrib.bayesflow.python.ops.optimizers import VariationalSGDOptimizer +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class VariationalSGDOptimizerTest(test.TestCase): + + def testBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.53 + sgd_op = VariationalSGDOptimizer( + 1, + 1, + preconditioner_decay_rate=decay_rate, + max_learning_rate=3.0, + burnin_max_learning_rate=3.0, + use_single_learning_rate=True).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + + def testBasicMultiInstance(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + vara = variables.Variable([1.1, 2.1], dtype=dtype) + varb = variables.Variable([3.0, 4.0], dtype=dtype) + gradsa = constant_op.constant([0.1, 0.1], dtype=dtype) + gradsb = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.5 + batch_size = 2 + total_num_examples = 10 + optimizer = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=1.0, + burnin_max_learning_rate=3.0, + preconditioner_decay_rate=decay_rate) + sgd_op = optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1])) + optimizer2 = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=1.0, + burnin_max_learning_rate=10.0, + burnin=0, + preconditioner_decay_rate=decay_rate) + sgd_op2 = optimizer2.apply_gradients( + zip([gradsa, gradsb], [vara, varb])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + self.assertAllCloseAccordingToType([1.1, 2.1], vara.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], varb.eval()) + + # Run 1 step of sgd + sgd_op.run() + sgd_op2.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.1 - 3. * 0.1, 2.1 - 3. * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([1.1 - 0.1, 2.1 - 0.1], vara.eval()) + + self.assertAllCloseAccordingToType([3.0 - 3. * 0.01, 4.0 - 3. * 0.01], + var1.eval()) + self.assertAllCloseAccordingToType([3.0 - 0.01, 4.0 - 0.01], + varb.eval()) + self.assertNotEqual(optimizer.variable_scope, + optimizer2.variable_scope) + self.assertNotEqual(optimizer.variable_scope.name, + optimizer2.variable_scope.name) + self.assertAllCloseAccordingToType(1, optimizer._counter.eval()) + self.assertAllCloseAccordingToType(1, optimizer2._counter.eval()) + + def testTensorLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = constant_op.constant(3.0) + decay_rate = 0.5 + batch_size = 2 + total_num_examples = 10 + sgd_op = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=lrate, + burnin=0, + preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + + def testTensorDecayLearningRate(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + lrate = variables.Variable(3.0) + lrate_decay_op = lrate.assign_add(-3.) + decay_rate = 0.5 + batch_size = 2 + total_num_examples = 10 + optimizer = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=lrate, + burnin=0, + preconditioner_decay_rate=decay_rate) + sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + # Update learning rate to 0 + lrate_decay_op.eval() + sgd_op.run() + # Validate params haven't changed + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + lrate_decay_op.eval() + + with self.assertRaises(errors.InvalidArgumentError): + sgd_op.run() + + def testGradWrtRef(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + opt = VariationalSGDOptimizer(1, 1, max_learning_rate=1.0) + values = [1.0, 3.0] + vars_ = [variables.Variable([v], dtype=dtype) for v in values] + grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_) + variables.global_variables_initializer().run() + for grad, _ in grads_and_vars: + self.assertAllCloseAccordingToType([1.0], grad.eval()) + + def testWithGlobalStep(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + global_step = variables.Variable(0, trainable=False) + var0 = variables.Variable([1.1, 2.1], dtype=dtype) + var1 = variables.Variable([3.0, 4.0], dtype=dtype) + grads0 = constant_op.constant([0.1, 0.1], dtype=dtype) + grads1 = constant_op.constant([0.01, 0.01], dtype=dtype) + decay_rate = 0.1 + batch_size = 2 + total_num_examples = 10 + sgd_optimizer = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=3.0, + burnin=0, + preconditioner_decay_rate=decay_rate) + sgd_op = sgd_optimizer.apply_gradients( + zip([grads0, grads1], [var0, var1]), global_step=global_step) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([1.1, 2.1], var0.eval()) + self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + + # Validate updated params and global_step + self.assertAllCloseAccordingToType([1.1 - 3.0 * 0.1, 2.1 - 3.0 * 0.1], + var0.eval()) + self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01], + var1.eval()) + self.assertAllCloseAccordingToType(1, global_step.eval()) + self.assertAllCloseAccordingToType(1, sgd_optimizer._counter.eval()) + + def testSparseBasic(self): + for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: + with self.test_session(): + var0 = variables.Variable([[1.1], [2.1]], dtype=dtype) + var1 = variables.Variable([[3.0], [4.0]], dtype=dtype) + grads0 = ops.IndexedSlices( + constant_op.constant([0.1], shape=[1, 1], dtype=dtype), + constant_op.constant([0]), constant_op.constant([2, 1])) + grads1 = ops.IndexedSlices( + constant_op.constant([0.01], shape=[1, 1], dtype=dtype), + constant_op.constant([1]), constant_op.constant([2, 1])) + decay_rate = 0.1 + batch_size = 2 + total_num_examples = 10 + sgd_op = VariationalSGDOptimizer( + batch_size, + total_num_examples, + max_learning_rate=3.0, + burnin=0, + preconditioner_decay_rate=decay_rate).apply_gradients( + zip([grads0, grads1], [var0, var1])) + variables.global_variables_initializer().run() + # Fetch params to validate initial values + self.assertAllCloseAccordingToType([[1.1], [2.1]], var0.eval()) + self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval()) + # Run 1 step of sgd + sgd_op.run() + # Validate updated params + self.assertAllCloseAccordingToType([[1.1 - 3.0 * 0.1], [2.1]], + var0.eval()) + self.assertAllCloseAccordingToType( + [[3.0 - 3.0 * 0], [4.0 - 3.0 * 0.01]], var1.eval()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py index ee3719232d8796c338247320fd8ef832a41df12b..d44fe6529a7ff0da0c6747e193fdb98a272a8da3 100644 --- a/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/custom_grad_impl.py @@ -31,8 +31,7 @@ __all__ = [ ] -def custom_gradient(fx, gx, x, axis=(), - fx_gx_manually_stopped=False, +def custom_gradient(fx, gx, x, axis=(), fx_gx_manually_stopped=False, name=None): """Enables specifying a custom gradient. @@ -43,7 +42,8 @@ def custom_gradient(fx, gx, x, axis=(), h(x) = x * stop_gradient(g(x)) + stop_gradient(f(x) - x * g(x)) ``` - is such that `h(x) = stop(f(x))` and `grad[h(x), x] = stop_gradient(g(x)).` + is such that `h(x) = stop_gradient(f(x))` and `grad[h(x), x] = + stop_gradient(g(x)).` In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions. However, in the latter case it diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py index 977d42fc16bb91777a76c45ac24f3c5dc587f5fe..7fd5652c5c3e085b23c05baef6e3a42b7a42e08f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. -""" +"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm.""" from __future__ import absolute_import from __future__ import division @@ -24,11 +23,9 @@ from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disabl from tensorflow.python.util import all_util _allowed_symbols = [ - 'chain', - 'kernel', - 'leapfrog_integrator', - 'leapfrog_step', - 'ais_chain' + "sample_chain", + "sample_annealed_importance_chain", + "kernel", ] all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py index 333dce929530adceb30dcb63653a5bd009c059e0..f724910c59315867a42a56fab3deb36f5d3adb7a 100644 --- a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -14,183 +14,343 @@ # ============================================================================== """Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. -@@chain -@@update -@@leapfrog_integrator -@@leapfrog_step -@@ais_chain +@@sample_chain +@@sample_annealed_importance_chain +@@kernel """ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import numpy as np +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import gradients_impl as gradients_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.ops.distributions import util as distributions_util __all__ = [ - 'chain', - 'kernel', - 'leapfrog_integrator', - 'leapfrog_step', - 'ais_chain' + "sample_chain", + "sample_annealed_importance_chain", + "kernel", ] -def _make_potential_and_grad(target_log_prob_fn): - def potential_and_grad(x): - log_prob_result = -target_log_prob_fn(x) - grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result), - x)[0] - return log_prob_result, grad_result - return potential_and_grad - - -def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, - target_log_prob_fn, event_dims=(), name=None): +KernelResults = collections.namedtuple( + "KernelResults", + [ + "acceptance_probs", + "current_grads_target_log_prob", # "Current result" means "accepted". + "current_target_log_prob", # "Current result" means "accepted". + "energy_change", + "is_accepted", + "proposed_grads_target_log_prob", + "proposed_state", + "proposed_target_log_prob", + "random_positive", + ]) + + +def _make_dummy_kernel_results( + dummy_state, + dummy_target_log_prob, + dummy_grads_target_log_prob): + return KernelResults( + acceptance_probs=dummy_target_log_prob, + current_grads_target_log_prob=dummy_grads_target_log_prob, + current_target_log_prob=dummy_target_log_prob, + energy_change=dummy_target_log_prob, + is_accepted=array_ops.ones_like(dummy_target_log_prob, dtypes.bool), + proposed_grads_target_log_prob=dummy_grads_target_log_prob, + proposed_state=dummy_state, + proposed_target_log_prob=dummy_target_log_prob, + random_positive=dummy_target_log_prob, + ) + + +def sample_chain( + num_results, + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + num_burnin_steps=0, + num_steps_between_results=0, + seed=None, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. - Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) - algorithm that takes a series of gradient-informed steps to produce - a Metropolis proposal. This function samples from an HMC Markov - chain whose initial state is `initial_x` and whose stationary - distribution has log-density `target_log_prob_fn()`. - - This function can update multiple chains in parallel. It assumes - that all dimensions of `initial_x` not specified in `event_dims` are - independent, and should therefore be updated independently. The - output of `target_log_prob_fn()` should sum log-probabilities across - all event dimensions. Slices along dimensions not in `event_dims` - may have different target distributions; this is up to + Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) algorithm + that takes a series of gradient-informed steps to produce a Metropolis + proposal. This function samples from an HMC Markov chain at `current_state` + and whose stationary distribution has log-unnormalized-density `target_log_prob_fn()`. - This function basically just wraps `hmc.kernel()` in a tf.scan() loop. + This function samples from multiple chains in parallel. It assumes that the + the leftmost dimensions of (each) `current_state` (part) index an independent + chain. The function `target_log_prob_fn()` sums log-probabilities across + event dimensions (i.e., current state (part) rightmost dimensions). Each + element of the output of `target_log_prob_fn()` represents the (possibly + unnormalized) log-probability of the joint distribution over (all) the current + state (parts). - Args: - n_iterations: Integer number of Markov chain updates to run. - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - initial_x: Tensor of initial state(s) of the Markov chain(s). - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - name: Python `str` name prefixed to Ops created by this function. + The `current_state` can be represented as a single `Tensor` or a `list` of + `Tensors` which collectively represent the current state. When specifying a + `list`, one must also specify a list of `step_size`s. - Returns: - acceptance_probs: Tensor with the acceptance probabilities for each - iteration. Has shape matching `target_log_prob_fn(initial_x)`. - chain_states: Tensor with the state of the Markov chain at each iteration. - Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`. + Note: `target_log_prob_fn` is called exactly twice. + + Only one out of every `num_steps_between_samples + 1` steps is included in the + returned results. This "thinning" comes at a cost of reduced statistical + power, while reducing memory requirements and autocorrelation. For more + discussion see [1]. + + [1]: "Statistically efficient thinning of a Markov chain sampler." + Art B. Owen. April 2017. + http://statweb.stanford.edu/~owen/reports/bestthinning.pdf #### Examples: - ```python - # Sampling from a standard normal (note `log_joint()` is unnormalized): - def log_joint(x): - return tf.reduce_sum(-0.5 * tf.square(x)) - chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, - event_dims=[0]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) - ``` + ##### Sample from a diagonal-variance Gaussian. ```python - # Sampling from a diagonal-variance Gaussian: - variances = tf.linspace(1., 3., 10) - def log_joint(x): - return tf.reduce_sum(-0.5 / variances * tf.square(x)) - chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, - event_dims=[0]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + tfd = tf.contrib.distributions + + def make_likelihood(true_variances): + return tfd.MultivariateNormalDiag( + scale_diag=tf.sqrt(true_variances)) + + dims = 10 + dtype = np.float32 + true_variances = tf.linspace(dtype(1), dtype(3), dims) + likelihood = make_likelihood(true_variances) + + states, kernel_results = hmc.sample_chain( + num_results=1000, + target_log_prob_fn=likelihood.log_prob, + current_state=tf.zeros(dims), + step_size=0.5, + num_leapfrog_steps=2, + num_burnin_steps=500) + + # Compute sample stats. + sample_mean = tf.reduce_mean(states, axis=0) + sample_var = tf.reduce_mean( + tf.squared_difference(states, sample_mean), + axis=0) ``` - ```python - # Sampling from factor-analysis posteriors with known factors W: - # mu[i, j] ~ Normal(0, 1) - # x[i] ~ Normal(matmul(mu[i], W), I) - def log_joint(mu, x, W): - prior = -0.5 * tf.reduce_sum(tf.square(mu), 1) - x_mean = tf.matmul(mu, W) - likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1) - return prior + likelihood - chain, acceptance_probs = hmc.chain(1000, 0.1, 2, - tf.zeros([x.shape[0], W.shape[0]]), - lambda mu: log_joint(mu, x, W), - event_dims=[1]) - # Discard first half of chain as warmup/burn-in - warmed_up = chain[500:] - mean_est = tf.reduce_mean(warmed_up, 0) - var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + ##### Sampling from factor-analysis posteriors with known factors. + + I.e., + + ```none + for i=1..n: + w[i] ~ Normal(0, eye(d)) # prior + x[i] ~ Normal(loc=matmul(w[i], F)) # likelihood ``` + where `F` denotes factors. + ```python - # Sampling from the posterior of a Bayesian regression model.: - - # Run 100 chains in parallel, each with a different initialization. - initial_beta = tf.random_normal([100, x.shape[1]]) - chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta, - log_joint_partial, event_dims=[1]) - # Discard first halves of chains as warmup/burn-in - warmed_up = chain[500:] - # Averaging across samples within a chain and across chains - mean_est = tf.reduce_mean(warmed_up, [0, 1]) - var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est) + tfd = tf.contrib.distributions + + def make_prior(dims, dtype): + return tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + def make_likelihood(weights, factors): + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(weights, factors, axes=[[0], [-1]])) + + # Setup data. + num_weights = 10 + num_factors = 4 + num_chains = 100 + dtype = np.float32 + + prior = make_prior(num_weights, dtype) + weights = prior.sample(num_chains) + factors = np.random.randn(num_factors, num_weights).astype(dtype) + x = make_likelihood(weights, factors).sample(num_chains) + + def target_log_prob(w): + # Target joint is: `f(w) = p(w, x | factors)`. + return prior.log_prob(w) + make_likelihood(w, factors).log_prob(x) + + # Get `num_results` samples from `num_chains` independent chains. + chains_states, kernels_results = hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target_log_prob, + current_state=tf.zeros([num_chains, dims], dtype), + step_size=0.1, + num_leapfrog_steps=2, + num_burnin_steps=500) + + # Compute sample stats. + sample_mean = tf.reduce_mean(chains_states, axis=[0, 1]) + sample_var = tf.reduce_mean( + tf.squared_difference(chains_states, sample_mean), + axis=[0, 1]) ``` + + Args: + num_results: Integer number of Markov chain draws. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + num_burnin_steps: Integer number of chain steps to take before starting to + collect results. + Default value: 0 (i.e., no burn-in). + num_steps_between_results: Integer number of chain steps between collecting + a result. Only one out of every `num_steps_between_samples + 1` steps is + included in the returned results. This "thinning" comes at a cost of + reduced statistical power, while reducing memory requirements and + autocorrelation. For more discussion see [1]. + Default value: 0 (i.e., no subsampling). + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to specify + this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `target_log_prob` at the `current_state` and wrt + the `current_state`. Must have same shape as `current_state`. The only + reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_sample_chain"). + + Returns: + accepted_states: Tensor or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at each result step. Has same shape as + input `current_state` but with a prepended `num_results`-size dimension. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. """ - with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size, - n_leapfrog_steps, initial_x]): - initial_x = ops.convert_to_tensor(initial_x, name='initial_x') - non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) - - def body(a, _): - updated_x, acceptance_probs, log_prob, grad = kernel( - step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims, - a[2], a[3]) - return updated_x, acceptance_probs, log_prob, grad - - potential_and_grad = _make_potential_and_grad(target_log_prob_fn) - potential, grad = potential_and_grad(initial_x) - return functional_ops.scan(body, array_ops.zeros(n_iterations), - (initial_x, array_ops.zeros(non_event_shape), - -potential, -grad))[:2] - - -def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, - target_log_prob_fn, proposal_log_prob_fn, event_dims=(), - name=None): + with ops.name_scope( + name, "hmc_sample_chain", + [num_results, current_state, step_size, num_leapfrog_steps, + num_burnin_steps, num_steps_between_results, seed, + current_target_log_prob, current_grads_target_log_prob]): + with ops.name_scope("initialize"): + [ + current_state, + step_size, + current_target_log_prob, + current_grads_target_log_prob, + ] = _prepare_args( + target_log_prob_fn, + current_state, + step_size, + current_target_log_prob, + current_grads_target_log_prob) + num_results = ops.convert_to_tensor( + num_results, + dtype=dtypes.int32, + name="num_results") + num_leapfrog_steps = ops.convert_to_tensor( + num_leapfrog_steps, + dtype=dtypes.int32, + name="num_leapfrog_steps") + num_burnin_steps = ops.convert_to_tensor( + num_burnin_steps, + dtype=dtypes.int32, + name="num_burnin_steps") + num_steps_between_results = ops.convert_to_tensor( + num_steps_between_results, + dtype=dtypes.int32, + name="num_steps_between_results") + + def _run_chain(num_steps, current_state, kernel_results): + """Runs the chain(s) for `num_steps`.""" + def _loop_body(iter_, current_state, kernel_results): + return [iter_ + 1] + list(kernel( + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed, + kernel_results.current_target_log_prob, + kernel_results.current_grads_target_log_prob)) + while_loop_kwargs = dict( + cond=lambda iter_, *args: iter_ < num_steps, + body=_loop_body, + loop_vars=[ + np.int32(0), + current_state, + kernel_results, + ], + ) + if seed is not None: + while_loop_kwargs["parallel_iterations"] = 1 + return control_flow_ops.while_loop( + **while_loop_kwargs)[1:] # Lop-off "iter_". + + def _scan_body(args_list, iter_): + """Closure which implements `tf.scan` body.""" + current_state, kernel_results = args_list + return _run_chain( + 1 + array_ops.where(math_ops.equal(iter_, 0), + num_burnin_steps, + num_steps_between_results), + current_state, + kernel_results) + + scan_kwargs = dict( + fn=_scan_body, + elems=math_ops.range(num_results), # iter_: used to choose burnin. + initializer=[ + current_state, + _make_dummy_kernel_results( + current_state, + current_target_log_prob, + current_grads_target_log_prob), + ]) + if seed is not None: + scan_kwargs["parallel_iterations"] = 1 + return functional_ops.scan(**scan_kwargs) + + +def sample_annealed_importance_chain( + proposal_log_prob_fn, + num_steps, + target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed=None, + name=None): """Runs annealed importance sampling (AIS) to estimate normalizing constants. - This routine uses Hamiltonian Monte Carlo to sample from a series of + This function uses Hamiltonian Monte Carlo to sample from a series of distributions that slowly interpolates between an initial "proposal" - distribution + distribution: `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` - and the target distribution + and the target distribution: `exp(target_log_prob_fn(x) - target_log_normalizer)`, @@ -199,112 +359,203 @@ def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, normalizing constants of the initial distribution and the target distribution: - E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer). - - Args: - n_iterations: Integer number of Markov chain updates to run. More - iterations means more expense, but smoother annealing between q - and p, which in turn means exponentially lower variance for the - normalizing constant estimator. - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - initial_x: Tensor of initial state(s) of the Markov chain(s). Must - be a sample from q, or results will be incorrect. - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - proposal_log_prob_fn: Python callable that returns the log density of the - initial distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - name: Python `str` name prefixed to Ops created by this function. + `E[exp(ais_weights)] = exp(target_log_normalizer - proposal_log_normalizer)`. - Returns: - ais_weights: Tensor with the estimated weight(s). Has shape matching - `target_log_prob_fn(initial_x)`. - chain_states: Tensor with the state(s) of the Markov chain(s) the final - iteration. Has shape matching `initial_x`. - acceptance_probs: Tensor with the acceptance probabilities for the final - iteration. Has shape matching `target_log_prob_fn(initial_x)`. + Note: `proposal_log_prob_fn` and `target_log_prob_fn` are called exactly three + times (although this may be reduced to two times, in the future). #### Examples: + ##### Estimate the normalizing constant of a log-gamma distribution. + ```python - # Estimating the normalizing constant of a log-gamma distribution: - def proposal_log_prob(x): - # Standard normal log-probability. This is properly normalized. - return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1) - def target_log_prob(x): - # Unnormalized log-gamma(2, 3) distribution. - # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1] - return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1) + tfd = tf.contrib.distributions + # Run 100 AIS chains in parallel - initial_x = tf.random_normal([100, 20]) - w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob, - proposal_log_prob, event_dims=[1]) - log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + num_chains = 100 + dims = 20 + dtype = np.float32 + + proposal = tfd.MultivatiateNormalDiag( + loc=tf.zeros([dims], dtype=dtype)) + + target = tfd.TransformedDistribution( + distribution=tfd.Gamma(concentration=dtype(2), + rate=dtype(3)), + bijector=tfd.bijectors.Invert(tfd.bijectors.Exp()), + event_shape=[dims]) + + chains_state, ais_weights, kernels_results = ( + hmc.sample_annealed_importance_chain( + proposal_log_prob_fn=proposal.log_prob, + num_steps=1000, + target_log_prob_fn=target.log_prob, + step_size=0.2, + current_state=proposal.sample(num_chains), + num_leapfrog_steps=2)) + + log_estimated_normalizer = (tf.reduce_logsumexp(ais_weights) + - np.log(num_chains)) + log_true_normalizer = tf.lgamma(2.) - 2. * tf.log(3.) ``` + ##### Estimate marginal likelihood of a Bayesian regression model. + ```python - # Estimating the marginal likelihood of a Bayesian regression model: - base_measure = -0.5 * np.log(2 * np.pi) - def proposal_log_prob(x): - # Standard normal log-probability. This is properly normalized. - return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1) - def regression_log_joint(beta, x, y): - # This function returns a vector whose ith element is log p(beta[i], y | x). - # Each row of beta corresponds to the state of an independent Markov chain. - log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1) - means = tf.matmul(beta, x, transpose_b=True) - log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) + - base_measure, 1) - return log_prior + log_likelihood - def log_joint_partial(beta): - return regression_log_joint(beta, x, y) + tfd = tf.contrib.distributions + + def make_prior(dims, dtype): + return tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + def make_likelihood(weights, x): + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(weights, x, axes=[[0], [-1]])) + # Run 100 AIS chains in parallel - initial_beta = tf.random_normal([100, x.shape[1]]) - w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta, - log_joint_partial, proposal_log_prob, - event_dims=[1]) - log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + num_chains = 100 + dims = 10 + dtype = np.float32 + + # Make training data. + x = np.random.randn(num_chains, dims).astype(dtype) + true_weights = np.random.randn(dims).astype(dtype) + y = np.dot(x, true_weights) + np.random.randn(num_chains) + + # Setup model. + prior = make_prior(dims, dtype) + def target_log_prob_fn(weights): + return prior.log_prob(weights) + make_likelihood(weights, x).log_prob(y) + + proposal = tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)) + + weight_samples, ais_weights, kernel_results = ( + hmc.sample_annealed_importance_chain( + num_steps=1000, + proposal_log_prob_fn=proposal.log_prob, + target_log_prob_fn=target_log_prob_fn + current_state=tf.zeros([num_chains, dims], dtype), + step_size=0.1, + num_leapfrog_steps=2)) + log_normalizer_estimate = (tf.reduce_logsumexp(ais_weights) + - np.log(num_chains)) ``` + + Args: + proposal_log_prob_fn: Python callable that returns the log density of the + initial distribution. + num_steps: Integer number of Markov chain updates to run. More + iterations means more expense, but smoother annealing between q + and p, which in turn means exponentially lower variance for the + normalizing constant estimator. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + seed: Python integer to seed the random number generator. + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_sample_annealed_importance_chain"). + + Returns: + accepted_state: `Tensor` or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at the final iteration. Has same shape as + input `current_state`. + ais_weights: Tensor with the estimated weight(s). Has shape matching + `target_log_prob_fn(current_state)`. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. """ - with ops.name_scope(name, 'hmc_ais_chain', - [n_iterations, step_size, n_leapfrog_steps, initial_x]): - non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) - - beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:] - def _body(a, beta): # pylint: disable=missing-docstring - def log_prob_beta(x): - return ((1 - beta) * proposal_log_prob_fn(x) + - beta * target_log_prob_fn(x)) - last_x = a[0] - w = a[2] - w += (1. / n_iterations) * (target_log_prob_fn(last_x) - - proposal_log_prob_fn(last_x)) - # TODO(b/66917083): There's an opportunity for gradient reuse here. - updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps, - last_x, log_prob_beta, - event_dims) - return updated_x, acceptance_probs, w - - x, acceptance_probs, w = functional_ops.scan( - _body, beta_series, (initial_x, array_ops.zeros(non_event_shape), - array_ops.zeros(non_event_shape))) - return w[-1], x[-1], acceptance_probs[-1] - - -def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), - x_log_prob=None, x_grad=None, name=None): + def make_convex_combined_log_prob_fn(iter_): + def _fn(*args): + p = proposal_log_prob_fn(*args) + t = target_log_prob_fn(*args) + dtype = p.dtype.base_dtype + beta = (math_ops.cast(iter_ + 1, dtype) + / math_ops.cast(num_steps, dtype)) + return (1. - beta) * p + beta * t + return _fn + + with ops.name_scope( + name, "hmc_sample_annealed_importance_chain", + [num_steps, current_state, step_size, num_leapfrog_steps, seed]): + with ops.name_scope("initialize"): + [ + current_state, + step_size, + current_log_prob, + current_grads_log_prob, + ] = _prepare_args( + make_convex_combined_log_prob_fn(iter_=0), + current_state, + step_size, + description="convex_combined_log_prob") + num_steps = ops.convert_to_tensor( + num_steps, + dtype=dtypes.int32, + name="num_steps") + num_leapfrog_steps = ops.convert_to_tensor( + num_leapfrog_steps, + dtype=dtypes.int32, + name="num_leapfrog_steps") + def _loop_body(iter_, ais_weights, current_state, kernel_results): + """Closure which implements `tf.while_loop` body.""" + current_state_parts = (list(current_state) + if _is_list_like(current_state) + else [current_state]) + # TODO(b/72994218): Consider refactoring things to avoid this unecessary + # call. + ais_weights += ((target_log_prob_fn(*current_state_parts) + - proposal_log_prob_fn(*current_state_parts)) + / math_ops.cast(num_steps, ais_weights.dtype)) + return [iter_ + 1, ais_weights] + list(kernel( + make_convex_combined_log_prob_fn(iter_), + current_state, + step_size, + num_leapfrog_steps, + seed, + kernel_results.current_target_log_prob, + kernel_results.current_grads_target_log_prob)) + + while_loop_kwargs = dict( + cond=lambda iter_, *args: iter_ < num_steps, + body=_loop_body, + loop_vars=[ + np.int32(0), # iter_ + array_ops.zeros_like(current_log_prob), # ais_weights + current_state, + _make_dummy_kernel_results(current_state, + current_log_prob, + current_grads_log_prob), + ]) + if seed is not None: + while_loop_kwargs["parallel_iterations"] = 1 + + [ais_weights, current_state, kernel_results] = control_flow_ops.while_loop( + **while_loop_kwargs)[1:] # Lop-off "iter_". + + return [current_state, ais_weights, kernel_results] + + +def kernel(target_log_prob_fn, + current_state, + step_size, + num_leapfrog_steps, + seed=None, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): """Runs one iteration of Hamiltonian Monte Carlo. Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) @@ -312,324 +563,623 @@ def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), a Metropolis proposal. This function applies one step of HMC to randomly update the variable `x`. - This function can update multiple chains in parallel. It assumes - that all dimensions of `x` not specified in `event_dims` are - independent, and should therefore be updated independently. The - output of `target_log_prob_fn()` should sum log-probabilities across - all event dimensions. Slices along dimensions not in `event_dims` - may have different target distributions; for example, if - `event_dims == (1,)`, then `x[0, :]` could have a different target - distribution from x[1, :]. This is up to `target_log_prob_fn()`. - - Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `x`. Larger step sizes lead to faster progress, but - too-large step sizes make rejection exponentially more likely. - When possible, it's often helpful to match per-variable step - sizes to the standard deviations of the target distribution in - each variable. - n_leapfrog_steps: Integer number of steps to run the leapfrog - integrator for. Total progress per HMC step is roughly - proportional to step_size * n_leapfrog_steps. - x: Tensor containing the value(s) of the random variable(s) to update. - target_log_prob_fn: Python callable which takes an argument like `initial_x` - and returns its (possibly unnormalized) log-density under the target - distribution. - event_dims: List of dimensions that should not be treated as - independent. This allows for multiple chains to be run independently - in parallel. Default is (), i.e., all dimensions are independent. - x_log_prob (optional): Tensor containing the cached output of a previous - call to `target_log_prob_fn()` evaluated at `x` (such as that provided by - a previous call to `kernel()`). Providing `x_log_prob` and - `x_grad` saves one gradient computation per call to `kernel()`. - x_grad (optional): Tensor containing the cached gradient of - `target_log_prob_fn()` evaluated at `x` (such as that provided by - a previous call to `kernel()`). Providing `x_log_prob` and - `x_grad` saves one gradient computation per call to `kernel()`. - name: Python `str` name prefixed to Ops created by this function. - - Returns: - updated_x: The updated variable(s) x. Has shape matching `initial_x`. - acceptance_probs: Tensor with the acceptance probabilities for the final - iteration. This is useful for diagnosing step size problems etc. Has - shape matching `target_log_prob_fn(initial_x)`. - new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`. - new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at - `updated_x`. + This function can update multiple chains in parallel. It assumes that all + leftmost dimensions of `current_state` index independent chain states (and are + therefore updated independently). The output of `target_log_prob_fn()` should + sum log-probabilities across all event dimensions. Slices along the rightmost + dimensions may have different target distributions; for example, + `current_state[0, :]` could have a different target distribution from + `current_state[1, :]`. This is up to `target_log_prob_fn()`. (The number of + independent chains is `tf.size(target_log_prob_fn(*current_state))`.) #### Examples: + ##### Simple chain with warm-up. + ```python + tfd = tf.contrib.distributions + # Tuning acceptance rates: + dtype = np.float32 target_accept_rate = 0.631 - def target_log_prob(x): - # Standard normal - return tf.reduce_sum(-0.5 * tf.square(x)) - initial_x = tf.zeros([10]) - initial_log_prob = target_log_prob(initial_x) - initial_grad = tf.gradients(initial_log_prob, initial_x)[0] - # Algorithm state - x = tf.Variable(initial_x, name='x') - step_size = tf.Variable(1., name='step_size') - last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob') - last_grad = tf.Variable(initial_grad, name='last_grad') - # Compute updates - new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x, - target_log_prob, - event_dims=[0], - x_log_prob=last_log_prob) - x_update = tf.assign(x, new_x) - log_prob_update = tf.assign(last_log_prob, log_prob) - grad_update = tf.assign(last_grad, grad) - step_size_update = tf.assign(step_size, - tf.where(acceptance_prob > target_accept_rate, - step_size * 1.01, step_size / 1.01)) - adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update] - sampling_updates = [x_update, log_prob_update, grad_update] - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) + num_warmup_iter = 500 + num_chain_iter = 500 + + x = tf.get_variable(name="x", initializer=dtype(1)) + step_size = tf.get_variable(name="step_size", initializer=dtype(1)) + + target = tfd.Normal(loc=dtype(0), scale=dtype(1)) + + new_x, other_results = hmc.kernel( + target_log_prob_fn=target.log_prob, + current_state=x, + step_size=step_size, + num_leapfrog_steps=3)[:4] + + x_update = x.assign(new_x) + + step_size_update = step_size.assign_add( + step_size * tf.where( + other_results.acceptance_probs > target_accept_rate, + 0.01, -0.01)) + + warmup = tf.group([x_update, step_size_update]) + + tf.global_variables_initializer().run() + + sess.graph.finalize() # No more graph building. + # Warm up the sampler and adapt the step size - for i in xrange(500): - sess.run(adaptive_updates) + for _ in xrange(num_warmup_iter): + sess.run(warmup) + # Collect samples without adapting step size - samples = np.zeros([500, 10]) - for i in xrange(500): - x_val, _ = sess.run([new_x, sampling_updates]) - samples[i] = x_val + samples = np.zeros([num_chain_iter]) + for i in xrange(num_chain_iter): + _, x_, target_log_prob_, grad_ = sess.run([ + x_update, + x, + other_results.target_log_prob, + other_results.grads_target_log_prob]) + samples[i] = x_ + + print(samples.mean(), samples.std()) + ``` + + ##### Sample from more complicated posterior. + + I.e., + + ```none + W ~ MVN(loc=0, scale=sigma * eye(dims)) + for i=1...num_samples: + X[i] ~ MVN(loc=0, scale=eye(dims)) + eps[i] ~ Normal(loc=0, scale=1) + Y[i] = X[i].T * W + eps[i] ``` ```python - # Empirical-Bayes estimation of a hyperparameter by MCMC-EM: - - # Problem setup - N = 150 - D = 10 - x = np.random.randn(N, D).astype(np.float32) - true_sigma = 0.5 - true_beta = true_sigma * np.random.randn(D).astype(np.float32) - y = x.dot(true_beta) + np.random.randn(N).astype(np.float32) - - def log_prior(beta, log_sigma): - return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) - - log_sigma) - def regression_log_joint(beta, log_sigma, x, y): - # This function returns log p(beta | log_sigma) + log p(y | x, beta). - means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True) - means = tf.squeeze(means) - log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means)) - return log_prior(beta, log_sigma) + log_likelihood - def log_joint_partial(beta): - return regression_log_joint(beta, log_sigma, x, y) - # Our estimate of log(sigma) - log_sigma = tf.Variable(0., name='log_sigma') - # The state of the Markov chain - beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta') - new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial, - event_dims=[0]) - beta_update = tf.assign(beta, new_beta) + tfd = tf.contrib.distributions + + def make_training_data(num_samples, dims, sigma): + dt = np.asarray(sigma).dtype + zeros = tf.zeros(dims, dtype=dt) + x = tfd.MultivariateNormalDiag( + loc=zeros).sample(num_samples, seed=1) + w = tfd.MultivariateNormalDiag( + loc=zeros, + scale_identity_multiplier=sigma).sample(seed=2) + noise = tfd.Normal( + loc=dt(0), + scale=dt(1)).sample(num_samples, seed=3) + y = tf.tensordot(x, w, axes=[[1], [0]]) + noise + return y, x, w + + def make_prior(sigma, dims): + # p(w | sigma) + return tfd.MultivariateNormalDiag( + loc=tf.zeros([dims], dtype=sigma.dtype), + scale_identity_multiplier=sigma) + + def make_likelihood(x, w): + # p(y | x, w) + return tfd.MultivariateNormalDiag( + loc=tf.tensordot(x, w, axes=[[1], [0]])) + + # Setup assumptions. + dtype = np.float32 + num_samples = 150 + dims = 10 + num_iters = int(5e3) + + true_sigma = dtype(0.5) + y, x, true_weights = make_training_data(num_samples, dims, true_sigma) + + # Estimate of `log(true_sigma)`. + log_sigma = tf.get_variable(name="log_sigma", initializer=dtype(0)) + sigma = tf.exp(log_sigma) + + # State of the Markov chain. + weights = tf.get_variable( + name="weights", + initializer=np.random.randn(dims).astype(dtype)) + + prior = make_prior(sigma, dims) + + def joint_log_prob_fn(w): + # f(w) = log p(w, y | x) + return prior.log_prob(w) + make_likelihood(x, w).log_prob(y) + + weights_update = weights.assign( + hmc.kernel(target_log_prob_fn=joint_log_prob, + current_state=weights, + step_size=0.1, + num_leapfrog_steps=5)[0]) + + with tf.control_dependencies([weights_update]): + loss = -prior.log_prob(weights) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - with tf.control_dependencies([beta_update]): - log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma), - var_list=[log_sigma]) - - sess = tf.Session() - sess.run(tf.global_variables_initializer()) - log_sigma_history = np.zeros(1000) - for i in xrange(1000): - log_sigma_val, _ = sess.run([log_sigma, log_sigma_update]) - log_sigma_history[i] = log_sigma_val - # Should converge to something close to true_sigma - plt.plot(np.exp(log_sigma_history)) + log_sigma_update = optimizer.minimize(loss, var_list=[log_sigma]) + + sess.graph.finalize() # No more graph building. + + tf.global_variables_initializer().run() + + sigma_history = np.zeros(num_iters, dtype) + weights_history = np.zeros([num_iters, dims], dtype) + + for i in xrange(num_iters): + _, sigma_, weights_, _ = sess.run([log_sigma_update, sigma, weights]) + weights_history[i, :] = weights_ + sigma_history[i] = sigma_ + + true_weights_ = sess.run(true_weights) + + # Should converge to something close to true_sigma. + plt.plot(sigma_history); + plt.ylabel("sigma"); + plt.xlabel("iteration"); ``` - """ - with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): - potential_and_grad = _make_potential_and_grad(target_log_prob_fn) - - x_shape = array_ops.shape(x) - m = random_ops.random_normal(x_shape) - - kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) - - if (x_log_prob is not None) and (x_grad is not None): - log_potential_0, grad_0 = -x_log_prob, -x_grad # pylint: disable=invalid-unary-operand-type - else: - if x_log_prob is not None: - logging.warn('x_log_prob was provided, but x_grad was not,' - ' so x_log_prob was not used.') - if x_grad is not None: - logging.warn('x_grad was provided, but x_log_prob was not,' - ' so x_grad was not used.') - log_potential_0, grad_0 = potential_and_grad(x) - - new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator( - step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0) - - kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) - - # TODO(mhoffman): It seems like there may be an opportunity for nans here. - # I'm delaying addressing this because we're going to refactor this part - # to use the more general Metropolis abstraction anyway. - acceptance_probs = math_ops.exp(math_ops.minimum(0., log_potential_0 - - log_potential_1 + - kinetic_0 - kinetic_1)) - accepted = math_ops.cast( - random_ops.random_uniform(array_ops.shape(acceptance_probs)) < - acceptance_probs, np.float32) - new_log_prob = (-log_potential_0 * (1. - accepted) - - log_potential_1 * accepted) - - # TODO(b/65738010): This should work, but it doesn't for now. - # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) - reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, - keep_dims=True)) - accepted = array_ops.reshape(accepted, reduced_shape) - new_x = x * (1. - accepted) + new_x * accepted - new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted - - return new_x, acceptance_probs, new_log_prob, new_grad - - -def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, - potential_and_grad, initial_grad, name=None): - """Applies `n_steps` steps of the leapfrog integrator. - - This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing - gradient computations where possible. Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `initial_position`. Larger step sizes lead to faster progress, but - too-large step sizes lead to larger discretization error and - worse energy conservation. - n_steps: Number of steps to run the leapfrog integrator. - initial_position: Tensor containing the value(s) of the position variable(s) - to update. - initial_momentum: Tensor containing the value(s) of the momentum variable(s) - to update. - potential_and_grad: Python callable that takes a position tensor like - `initial_position` and returns the potential energy and its gradient at - that position. - initial_grad: Tensor with the value of the gradient of the potential energy - at `initial_position`. + target_log_prob_fn: Python callable which takes an argument like + `current_state` (or `*current_state` if it's a list) and returns its + (possibly unnormalized) log-density under the target distribution. + current_state: `Tensor` or Python `list` of `Tensor`s representing the + current state(s) of the Markov chain(s). The first `r` dimensions index + independent chains, `r = tf.rank(target_log_prob_fn(*current_state))`. + step_size: `Tensor` or Python `list` of `Tensor`s representing the step size + for the leapfrog integrator. Must broadcast with the shape of + `current_state`. Larger step sizes lead to faster progress, but too-large + step sizes make rejection exponentially more likely. When possible, it's + often helpful to match per-variable step sizes to the standard deviations + of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + seed: Python integer to seed the random number generator. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn` at the `current_state`. The only reason to + specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `current_target_log_prob` at the `current_state` + and wrt the `current_state`. Must have same shape as `current_state`. The + only reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_kernel"). Returns: - updated_position: Updated value of the position. - updated_momentum: Updated value of the momentum. - new_potential: Potential energy of the new position. Has shape matching - `potential_and_grad(initial_position)`. - new_grad: Gradient from potential_and_grad() evaluated at the new position. - Has shape matching `initial_position`. - - Example: Simple quadratic potential. - ```python - def potential_and_grad(position): - return tf.reduce_sum(0.5 * tf.square(position)), position - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - potential, grad = potential_and_grad(position) - new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator( - 0.1, 3, position, momentum, potential_and_grad, grad) - - sess = tf.Session() - position_val = np.random.randn(10) - momentum_val = np.random.randn(10) - potential_val, grad_val = sess.run([potential, grad], - {position: position_val}) - positions = np.zeros([100, 10]) - for i in xrange(100): - position_val, momentum_val, potential_val, grad_val = sess.run( - [new_position, new_momentum, new_potential, new_grad], - {position: position_val, momentum: momentum_val}) - positions[i] = position_val - # Should trace out sinusoidal dynamics. - plt.plot(positions[:, 0]) - ``` + accepted_state: Tensor or Python list of `Tensor`s representing the state(s) + of the Markov chain(s) at each result step. Has same shape as + `current_state`. + kernel_results: `collections.namedtuple` of internal calculations used to + advance the chain. + + Raises: + ValueError: if there isn't one `step_size` or a list with same length as + `current_state`. """ - def leapfrog_wrapper(step_size, x, m, grad, l): - x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad) - return step_size, x, m, grad, l + 1 + with ops.name_scope( + name, "hmc_kernel", + [current_state, step_size, num_leapfrog_steps, seed, + current_target_log_prob, current_grads_target_log_prob]): + with ops.name_scope("initialize"): + [current_state_parts, step_sizes, current_target_log_prob, + current_grads_target_log_prob] = _prepare_args( + target_log_prob_fn, current_state, step_size, + current_target_log_prob, current_grads_target_log_prob, + maybe_expand=True) + independent_chain_ndims = distributions_util.prefer_static_rank( + current_target_log_prob) + current_momentums = [] + for s in current_state_parts: + current_momentums.append(random_ops.random_normal( + shape=array_ops.shape(s), + dtype=s.dtype.base_dtype, + seed=seed)) + seed = distributions_util.gen_new_seed( + seed, salt="hmc_kernel_momentums") + + num_leapfrog_steps = ops.convert_to_tensor( + num_leapfrog_steps, + dtype=dtypes.int32, + name="num_leapfrog_steps") + [ + proposed_momentums, + proposed_state_parts, + proposed_target_log_prob, + proposed_grads_target_log_prob, + ] = _leapfrog_integrator(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + num_leapfrog_steps, + current_target_log_prob, + current_grads_target_log_prob) + + energy_change = _compute_energy_change(current_target_log_prob, + current_momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims) + + # u < exp(min(-energy, 0)), where u~Uniform[0,1) + # ==> -log(u) >= max(e, 0) + # ==> -log(u) >= e + # (Perhaps surprisingly, we don't have a better way to obtain a random + # uniform from positive reals, i.e., `tf.random_uniform(minval=0, + # maxval=np.inf)` won't work.) + random_uniform = random_ops.random_uniform( + shape=array_ops.shape(energy_change), + dtype=energy_change.dtype, + seed=seed) + random_positive = -math_ops.log(random_uniform) + is_accepted = random_positive >= energy_change + + accepted_target_log_prob = array_ops.where(is_accepted, + proposed_target_log_prob, + current_target_log_prob) + + accepted_state_parts = [_choose(is_accepted, + proposed_state_part, + current_state_part, + independent_chain_ndims) + for current_state_part, proposed_state_part + in zip(current_state_parts, proposed_state_parts)] + + accepted_grads_target_log_prob = [ + _choose(is_accepted, + proposed_grad, + grad, + independent_chain_ndims) + for proposed_grad, grad + in zip(proposed_grads_target_log_prob, current_grads_target_log_prob)] + + maybe_flatten = lambda x: x if _is_list_like(current_state) else x[0] + return [ + maybe_flatten(accepted_state_parts), + KernelResults( + acceptance_probs=math_ops.exp(math_ops.minimum(-energy_change, 0.)), + current_grads_target_log_prob=accepted_grads_target_log_prob, + current_target_log_prob=accepted_target_log_prob, + energy_change=energy_change, + is_accepted=is_accepted, + proposed_grads_target_log_prob=proposed_grads_target_log_prob, + proposed_state=maybe_flatten(proposed_state_parts), + proposed_target_log_prob=proposed_target_log_prob, + random_positive=random_positive, + ), + ] + + +def _leapfrog_integrator(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + num_leapfrog_steps, + current_target_log_prob=None, + current_grads_target_log_prob=None, + name=None): + """Applies `num_leapfrog_steps` of the leapfrog integrator. + + Assumes a simple quadratic kinetic energy function: `0.5 ||momentum||**2`. - def counter_fn(a, b, c, d, counter): # pylint: disable=unused-argument - return counter < n_steps + #### Examples: - with ops.name_scope(name, 'leapfrog_integrator', - [step_size, n_steps, initial_position, initial_momentum, - initial_grad]): - _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop( - counter_fn, leapfrog_wrapper, [step_size, initial_position, - initial_momentum, initial_grad, - array_ops.constant(0)], back_prop=False) - # We're counting on the runtime to eliminate this redundant computation. - new_potential, new_grad = potential_and_grad(new_x) - return new_x, new_m, new_potential, new_grad + ##### Simple quadratic potential. + ```python + tfd = tf.contrib.distributions -def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, - name=None): - """Applies one step of the leapfrog integrator. + dims = 10 + num_iter = int(1e3) + dtype = np.float32 - Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2. + position = tf.placeholder(np.float32) + momentum = tf.placeholder(np.float32) + + [ + new_momentums, + new_positions, + ] = hmc._leapfrog_integrator( + current_momentums=[momentum], + target_log_prob_fn=tfd.MultivariateNormalDiag( + loc=tf.zeros(dims, dtype)).log_prob, + current_state_parts=[position], + step_sizes=0.1, + num_leapfrog_steps=3)[:2] + + sess.graph.finalize() # No more graph building. + + momentum_ = np.random.randn(dims).astype(dtype) + position_ = np.random.randn(dims).astype(dtype) + + positions = np.zeros([num_iter, dims], dtype) + for i in xrange(num_iter): + position_, momentum_ = sess.run( + [new_momentums[0], new_position[0]], + feed_dict={position: position_, momentum: momentum_}) + positions[i] = position_ + + plt.plot(positions[:, 0]); # Sinusoidal. + ``` Args: - step_size: Scalar step size or array of step sizes for the - leapfrog integrator. Broadcasts to the shape of - `position`. Larger step sizes lead to faster progress, but - too-large step sizes lead to larger discretization error and - worse energy conservation. - position: Tensor containing the value(s) of the position variable(s) - to update. - momentum: Tensor containing the value(s) of the momentum variable(s) - to update. - potential_and_grad: Python callable that takes a position tensor like - `position` and returns the potential energy and its gradient at that - position. - grad: Tensor with the value of the gradient of the potential energy - at `position`. + current_momentums: Tensor containing the value(s) of the momentum + variable(s) to update. + target_log_prob_fn: Python callable which takes an argument like + `*current_state_parts` and returns its (possibly unnormalized) log-density + under the target distribution. + current_state_parts: Python `list` of `Tensor`s representing the current + state(s) of the Markov chain(s). The first `independent_chain_ndims` of + the `Tensor`(s) index different chains. + step_sizes: Python `list` of `Tensor`s representing the step size for the + leapfrog integrator. Must broadcast with the shape of + `current_state_parts`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. When + possible, it's often helpful to match per-variable step sizes to the + standard deviations of the target distribution in each variable. + num_leapfrog_steps: Integer number of steps to run the leapfrog integrator + for. Total progress per HMC step is roughly proportional to `step_size * + num_leapfrog_steps`. + current_target_log_prob: (Optional) `Tensor` representing the value of + `target_log_prob_fn(*current_state_parts)`. The only reason to specify + this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). + current_grads_target_log_prob: (Optional) Python list of `Tensor`s + representing gradient of `target_log_prob_fn(*current_state_parts`) wrt + `current_state_parts`. Must have same shape as `current_state_parts`. The + only reason to specify this argument is to reduce TF graph size. + Default value: `None` (i.e., compute as needed). name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "hmc_leapfrog_integrator"). Returns: - updated_position: Updated value of the position. - updated_momentum: Updated value of the momentum. - new_potential: Potential energy of the new position. Has shape matching - `potential_and_grad(position)`. - new_grad: Gradient from potential_and_grad() evaluated at the new position. - Has shape matching `position`. - - Example: Simple quadratic potential. - ```python - def potential_and_grad(position): - # Simple quadratic potential - return tf.reduce_sum(0.5 * tf.square(position)), position - position = tf.placeholder(np.float32) - momentum = tf.placeholder(np.float32) - potential, grad = potential_and_grad(position) - new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step( - 0.1, position, momentum, potential_and_grad, grad) - - sess = tf.Session() - position_val = np.random.randn(10) - momentum_val = np.random.randn(10) - potential_val, grad_val = sess.run([potential, grad], - {position: position_val}) - positions = np.zeros([100, 10]) - for i in xrange(100): - position_val, momentum_val, potential_val, grad_val = sess.run( - [new_position, new_momentum, new_potential, new_grad], - {position: position_val, momentum: momentum_val}) - positions[i] = position_val - # Should trace out sinusoidal dynamics. - plt.plot(positions[:, 0]) - ``` + proposed_momentums: Updated value of the momentum. + proposed_state_parts: Tensor or Python list of `Tensor`s representing the + state(s) of the Markov chain(s) at each result step. Has same shape as + input `current_state_parts`. + proposed_target_log_prob: `Tensor` representing the value of + `target_log_prob_fn` at `accepted_state`. + proposed_grads_target_log_prob: Gradient of `proposed_target_log_prob` wrt + `accepted_state`. + + Raises: + ValueError: if `len(momentums) != len(state_parts)`. + ValueError: if `len(state_parts) != len(step_sizes)`. + ValueError: if `len(state_parts) != len(grads_target_log_prob)`. + TypeError: if `not target_log_prob.dtype.is_floating`. """ - with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum, - grad]): - momentum -= 0.5 * step_size * grad - position += step_size * momentum - potential, grad = potential_and_grad(position) - momentum -= 0.5 * step_size * grad - - return position, momentum, potential, grad + def _loop_body(step, + current_momentums, + current_state_parts, + ignore_current_target_log_prob, # pylint: disable=unused-argument + current_grads_target_log_prob): + return [step + 1] + list(_leapfrog_step(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + current_grads_target_log_prob)) + + with ops.name_scope( + name, "hmc_leapfrog_integrator", + [current_momentums, current_state_parts, step_sizes, num_leapfrog_steps, + current_target_log_prob, current_grads_target_log_prob]): + if len(current_momentums) != len(current_state_parts): + raise ValueError("`momentums` must be in one-to-one correspondence " + "with `state_parts`") + num_leapfrog_steps = ops.convert_to_tensor(num_leapfrog_steps, + name="num_leapfrog_steps") + current_target_log_prob, current_grads_target_log_prob = ( + _maybe_call_fn_and_grads( + target_log_prob_fn, + current_state_parts, + current_target_log_prob, + current_grads_target_log_prob)) + return control_flow_ops.while_loop( + cond=lambda iter_, *args: iter_ < num_leapfrog_steps, + body=_loop_body, + loop_vars=[ + np.int32(0), # iter_ + current_momentums, + current_state_parts, + current_target_log_prob, + current_grads_target_log_prob, + ], + back_prop=False)[1:] # Lop-off "iter_". + + +def _leapfrog_step(current_momentums, + target_log_prob_fn, + current_state_parts, + step_sizes, + current_grads_target_log_prob, + name=None): + """Applies one step of the leapfrog integrator.""" + with ops.name_scope( + name, "_leapfrog_step", + [current_momentums, current_state_parts, step_sizes, + current_grads_target_log_prob]): + proposed_momentums = [m + 0.5 * ss * g for m, ss, g + in zip(current_momentums, + step_sizes, + current_grads_target_log_prob)] + proposed_state_parts = [x + ss * m for x, ss, m + in zip(current_state_parts, + step_sizes, + proposed_momentums)] + proposed_target_log_prob = target_log_prob_fn(*proposed_state_parts) + if not proposed_target_log_prob.dtype.is_floating: + raise TypeError("`target_log_prob_fn` must produce a `Tensor` " + "with `float` `dtype`.") + proposed_grads_target_log_prob = gradients_ops.gradients( + proposed_target_log_prob, proposed_state_parts) + if any(g is None for g in proposed_grads_target_log_prob): + raise ValueError( + "Encountered `None` gradient. Does your target `target_log_prob_fn` " + "access all `tf.Variable`s via `tf.get_variable`?\n" + " current_state_parts: {}\n" + " proposed_state_parts: {}\n" + " proposed_grads_target_log_prob: {}".format( + current_state_parts, + proposed_state_parts, + proposed_grads_target_log_prob)) + proposed_momentums = [m + 0.5 * ss * g for m, ss, g + in zip(proposed_momentums, + step_sizes, + proposed_grads_target_log_prob)] + return [ + proposed_momentums, + proposed_state_parts, + proposed_target_log_prob, + proposed_grads_target_log_prob, + ] + + +def _compute_energy_change(current_target_log_prob, + current_momentums, + proposed_target_log_prob, + proposed_momentums, + independent_chain_ndims, + name=None): + """Helper to `kernel` which computes the energy change.""" + with ops.name_scope( + name, "compute_energy_change", + ([current_target_log_prob, proposed_target_log_prob, + independent_chain_ndims] + + current_momentums + proposed_momentums)): + # Abbreviate lk0=log_kinetic_energy and lk1=proposed_log_kinetic_energy + # since they're a mouthful and lets us inline more. + lk0, lk1 = [], [] + for current_momentum, proposed_momentum in zip(current_momentums, + proposed_momentums): + axis = math_ops.range(independent_chain_ndims, + array_ops.rank(current_momentum)) + lk0.append(_log_sum_sq(current_momentum, axis)) + lk1.append(_log_sum_sq(proposed_momentum, axis)) + + lk0 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk0, axis=-1), + axis=-1) + lk1 = -np.log(2.) + math_ops.reduce_logsumexp(array_ops.stack(lk1, axis=-1), + axis=-1) + lp0 = -current_target_log_prob # log_potential + lp1 = -proposed_target_log_prob # proposed_log_potential + x = array_ops.stack([lp1, math_ops.exp(lk1), -lp0, -math_ops.exp(lk0)], + axis=-1) + + # The sum is NaN if any element is NaN or we see both +Inf and -Inf. + # Thus we will replace such rows with infinite energy change which implies + # rejection. Recall that float-comparisons with NaN are always False. + is_sum_determinate = ( + math_ops.reduce_all(math_ops.is_finite(x) | (x >= 0.), axis=-1) & + math_ops.reduce_all(math_ops.is_finite(x) | (x <= 0.), axis=-1)) + is_sum_determinate = array_ops.tile( + is_sum_determinate[..., array_ops.newaxis], + multiples=array_ops.concat([ + array_ops.ones(array_ops.rank(is_sum_determinate), + dtype=dtypes.int32), + [4], + ], axis=0)) + x = array_ops.where(is_sum_determinate, + x, + array_ops.fill(array_ops.shape(x), + value=x.dtype.as_numpy_dtype(np.inf))) + + return math_ops.reduce_sum(x, axis=-1) + + +def _choose(is_accepted, + accepted, + rejected, + independent_chain_ndims, + name=None): + """Helper to `kernel` which expand_dims `is_accepted` to apply tf.where.""" + def _expand_is_accepted_like(x): + with ops.name_scope("_choose"): + expand_shape = array_ops.concat([ + array_ops.shape(is_accepted), + array_ops.ones([array_ops.rank(x) - array_ops.rank(is_accepted)], + dtype=dtypes.int32), + ], axis=0) + multiples = array_ops.concat([ + array_ops.ones([array_ops.rank(is_accepted)], dtype=dtypes.int32), + array_ops.shape(x)[independent_chain_ndims:], + ], axis=0) + m = array_ops.tile(array_ops.reshape(is_accepted, expand_shape), + multiples) + m.set_shape(x.shape) + return m + with ops.name_scope(name, "_choose", values=[ + is_accepted, accepted, rejected, independent_chain_ndims]): + return array_ops.where(_expand_is_accepted_like(accepted), + accepted, + rejected) + + +def _maybe_call_fn_and_grads(fn, + fn_arg_list, + fn_result=None, + grads_fn_result=None, + description="target_log_prob"): + """Helper which computes `fn_result` and `grads` if needed.""" + fn_arg_list = (list(fn_arg_list) if _is_list_like(fn_arg_list) + else [fn_arg_list]) + if fn_result is None: + fn_result = fn(*fn_arg_list) + if not fn_result.dtype.is_floating: + raise TypeError("`{}` must be a `Tensor` with `float` `dtype`.".format( + description)) + if grads_fn_result is None: + grads_fn_result = gradients_ops.gradients( + fn_result, fn_arg_list) + if len(fn_arg_list) != len(grads_fn_result): + raise ValueError("`{}` must be in one-to-one correspondence with " + "`grads_{}`".format(*[description]*2)) + if any(g is None for g in grads_fn_result): + raise ValueError("Encountered `None` gradient.") + return fn_result, grads_fn_result + + +def _prepare_args(target_log_prob_fn, state, step_size, + target_log_prob=None, grads_target_log_prob=None, + maybe_expand=False, description="target_log_prob"): + """Helper which processes input args to meet list-like assumptions.""" + state_parts = list(state) if _is_list_like(state) else [state] + state_parts = [ops.convert_to_tensor(s, name="state") + for s in state_parts] + target_log_prob, grads_target_log_prob = _maybe_call_fn_and_grads( + target_log_prob_fn, + state_parts, + target_log_prob, + grads_target_log_prob, + description) + step_sizes = list(step_size) if _is_list_like(step_size) else [step_size] + step_sizes = [ + ops.convert_to_tensor( + s, name="step_size", dtype=target_log_prob.dtype) + for s in step_sizes] + if len(step_sizes) == 1: + step_sizes *= len(state_parts) + if len(state_parts) != len(step_sizes): + raise ValueError("There should be exactly one `step_size` or it should " + "have same length as `current_state`.") + maybe_flatten = lambda x: x if maybe_expand or _is_list_like(state) else x[0] + return [ + maybe_flatten(state_parts), + maybe_flatten(step_sizes), + target_log_prob, + grads_target_log_prob, + ] + + +def _is_list_like(x): + """Helper which returns `True` if input is `list`-like.""" + return isinstance(x, (tuple, list)) + + +def _log_sum_sq(x, axis=None): + """Computes log(sum(x**2)).""" + return math_ops.reduce_logsumexp(2. * math_ops.log(math_ops.abs(x)), axis) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers.py b/tensorflow/contrib/bayesflow/python/ops/layers.py index dcead38af826a12e776160bdb251ba021e6b953c..a742b7c1aa593d6c08bf9d8d597c99c9fc4e7aed 100644 --- a/tensorflow/contrib/bayesflow/python/ops/layers.py +++ b/tensorflow/contrib/bayesflow/python/ops/layers.py @@ -23,13 +23,43 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational_impl import * +from tensorflow.contrib.bayesflow.python.ops.layers_conv_variational import * +from tensorflow.contrib.bayesflow.python.ops.layers_dense_variational import * +from tensorflow.contrib.bayesflow.python.ops.layers_util import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'DenseVariational', - 'dense_variational', + 'Convolution1DReparameterization', + 'Convolution2DReparameterization', + 'Convolution3DReparameterization', + 'Convolution1DFlipout', + 'Convolution2DFlipout', + 'Convolution3DFlipout', + 'Conv1DReparameterization', + 'Conv2DReparameterization', + 'Conv3DReparameterization', + 'Conv1DFlipout', + 'Conv2DFlipout', + 'Conv3DFlipout', + 'convolution1d_reparameterization', + 'convolution2d_reparameterization', + 'convolution3d_reparameterization', + 'convolution1d_flipout', + 'convolution2d_flipout', + 'convolution3d_flipout', + 'conv1d_reparameterization', + 'conv2d_reparameterization', + 'conv3d_reparameterization', + 'conv1d_flipout', + 'conv2d_flipout', + 'conv3d_flipout', + 'DenseReparameterization', + 'DenseLocalReparameterization', + 'DenseFlipout', + 'dense_reparameterization', + 'dense_local_reparameterization', + 'dense_flipout', 'default_loc_scale_fn', 'default_mean_field_normal_fn', ] diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py new file mode 100644 index 0000000000000000000000000000000000000000..7723cfb442712626ff415f1412e3362f2392ce9f --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_conv_variational.py @@ -0,0 +1,2943 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Convolutional variational layer classes and their functional aliases. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util + + +class _ConvVariational(layers_lib.Layer): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: A string, the name of the layer. + + Properties: + rank: Python integer, dimensionality of convolution. + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(_ConvVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.rank = rank + self.filters = filters + self.kernel_size = utils.normalize_tuple(kernel_size, rank, "kernel_size") + self.strides = utils.normalize_tuple(strides, rank, "strides") + self.padding = utils.normalize_padding(padding) + self.data_format = utils.normalize_data_format(data_format) + self.dilation_rate = utils.normalize_tuple( + dilation_rate, rank, "dilation_rate") + self.activation = activation + self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2) + self.kernel_posterior_fn = kernel_posterior_fn + self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn + self.kernel_prior_fn = kernel_prior_fn + self.kernel_divergence_fn = kernel_divergence_fn + self.bias_posterior_fn = bias_posterior_fn + self.bias_posterior_tensor_fn = bias_posterior_tensor_fn + self.bias_prior_fn = bias_prior_fn + self.bias_divergence_fn = bias_divergence_fn + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if self.data_format == "channels_first": + channel_axis = 1 + else: + channel_axis = -1 + if input_shape[channel_axis].value is None: + raise ValueError("The channel dimension of the inputs " + "should be defined. Found `None`.") + input_dim = input_shape[channel_axis].value + kernel_shape = self.kernel_size + (input_dim, self.filters) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel_posterior = self.kernel_posterior_fn( + dtype, kernel_shape, "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel_prior_fn is None: + self.kernel_prior = None + else: + self.kernel_prior = self.kernel_prior_fn( + dtype, kernel_shape, "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias_posterior_fn is None: + self.bias_posterior = None + else: + self.bias_posterior = self.bias_posterior_fn( + dtype, (self.filters,), "bias_posterior", + self.trainable, self.add_variable) + + if self.bias_prior_fn is None: + self.bias_prior = None + else: + self.bias_prior = self.bias_prior_fn( + dtype, (self.filters,), "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.input_spec = layers_lib.InputSpec(ndim=self.rank + 2, + axes={channel_axis: input_dim}) + self._convolution_op = nn_ops.Convolution( + input_shape, + filter_shape=tensor_shape.TensorShape(kernel_shape), + dilation_rate=self.dilation_rate, + strides=self.strides, + padding=self.padding.upper(), + data_format=utils.convert_data_format(self.data_format, + self.rank + 2)) + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) + if not self._built_kernel_divergence: + kernel_posterior = self.kernel_posterior + kernel_prior = self.kernel_prior + if isinstance(self.kernel_posterior, independent_lib.Independent): + kernel_posterior = kernel_posterior.distribution + if isinstance(self.kernel_prior, independent_lib.Independent): + kernel_prior = kernel_prior.distribution + self._apply_divergence(self.kernel_divergence_fn, + kernel_posterior, + kernel_prior, + self.kernel_posterior_tensor, + name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + bias_posterior = self.bias_posterior + bias_prior = self.bias_prior + if isinstance(self.bias_posterior, independent_lib.Independent): + bias_posterior = bias_posterior.distribution + if isinstance(self.bias_prior, independent_lib.Independent): + bias_prior = bias_prior.distribution + self._apply_divergence(self.bias_divergence_fn, + bias_posterior, + bias_prior, + self.bias_posterior_tensor, + name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_bias(self, inputs): + if self.bias_posterior is None: + self.bias_posterior_tensor = None + return inputs + self.bias_posterior_tensor = self.bias_posterior_tensor_fn( + self.bias_posterior) + outputs = inputs + if self.data_format == "channels_first": + if self.rank == 1: + # nn.bias_add does not accept a 1D input tensor. + bias = array_ops.reshape(self.bias_posterior_tensor, + (1, self.filters, 1)) + outputs += bias + if self.rank == 2: + outputs = nn.bias_add(outputs, + self.bias_posterior_tensor, + data_format="NCHW") + if self.rank == 3: + # As of Mar 2017, direct addition is significantly slower than + # bias_add when computing gradients. To use bias_add, we collapse Z + # and Y into a single dimension to obtain a 4D input tensor. + outputs_shape = outputs.shape.as_list() + outputs_4d = array_ops.reshape(outputs, + [outputs_shape[0], outputs_shape[1], + outputs_shape[2] * outputs_shape[3], + outputs_shape[4]]) + outputs_4d = nn.bias_add(outputs_4d, + self.bias_posterior_tensor, + data_format="NCHW") + outputs = array_ops.reshape(outputs_4d, outputs_shape) + else: + outputs = nn.bias_add(outputs, + self.bias_posterior_tensor, + data_format="NHWC") + return outputs + + def _apply_divergence(self, divergence_fn, posterior, prior, + posterior_tensor, name): + if (divergence_fn is None or + posterior is None or + prior is None): + divergence = None + return + divergence = standard_ops.identity( + divergence_fn( + posterior, prior, posterior_tensor), + name=name) + self.add_loss(divergence) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).as_list() + if self.data_format == "channels_last": + space = input_shape[1:-1] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0]] + new_space + + [self.filters]) + else: + space = input_shape[2:] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0], self.filters] + + new_space) + + +class _ConvReparameterization(_ConvVariational): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: A string, the name of the layer. + + Properties: + rank: Python integer, dimensionality of convolution. + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(_ConvReparameterization, self).__init__( + rank=rank, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + def _apply_variational_kernel(self, inputs): + self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( + self.kernel_posterior) + self.kernel_posterior_affine = None + self.kernel_posterior_affine_tensor = None + outputs = self._convolution_op(inputs, self.kernel_posterior_tensor) + return outputs + + +class Conv1DReparameterization(_ConvReparameterization): + """1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.Conv1DReparameterization(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv1DReparameterization, self).__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv1d_reparameterization( + inputs, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for 1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.conv1d_reparameterization(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = Conv1DReparameterization( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv2DReparameterization(_ConvReparameterization): + """2D convolution layer (e.g. spatial convolution over images). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.Conv2DReparameterization(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv2DReparameterization, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv2d_reparameterization( + inputs, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for the 2D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.conv2d_reparameterization(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = Conv2DReparameterization( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv3DReparameterization(_ConvReparameterization): + """3D convolution layer (e.g. spatial convolution over volumes). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.Conv3DReparameterization(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(Conv3DReparameterization, self).__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + + +def conv3d_reparameterization( + inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Functional interface for the 3D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the reparameterization + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.conv3d_reparameterization(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = Conv3DReparameterization( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class _ConvFlipout(_ConvVariational): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + rank: Python integer, dimensionality of convolution. + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + rank, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(_ConvFlipout, self).__init__( + rank=rank, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, **kwargs) + self.seed = seed + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`{}` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format( + type(self).__name__, self.kernel_posterior.name)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), + scale=self.kernel_posterior.distribution.scale) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + + outputs = self._convolution_op( + inputs, self.kernel_posterior.distribution.loc) + + input_shape = array_ops.shape(inputs) + output_shape = array_ops.shape(outputs) + batch_shape = array_ops.expand_dims(input_shape[0], 0) + channels = input_shape[-1] + + sign_input = layers_util.random_sign( + array_ops.concat([batch_shape, + array_ops.expand_dims(channels, 0)], 0), + dtype=inputs.dtype, + seed=self.seed) + sign_output = layers_util.random_sign( + array_ops.concat([batch_shape, + array_ops.expand_dims(self.filters, 0)], 0), + dtype=inputs.dtype, + seed=distribution_util.gen_new_seed( + self.seed, salt="conv_flipout")) + for _ in range(self.rank): + sign_input = array_ops.expand_dims(sign_input, 1) # 2D ex: (B, 1, 1, C) + sign_output = array_ops.expand_dims(sign_output, 1) + + sign_input = array_ops.tile( # tile for element-wise op broadcasting + sign_input, + [1] + [input_shape[i + 1] for i in range(self.rank)] + [1]) + sign_output = array_ops.tile( + sign_output, + [1] + [output_shape[i + 1] for i in range(self.rank)] + [1]) + + perturbed_inputs = self._convolution_op( + inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output + + outputs += perturbed_inputs + return outputs + + +class Conv1DFlipout(_ConvFlipout): + """1D convolution layer (e.g. temporal convolution) with Flipout. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.Conv1DFlipout(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(Conv1DFlipout, self).__init__( + rank=1, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, **kwargs) + + +def conv1d_flipout( + inputs, + filters, + kernel_size, + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Functional interface for 1D convolution layer (e.g. temporal convolution). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of a single integer, specifying the + length of the 1D convolution window. + strides: An integer or tuple/list of a single integer, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, length)`. + dilation_rate: An integer or tuple/list of a single integer, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 128, 1]) + net = tfp.layers.conv1d_flipout(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.reshape(net, [-1, 128 * 64]) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = Conv1DFlipout( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv2DFlipout(_ConvFlipout): + """2D convolution layer (e.g. spatial convolution over images) with Flipout. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.Conv2DFlipout(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(Conv2DFlipout, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, **kwargs) + + +def conv2d_flipout( + inputs, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Functional interface for the 2D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 32, 32, 3]) + net = tfp.layers.conv2d_flipout(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 8 * 8 * 64]) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = Conv2DFlipout( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class Conv3DFlipout(_ConvFlipout): + """3D convolution layer (e.g. spatial convolution over volumes) with Flipout. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + + Properties: + filters: Python integer, dimensionality of the output space. + kernel_size: Size of the convolution window. + strides: Stride length of convolution. + padding: Python string describing padding approach. + data_format: Python string describing input data's dimensions. + dilation_rate: Dilation rate for an atrous convolution. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.Conv3DFlipout(64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu)(net) + net = tf.layers.MaxPooling2D(pool_size=2, + strides=2, + padding="SAME")(net) + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(Conv3DFlipout, self).__init__( + rank=3, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, **kwargs) + + +def conv3d_flipout( + inputs, + filters, + kernel_size, + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Functional interface for the 3D convolution layer. + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. It may also include a bias addition and activation function + on the outputs. It assumes the `kernel` and/or `bias` are drawn from + distributions. + + By default, the layer implements a stochastic forward pass via + sampling from the kernel and bias posteriors, + ```none + outputs = f(inputs; kernel, bias), kernel, bias ~ posterior + ``` + where f denotes the layer's calculation. It uses the Flipout + estimator [1], which performs a Monte Carlo approximation of the + distribution integrating over the `kernel` and `bias`. Flipout uses + roughly twice as many floating point operations as the + reparameterization estimator but has the advantage of significantly + lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Arguments: + inputs: Tensor input. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 3 integers, specifying the + depth, height and width of the 3D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 3 integers, + specifying the strides of the convolution along the depth, + height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, depth, height, width, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, depth, height, width)`. + dilation_rate: An integer or tuple/list of 3 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + activity_regularizer: Optional regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: A string, the name of the layer. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Returns: + Output tensor. + + Raises: + ValueError: if eager execution is enabled. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tf.reshape(features, [-1, 256, 32, 32, 3]) + net = tfp.layers.conv3d_flipout(net, + filters=64, + kernel_size=5, + padding="SAME", + activation=tf.nn.relu) + net = tf.layers.max_pooling2d(net, + pool_size=2, + strides=2, + padding="SAME") + net = tf.reshape(net, [-1, 256 * 8 * 8 * 64]) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = Conv3DFlipout( + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +# Aliases + +Convolution1DReparameterization = Conv1DReparameterization +Convolution2DReparameterization = Conv2DReparameterization +Convolution3DReparameterization = Conv3DReparameterization +convolution1d_reparameterization = conv1d_reparameterization +convolution2d_reparameterization = conv2d_reparameterization +convolution3d_reparameterization = conv3d_reparameterization +Convolution1DFlipout = Conv1DFlipout +Convolution2DFlipout = Conv2DFlipout +Convolution3DFlipout = Conv3DFlipout +convolution1d_flipout = conv1d_flipout +convolution2d_flipout = conv2d_flipout +convolution3d_flipout = conv3d_flipout diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py new file mode 100644 index 0000000000000000000000000000000000000000..591a8e553de0c194786c7ee8693665f762711b2d --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational.py @@ -0,0 +1,1176 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Dense Bayesian layer using KL-divergence based variational inference. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.bayesflow.python.ops import layers_util +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as layers_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import standard_ops +from tensorflow.python.ops.distributions import kullback_leibler as kl_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import util as distribution_util + + +class _DenseVariational(layers_lib.Layer): + """Abstract densely-connected class (private, used as implementation base). + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(_DenseVariational, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.units = units + self.activation = activation + self.input_spec = layers_lib.InputSpec(min_ndim=2) + self.kernel_posterior_fn = kernel_posterior_fn + self.kernel_posterior_tensor_fn = kernel_posterior_tensor_fn + self.kernel_prior_fn = kernel_prior_fn + self.kernel_divergence_fn = kernel_divergence_fn + self.bias_posterior_fn = bias_posterior_fn + self.bias_posterior_tensor_fn = bias_posterior_tensor_fn + self.bias_prior_fn = bias_prior_fn + self.bias_divergence_fn = bias_divergence_fn + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + in_size = input_shape.with_rank_at_least(2)[-1].value + if in_size is None: + raise ValueError("The last dimension of the inputs to `Dense` " + "should be defined. Found `None`.") + self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) + dtype = dtypes.as_dtype(self.dtype) + + # Must have a posterior kernel. + self.kernel_posterior = self.kernel_posterior_fn( + dtype, [in_size, self.units], "kernel_posterior", + self.trainable, self.add_variable) + + if self.kernel_prior_fn is None: + self.kernel_prior = None + else: + self.kernel_prior = self.kernel_prior_fn( + dtype, [in_size, self.units], "kernel_prior", + self.trainable, self.add_variable) + self._built_kernel_divergence = False + + if self.bias_posterior_fn is None: + self.bias_posterior = None + else: + self.bias_posterior = self.bias_posterior_fn( + dtype, [self.units], "bias_posterior", + self.trainable, self.add_variable) + + if self.bias_prior_fn is None: + self.bias_prior = None + else: + self.bias_prior = self.bias_prior_fn( + dtype, [self.units], "bias_prior", + self.trainable, self.add_variable) + self._built_bias_divergence = False + + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + + outputs = self._apply_variational_kernel(inputs) + outputs = self._apply_variational_bias(outputs) + if self.activation is not None: + outputs = self.activation(outputs) # pylint: disable=not-callable + if not self._built_kernel_divergence: + kernel_posterior = self.kernel_posterior + kernel_prior = self.kernel_prior + if isinstance(self.kernel_posterior, independent_lib.Independent): + kernel_posterior = kernel_posterior.distribution + if isinstance(self.kernel_prior, independent_lib.Independent): + kernel_prior = kernel_prior.distribution + self._apply_divergence(self.kernel_divergence_fn, + kernel_posterior, + kernel_prior, + self.kernel_posterior_tensor, + name="divergence_kernel") + self._built_kernel_divergence = True + if not self._built_bias_divergence: + bias_posterior = self.bias_posterior + bias_prior = self.bias_prior + if isinstance(self.bias_posterior, independent_lib.Independent): + bias_posterior = bias_posterior.distribution + if isinstance(self.bias_prior, independent_lib.Independent): + bias_prior = bias_prior.distribution + self._apply_divergence(self.bias_divergence_fn, + bias_posterior, + bias_prior, + self.bias_posterior_tensor, + name="divergence_bias") + self._built_bias_divergence = True + return outputs + + def _apply_variational_bias(self, inputs): + if self.bias_posterior is None: + self.bias_posterior_tensor = None + return inputs + self.bias_posterior_tensor = self.bias_posterior_tensor_fn( + self.bias_posterior) + return nn.bias_add(inputs, self.bias_posterior_tensor) + + def _apply_divergence(self, divergence_fn, posterior, prior, + posterior_tensor, name): + if (divergence_fn is None or + posterior is None or + prior is None): + divergence = None + return + divergence = standard_ops.identity( + divergence_fn( + posterior, prior, posterior_tensor), + name=name) + self.add_loss(divergence) + + def _matmul(self, inputs, kernel): + if inputs.shape.ndims <= 2: + return standard_ops.matmul(inputs, kernel) + # To handle broadcasting, we must use `tensordot`. + return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + "The innermost dimension of input_shape must be defined, " + "but saw: {}".format(input_shape)) + return input_shape[:-1].concatenate(self.units) + + +class DenseReparameterization(_DenseVariational): + """Densely-connected layer class with reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the reparameterization estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseReparameterization( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseReparameterization, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + + def _apply_variational_kernel(self, inputs): + self.kernel_posterior_tensor = self.kernel_posterior_tensor_fn( + self.kernel_posterior) + self.kernel_posterior_affine = None + self.kernel_posterior_affine_tensor = None + return self._matmul(inputs, self.kernel_posterior_tensor) + + +def dense_reparameterization( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn(is_singular=True), # pylint: disable=line-too-long + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected layer with reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the reparameterization estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_reparameterization( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Auto-Encoding Variational Bayes." + Diederik P. Kingma, Max Welling. + International Conference on Learning Representations, 2014. + """ + layer = DenseReparameterization( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class DenseLocalReparameterization(_DenseVariational): + """Densely-connected layer class with local reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the local reparameterization estimator [1], which performs a + Monte Carlo approximation of the distribution on the hidden units + induced by the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseLocalReparameterization( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseLocalReparameterization(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses local reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Variational Dropout and the Local Reparameterization Trick." + Diederik P. Kingma, Tim Salimans, Max Welling. + Neural Information Processing Systems, 2015. + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + **kwargs): + super(DenseLocalReparameterization, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`DenseLocalReparameterization` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format(self.kernel_posterior.name)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=self._matmul(inputs, self.kernel_posterior.distribution.loc), + scale=standard_ops.sqrt(self._matmul( + standard_ops.square(inputs), + standard_ops.square(self.kernel_posterior.distribution.scale)))) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + return self.kernel_posterior_affine_tensor + + +def dense_local_reparameterization( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + name=None, + reuse=None): + """Densely-connected layer with local reparameterization estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the local reparameterization estimator [1], which performs a + Monte Carlo approximation of the distribution on the hidden units + induced by the `kernel` and `bias`. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_local_reparameterization( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_local_reparameterization(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses local reparameterization gradients to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Variational Dropout and the Local Reparameterization Trick." + Diederik P. Kingma, Tim Salimans, Max Welling. + Neural Information Processing Systems, 2015. + """ + layer = DenseLocalReparameterization( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) + + +class DenseFlipout(_DenseVariational): + """Densely-connected layer class with Flipout estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the Flipout estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. Flipout uses roughly twice as many floating point operations + as the reparameterization estimator but has the advantage of + significantly lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (`callable`). + activity_regularizer: Regularizer function for the output. + kernel_posterior_fn: `callable` returning posterior. + kernel_posterior_tensor_fn: `callable` operating on posterior. + kernel_prior_fn: `callable` returning prior. + kernel_divergence_fn: `callable` returning divergence. + bias_posterior_fn: `callable` returning posterior. + bias_posterior_tensor_fn: `callable` operating on posterior. + bias_prior_fn: `callable` returning prior. + bias_divergence_fn: `callable` returning divergence. + seed: Python integer, used to create random seeds. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.DenseFlipout( + 512, activation=tf.nn.relu)(features) + logits = tfp.layers.DenseFlipout(10)(net) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + + def __init__( + self, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + **kwargs): + super(DenseFlipout, self).__init__( + units=units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + name=name, + **kwargs) + self.seed = seed + + def _apply_variational_kernel(self, inputs): + if (not isinstance(self.kernel_posterior, independent_lib.Independent) or + not isinstance(self.kernel_posterior.distribution, normal_lib.Normal)): + raise TypeError( + "`DenseFlipout` requires " + "`kernel_posterior_fn` produce an instance of " + "`tf.distributions.Independent(tf.distributions.Normal)` " + "(saw: \"{}\").".format(self.kernel_posterior.name)) + self.kernel_posterior_affine = normal_lib.Normal( + loc=array_ops.zeros_like(self.kernel_posterior.distribution.loc), + scale=self.kernel_posterior.distribution.scale) + self.kernel_posterior_affine_tensor = ( + self.kernel_posterior_tensor_fn(self.kernel_posterior_affine)) + self.kernel_posterior_tensor = None + + input_shape = array_ops.shape(inputs) + batch_shape = input_shape[:-1] + + sign_input = layers_util.random_sign( + input_shape, + dtype=inputs.dtype, + seed=self.seed) + sign_output = layers_util.random_sign( + array_ops.concat([batch_shape, + array_ops.expand_dims(self.units, 0)], 0), + dtype=inputs.dtype, + seed=distribution_util.gen_new_seed( + self.seed, salt="dense_flipout")) + perturbed_inputs = self._matmul( + inputs * sign_input, self.kernel_posterior_affine_tensor) * sign_output + + outputs = self._matmul(inputs, self.kernel_posterior.distribution.loc) + outputs += perturbed_inputs + return outputs + + +def dense_flipout( + inputs, + units, + activation=None, + activity_regularizer=None, + trainable=True, + kernel_posterior_fn=layers_util.default_mean_field_normal_fn(), + kernel_posterior_tensor_fn=lambda d: d.sample(), + kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda + loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), + kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + bias_posterior_fn=layers_util.default_mean_field_normal_fn( + is_singular=True), + bias_posterior_tensor_fn=lambda d: d.sample(), + bias_prior_fn=None, + bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), + seed=None, + name=None, + reuse=None): + """Densely-connected layer with Flipout estimator. + + This layer implements the Bayesian variational inference analogue to + a dense layer by assuming the `kernel` and/or the `bias` are drawn + from distributions. By default, the layer implements a stochastic + forward pass via sampling from the kernel and bias posteriors, + + ```none + kernel, bias ~ posterior + outputs = activation(matmul(inputs, kernel) + bias) + ``` + + It uses the Flipout estimator [1], which performs a Monte Carlo + approximation of the distribution integrating over the `kernel` and + `bias`. Flipout uses roughly twice as many floating point operations + as the reparameterization estimator but has the advantage of + significantly lower variance. + + The arguments permit separate specification of the surrogate posterior + (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` + distributions. + + Args: + inputs: Tensor input. + units: Integer or Long, dimensionality of the output space. + activation: Activation function (`callable`). Set it to None to maintain a + linear activation. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + kernel_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `kernel` parameter. Default value: + `default_mean_field_normal_fn()`. + kernel_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + kernel_prior_fn: Python `callable` which creates `tf.distributions` + instance. See `default_mean_field_normal_fn` docstring for required + parameter signature. + Default value: `tf.distributions.Normal(loc=0., scale=1.)`. + kernel_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + bias_posterior_fn: Python `callable` which creates + `tf.distributions.Distribution` instance representing the surrogate + posterior of the `bias` parameter. Default value: + `default_mean_field_normal_fn(is_singular=True)` (which creates an + instance of `tf.distributions.Deterministic`). + bias_posterior_tensor_fn: Python `callable` which takes a + `tf.distributions.Distribution` instance and returns a representative + value. Default value: `lambda d: d.sample()`. + bias_prior_fn: Python `callable` which creates `tf.distributions` instance. + See `default_mean_field_normal_fn` docstring for required parameter + signature. Default value: `None` (no prior, no variational inference) + bias_divergence_fn: Python `callable` which takes the surrogate posterior + distribution, prior distribution and random variate sample(s) from the + surrogate posterior and computes or approximates the KL divergence. The + distributions are `tf.distributions.Distribution`-like instances and the + sample is a `Tensor`. + seed: Python scalar `int` which initializes the random number + generator. Default value: `None` (i.e., use global seed). + name: Python `str`, the name of the layer. Layers with the same name will + share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in + such cases. + reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous + layer by the same name. + + Returns: + output: `Tensor` representing a the affine transformed input under a random + draw from the surrogate posterior distribution. + + #### Examples + + We illustrate a Bayesian neural network with [variational inference]( + https://en.wikipedia.org/wiki/Variational_Bayesian_methods), + assuming a dataset of `features` and `labels`. + + ```python + tfp = tf.contrib.bayesflow + + net = tfp.layers.dense_flipout( + features, 512, activation=tf.nn.relu) + logits = tfp.layers.dense_flipout(net, 10) + neg_log_likelihood = tf.nn.softmax_cross_entropy_with_logits( + labels=labels, logits=logits) + kl = sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) + loss = neg_log_likelihood + kl + train_op = tf.train.AdamOptimizer().minimize(loss) + ``` + + It uses the Flipout gradient estimator to minimize the + Kullback-Leibler divergence up to a constant, also known as the + negative Evidence Lower Bound. It consists of the sum of two terms: + the expected negative log-likelihood, which we approximate via + Monte Carlo; and the KL divergence, which is added via regularizer + terms which are arguments to the layer. + + [1]: "Flipout: Efficient Pseudo-Independent Weight Perturbations on + Mini-Batches." + Anonymous. OpenReview, 2017. + https://openreview.net/forum?id=rJnpifWAb + """ + layer = DenseFlipout( + units, + activation=activation, + activity_regularizer=activity_regularizer, + trainable=trainable, + kernel_posterior_fn=kernel_posterior_fn, + kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, + kernel_prior_fn=kernel_prior_fn, + kernel_divergence_fn=kernel_divergence_fn, + bias_posterior_fn=bias_posterior_fn, + bias_posterior_tensor_fn=bias_posterior_tensor_fn, + bias_prior_fn=bias_prior_fn, + bias_divergence_fn=bias_divergence_fn, + seed=seed, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) + return layer.apply(inputs) diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py b/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py deleted file mode 100644 index b05ce0ffc1dd55ffb029b339a846a9aa5c877620..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/layers_dense_variational_impl.py +++ /dev/null @@ -1,797 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Dense Bayesian layer using KL-divergence based variational inference. - -@@DenseVariational -@@dense_variational - -@@default_loc_scale_fn -@@default_mean_field_normal_fn -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.layers import base as layers_lib -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import nn -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import standard_ops -from tensorflow.python.ops.distributions import kullback_leibler as kl_lib -from tensorflow.python.ops.distributions import normal as normal_lib - - -__all__ = [ - "DenseVariational", - "dense_variational", - "default_loc_scale_fn", - "default_mean_field_normal_fn", -] - - -def default_loc_scale_fn( - is_singular=False, - loc_initializer=init_ops.random_normal_initializer(stddev=0.1), - untransformed_scale_initializer=init_ops.random_normal_initializer( - mean=-3., stddev=0.1), - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. - - This function produces a closure which produces `loc`, `scale` using - `tf.get_variable`. The closure accepts the following arguments: - - dtype: Type of parameter's event. - shape: Python `list`-like representing the parameter's event shape. - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` indicating if `scale is None`. Default: `False`. - loc_initializer: Initializer function for the `loc` parameters. - The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. Default value: `tf.random_normal_initializer(mean=-3., - stddev=0.1)`. This implies the softplus transformed result has mean - approximately `0.05` and std. deviation approximately `0.005`. - loc_regularizer: Regularizer function for the `loc` parameters. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. The default (`None`) is to use the `tf.get_variable` default. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - The default (`None`) is to use the `tf.get_variable` default. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. The default - (`None`) is to use the `tf.get_variable` default. - - Returns: - default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` - parameters from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates `loc`, `scale` parameters.""" - loc = add_variable_fn( - name=name + "_loc", - shape=shape, - initializer=loc_initializer, - regularizer=loc_regularizer, - constraint=loc_constraint, - dtype=dtype, - trainable=trainable) - if is_singular: - return loc, None - untransformed_scale = add_variable_fn( - name=name + "_untransformed_scale", - shape=shape, - initializer=untransformed_scale_initializer, - regularizer=untransformed_scale_regularizer, - constraint=untransformed_scale_constraint, - dtype=dtype, - trainable=trainable) - scale = (np.finfo(dtype.as_numpy_dtype).eps + - nn_ops.softplus(untransformed_scale)) - return loc, scale - return _fn - - -def default_mean_field_normal_fn( - is_singular=False, - loc_initializer=None, - untransformed_scale_initializer=None, - loc_regularizer=None, - untransformed_scale_regularizer=None, - loc_constraint=None, - untransformed_scale_constraint=None): - """Creates a function to build Normal distributions with trainable params. - - This function produces a closure which produces `tf.distributions.Normal` - parameterized by a loc` and `scale` each created using `tf.get_variable`. The - produced closure accepts the following arguments: - - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Args: - is_singular: Python `bool` if `True`, forces the special case limit of - `scale->0`, i.e., a `Deterministic` distribution. - loc_initializer: Initializer function for the `loc` parameters. - If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - untransformed_scale_initializer: Initializer function for the `scale` - parameters. If `None` (default), values are initialized using the default - initializer used by `tf.get_variable`. - loc_regularizer: Regularizer function for the `loc` parameters. - untransformed_scale_regularizer: Regularizer function for the `scale` - parameters. - loc_constraint: An optional projection function to be applied to the - loc after being updated by an `Optimizer`. The function must take as input - the unprojected variable and must return the projected variable (which - must have the same shape). Constraints are not safe to use when doing - asynchronous distributed training. - untransformed_scale_constraint: An optional projection function to be - applied to the `scale` parameters after being updated by an `Optimizer` - (e.g. used to implement norm constraints or value constraints). The - function must take as input the unprojected variable and must return the - projected variable (which must have the same shape). Constraints are not - safe to use when doing asynchronous distributed training. - - Returns: - make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` - using from args: `dtype, shape, name, trainable, add_variable_fn`. - """ - loc_scale_fn_ = default_loc_scale_fn( - is_singular, - loc_initializer, - untransformed_scale_initializer, - loc_regularizer, - untransformed_scale_regularizer, - loc_constraint, - untransformed_scale_constraint) - def _fn(dtype, shape, name, trainable, add_variable_fn): - """Creates a batch of `Deterministic` or `Normal` distributions.""" - loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) - if scale is None: - return deterministic_lib.Deterministic(loc=loc) - return normal_lib.Normal(loc=loc, scale=scale) - return _fn - - -class DenseVariational(layers_lib.Layer): - """Densely-connected variational class. - - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. - - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., - - ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] - ``` - - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. - - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). - - Args: - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - When `True`, `kernel_posterior_fn` must create an instance of - `tf.distributions.Normal`. - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Properties: - units: Python integer, dimensionality of the output space. - activation: Activation function (`callable`). - activity_regularizer: Regularizer function for the output. - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - kernel: `VariationalKernelParamater` instance containing all `kernel` - related properties and `callable`s. - bias: `VariationalParameter` instance containing all `kernel` - related properties and `callable`s. - """ - - def __init__( - self, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - **kwargs): - super(DenseVariational, self).__init__( - trainable=trainable, - name=name, - activity_regularizer=activity_regularizer, - **kwargs) - self._units = units - self._activation = activation - self._input_spec = layers_lib.InputSpec(min_ndim=2) - self._kernel_use_local_reparameterization = ( - kernel_use_local_reparameterization) - self._kernel = VariationalKernelParameter( - kernel_posterior_fn, - kernel_posterior_tensor_fn, - kernel_prior_fn, - kernel_divergence_fn) - self._bias = VariationalParameter( - bias_posterior_fn, - bias_posterior_tensor_fn, - bias_prior_fn, - bias_divergence_fn) - - @property - def units(self): - return self._units - - @property - def activation(self): - return self._activation - - @property - def input_spec(self): - return self._input_spec - - @input_spec.setter - def input_spec(self, value): - self._input_spec = value - - @property - def kernel_use_local_reparameterization(self): - return self._kernel_use_local_reparameterization - - @property - def kernel(self): - return self._kernel - - @property - def bias(self): - return self._bias - - def build(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape) - in_size = input_shape.with_rank_at_least(2)[-1].value - if in_size is None: - raise ValueError("The last dimension of the inputs to `Dense` " - "should be defined. Found `None`.") - self._input_spec = layers_lib.InputSpec(min_ndim=2, axes={-1: in_size}) - dtype = dtypes.as_dtype(self.dtype) - - # Must have a posterior kernel. - self.kernel.posterior = self.kernel.posterior_fn( - dtype, [in_size, self.units], "kernel_posterior", - self.trainable, self.add_variable) - - if self.kernel.prior_fn is None: - self.kernel_prior = None - else: - self.kernel.prior = self.kernel.prior_fn( - dtype, [in_size, self.units], "kernel_prior", - self.trainable, self.add_variable) - self._built_kernel_divergence = False - - if self.bias.posterior_fn is None: - self.bias.posterior = None - else: - self.bias.posterior = self.bias.posterior_fn( - dtype, [self.units], "bias_posterior", - self.trainable, self.add_variable) - - if self.bias.prior_fn is None: - self.bias.prior = None - else: - self.bias.prior = self.bias.prior_fn( - dtype, [self.units], "bias_prior", - self.trainable, self.add_variable) - self._built_bias_divergence = False - - self.built = True - - def call(self, inputs): - inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) - - outputs = self._apply_variational_kernel(inputs) - outputs = self._apply_variational_bias(outputs) - if self.activation is not None: - outputs = self.activation(outputs) # pylint: disable=not-callable - if not self._built_kernel_divergence: - self._apply_divergence(self.kernel, name="divergence_kernel") - self._built_kernel_divergence = True - if not self._built_bias_divergence: - self._apply_divergence(self.bias, name="divergence_bias") - self._built_bias_divergence = True - return outputs - - def _apply_variational_kernel(self, inputs): - if not self.kernel_use_local_reparameterization: - self.kernel.posterior_tensor = self.kernel.posterior_tensor_fn( - self.kernel.posterior) - self.kernel.posterior_affine = None - self.kernel.posterior_affine_tensor = None - return self._matmul(inputs, self.kernel.posterior_tensor) - if not isinstance(self.kernel.posterior, normal_lib.Normal): - raise TypeError("`kernel_use_local_reparameterization=True` requires " - "`kernel_posterior_fn` produce an instance of " - "`tf.distributions.Normal` (saw: \"{}\").".format( - type(self.kernel.posterior).__name__)) - self.kernel.posterior_affine = normal_lib.Normal( - loc=self._matmul(inputs, self.kernel.posterior.loc), - scale=standard_ops.sqrt(self._matmul( - standard_ops.square(inputs), - standard_ops.square(self.kernel.posterior.scale)))) - self.kernel.posterior_affine_tensor = ( - self.kernel.posterior_tensor_fn(self.kernel.posterior_affine)) - self.kernel.posterior_tensor = None - return self.kernel.posterior_affine_tensor - - def _apply_variational_bias(self, inputs): - if self.bias.posterior is None: - self.bias.posterior_tensor = None - return inputs - self.bias.posterior_tensor = self.bias.posterior_tensor_fn( - self.bias.posterior) - return nn.bias_add(inputs, self.bias.posterior_tensor) - - def _apply_divergence(self, param, name): - if (param.divergence_fn is None or - param.posterior is None or - param.prior is None): - param.divergence = None - return - param.divergence = standard_ops.identity( - param.divergence_fn( - param.posterior, param.prior, param.posterior_tensor), - name=name) - self.add_loss(param.divergence) - - def _matmul(self, inputs, kernel): - if inputs.shape.ndims <= 2: - return standard_ops.matmul(inputs, kernel) - # To handle broadcasting, we must use `tensordot`. - return standard_ops.tensordot(inputs, kernel, axes=[[-1], [0]]) - - def _compute_output_shape(self, input_shape): - input_shape = tensor_shape.TensorShape(input_shape).with_rank_at_least(2) - if input_shape[-1].value is None: - raise ValueError( - "The innermost dimension of input_shape must be defined, " - "but saw: {}".format(input_shape)) - return input_shape[:-1].concatenate(self.units) - - -def dense_variational( - inputs, - units, - activation=None, - activity_regularizer=None, - trainable=True, - kernel_use_local_reparameterization=True, - kernel_posterior_fn=default_mean_field_normal_fn(), - kernel_posterior_tensor_fn=lambda d: d.sample(), - kernel_prior_fn=lambda dtype, *args: normal_lib.Normal( # pylint: disable=g-long-lambda - loc=dtype.as_numpy_dtype(0.), scale=dtype.as_numpy_dtype(1.)), - kernel_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - bias_posterior_fn=default_mean_field_normal_fn(is_singular=True), - bias_posterior_tensor_fn=lambda d: d.sample(), - bias_prior_fn=None, - bias_divergence_fn=lambda q, p, ignore: kl_lib.kl_divergence(q, p), - name=None, - reuse=None): - """Densely-connected variational layer. - - This layer implements the Bayesian variational inference analogue to: - `outputs = activation(matmul(inputs, kernel) + bias)` - by assuming the `kernel` and/or the `bias` are random variables. - - The layer implements a stochastic dense calculation by making a Monte Carlo - approximation of a [variational Bayesian method based on KL divergence]( - https://en.wikipedia.org/wiki/Variational_Bayesian_methods), i.e., - - ```none - -log p(y|x) = -log int_{R**d} p(y|x,w) p(w) dw - = -log int_{R**d} p(y,w|x) q(w|x) / q(w|x) dw - <= E_q(W|x)[-log p(y,W|x) + log q(W|x)] # Jensen's - = E_q(W|x)[-log p(y|x,W)] + KL[q(W|x), p(W)] - ~= m**-1 sum{ -log(y|x,w[j]) : w[j] ~ q(W|x), j=1..m } - + KL[q(W|x), p(W)] - ``` - - where `W` denotes the (independent) `kernel` and `bias` random variables, `w` - is a random variate or outcome of `W`, `y` is the label, `x` is the evidence`, - and `~=` denotes an approximation which becomes exact as `m->inf`. The above - bound is sometimes referred to as the negative Evidence Lower BOund or - negative [ELBO](https://arxiv.org/abs/1601.00670). In context of a DNN, this - layer is appropriate to use when the final loss is a negative log-likelihood. - - The Monte-Carlo sum portion is used for the feed-forward calculation of the - DNN. The KL divergence portion can be added to the final loss via: - `loss += sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))`. - - The arguments permit separate specification of the surrogate posterior - (`q(W|x)`), prior (`p(W)`), and divergence for both the `kernel` and `bias` - random variables (which together comprise `W`). - - Args: - inputs: Tensor input. - units: Integer or Long, dimensionality of the output space. - activation: Activation function (`callable`). Set it to None to maintain a - linear activation. - activity_regularizer: Regularizer function for the output. - trainable: Boolean, if `True` also add variables to the graph collection - `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). - kernel_use_local_reparameterization: Python `bool` indicating whether - `kernel` calculation should employ the Local Reparameterization Trick. - When `True`, `kernel_posterior_fn` must create an instance of - `tf.distributions.Normal`. - kernel_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `kernel` parameter. Default value: - `default_mean_field_normal_fn()`. - kernel_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - kernel_prior_fn: Python `callable` which creates `tf.distributions` - instance. See `default_mean_field_normal_fn` docstring for required - parameter signature. - Default value: `tf.distributions.Normal(loc=0., scale=1.)`. - kernel_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - bias_posterior_fn: Python `callable` which creates - `tf.distributions.Distribution` instance representing the surrogate - posterior of the `bias` parameter. Default value: - `default_mean_field_normal_fn(is_singular=True)` (which creates an - instance of `tf.distributions.Deterministic`). - bias_posterior_tensor_fn: Python `callable` which takes a - `tf.distributions.Distribution` instance and returns a representative - value. Default value: `lambda d: d.sample()`. - bias_prior_fn: Python `callable` which creates `tf.distributions` instance. - See `default_mean_field_normal_fn` docstring for required parameter - signature. Default value: `None` (no prior, no variational inference) - bias_divergence_fn: Python `callable` which takes the surrogate posterior - distribution, prior distribution and random variate sample(s) from the - surrogate posterior and computes or approximates the KL divergence. The - distributions are `tf.distributions.Distribution`-like instances and the - sample is a `Tensor`. - name: Python `str`, the name of the layer. Layers with the same name will - share `tf.Variable`s, but to avoid mistakes we require `reuse=True` in - such cases. - reuse: Python `bool`, whether to reuse the `tf.Variable`s of a previous - layer by the same name. - - Returns: - output: `Tensor` representing a the affine transformed input under a random - draw from the surrogate posterior distribution. - """ - layer = DenseVariational( - units, - activation=activation, - activity_regularizer=activity_regularizer, - trainable=trainable, - kernel_use_local_reparameterization=( - kernel_use_local_reparameterization), - kernel_posterior_fn=kernel_posterior_fn, - kernel_posterior_tensor_fn=kernel_posterior_tensor_fn, - kernel_prior_fn=kernel_prior_fn, - kernel_divergence_fn=kernel_divergence_fn, - bias_posterior_fn=bias_posterior_fn, - bias_posterior_tensor_fn=bias_posterior_tensor_fn, - bias_prior_fn=bias_prior_fn, - bias_divergence_fn=bias_divergence_fn, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) - return layer.apply(inputs) - - -class NotSet(object): - """Helper to track whether a `VariationalParameter` value has been set.""" - pass - - -class VariationalParameter(object): - """Struct-like container of variational parameter properties. - - A `VariationalParameter` is intitialized with Python `callable`s which set the - value of correspondingly named members. Corresponding values have "set once" - semantics, i.e., once set to any value they are immutable. - """ - - def __init__( - self, - posterior_fn, - posterior_tensor_fn, - prior_fn, - divergence_fn): - """Creates the `VariationalParameter` struct-like object. - - Args: - posterior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the posterior - distribution. See `VariationalParameter.posterior_fn` for `callable`'s - required parameters. - posterior_tensor_fn: Python `callable` which computes a `Tensor` - which represents the `posterior`. - prior_fn: Python `callable` which creates a - `tf.distribution.Distribution` like object representing the prior - distribution. See `VariationalParameter.prior_fn` for `callable`'s - required parameters. - divergence_fn: Python `callable` which computes the KL divergence from - `posterior` to `prior`. See `VariationalParameter.divergence_fn` for - required `callable`'s parameters. - """ - self._posterior_fn = posterior_fn - self._posterior = NotSet() - self._posterior_tensor_fn = posterior_tensor_fn - self._posterior_tensor = NotSet() - self._prior_fn = prior_fn - self._prior = NotSet() - self._divergence_fn = divergence_fn - self._divergence = NotSet() - self._init_helper() - - @property - def posterior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like posterior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - posterior_fn: The Python `callable` specified in `__init__`. - """ - return self._posterior_fn - - @property - def posterior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._posterior - - @posterior.setter - def posterior(self, value): - """One-time setter of the `posterior` distribution.""" - if not isinstance(self._posterior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior = value - - @property - def posterior_tensor_fn(self): - """Creates `Tensor` representing the `posterior` distribution. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - - Returns: - posterior_tensor_fn: The Python `callable` specified in - `__init__`. - """ - return self._posterior_tensor_fn - - @property - def posterior_tensor(self): - """`Tensor` representing the `posterior` distribution.""" - return self._posterior_tensor - - @posterior_tensor.setter - def posterior_tensor(self, value): - """One-time setter of the `posterior_tensor`.""" - if not isinstance(self._posterior_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_tensor = value - - @property - def prior_fn(self): - """`callable` which creates `tf.distributions.Distribution`-like prior. - - The `callable` must accept the following parameters: - name: Python `str` name prepended to any created (or existing) - `tf.Variable`s. - shape: Python `list`-like representing the parameter's event shape. - dtype: Type of parameter's event. - trainable: Python `bool` indicating all created `tf.Variable`s should be - added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. - add_variable_fn: `tf.get_variable`-like `callable` used to create (or - access existing) `tf.Variable`s. - - Returns: - prior_fn: The Python `callable` specified in `__init__`. - """ - return self._prior_fn - - @property - def prior(self): - """`tf.distributions.Distribution`-like instance representing posterior.""" - return self._prior - - @prior.setter - def prior(self, value): - """One-time setter of the `prior` distribution.""" - if not isinstance(self._prior, NotSet): - raise ValueError("Cannot override already set attribute.") - self._prior = value - - @property - def divergence_fn(self): - """`callable` which computes KL-divergence `Tensor` from posterior to prior. - - The `callable` must accept the following parameters: - posterior: `tf.distributions.Distribution`-like instance. - prior: `tf.distributions.Distribution`-like instance. - posterior_tensor: `Tensor` representing value of posterior. - - Returns: - divergence_fn: The Python `callable` specified in `__init__`. - """ - return self._divergence_fn - - @property - def divergence(self): - """`Tensor` representing KL-divergence from posterior to prior.""" - return self._divergence - - @divergence.setter - def divergence(self, value): - """One-time setter of the `divergence`.""" - if not isinstance(self._divergence, NotSet): - raise ValueError("Cannot override already set attribute.") - self._divergence = value - - def _init_helper(self): - pass - - -class VariationalKernelParameter(VariationalParameter): - """Struct-like container of variational kernel properties. - - A `VariationalKernelParameter` is intitialized with Python `callable`s which - set the value of correspondingly named members. Corresponding values have "set - once" semantics, i.e., once set to any value they are immutable. - """ - - @property - def posterior_affine(self): - """`tf.distributions.Distribution` affine transformed posterior.""" - return self._posterior_affine - - @posterior_affine.setter - def posterior_affine(self, value): - """One-time setter of `posterior_affine`.""" - if not isinstance(self._posterior_affine, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine = value - - @property - def posterior_affine_tensor(self): - """`Tensor` representing the `posterior_affine` distribution.""" - return self._posterior_affine_tensor - - @posterior_affine_tensor.setter - def posterior_affine_tensor(self, value): - """One-time setter of the `posterior_affine_tensor`.""" - if not isinstance(self._posterior_affine_tensor, NotSet): - raise ValueError("Cannot override already set attribute.") - self._posterior_affine_tensor = value - - def _init_helper(self): - self._posterior_affine = NotSet() - self._posterior_affine_tensor = NotSet() diff --git a/tensorflow/contrib/bayesflow/python/ops/layers_util.py b/tensorflow/contrib/bayesflow/python/ops/layers_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8c1fb203f7328e8260e49b4326d813fbe133613e --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/layers_util.py @@ -0,0 +1,191 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for probabilistic layers. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import deterministic as deterministic_lib +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import normal as normal_lib + + +def default_loc_scale_fn( + is_singular=False, + loc_initializer=init_ops.random_normal_initializer(stddev=0.1), + untransformed_scale_initializer=init_ops.random_normal_initializer( + mean=-3., stddev=0.1), + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Makes closure which creates `loc`, `scale` params from `tf.get_variable`. + + This function produces a closure which produces `loc`, `scale` using + `tf.get_variable`. The closure accepts the following arguments: + + dtype: Type of parameter's event. + shape: Python `list`-like representing the parameter's event shape. + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` indicating if `scale is None`. Default: `False`. + loc_initializer: Initializer function for the `loc` parameters. + The default is `tf.random_normal_initializer(mean=0., stddev=0.1)`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. Default value: `tf.random_normal_initializer(mean=-3., + stddev=0.1)`. This implies the softplus transformed result has mean + approximately `0.05` and std. deviation approximately `0.005`. + loc_regularizer: Regularizer function for the `loc` parameters. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. The default (`None`) is to use the `tf.get_variable` default. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + The default (`None`) is to use the `tf.get_variable` default. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. The default + (`None`) is to use the `tf.get_variable` default. + + Returns: + default_loc_scale_fn: Python `callable` which instantiates `loc`, `scale` + parameters from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates `loc`, `scale` parameters.""" + loc = add_variable_fn( + name=name + "_loc", + shape=shape, + initializer=loc_initializer, + regularizer=loc_regularizer, + constraint=loc_constraint, + dtype=dtype, + trainable=trainable) + if is_singular: + return loc, None + untransformed_scale = add_variable_fn( + name=name + "_untransformed_scale", + shape=shape, + initializer=untransformed_scale_initializer, + regularizer=untransformed_scale_regularizer, + constraint=untransformed_scale_constraint, + dtype=dtype, + trainable=trainable) + scale = (np.finfo(dtype.as_numpy_dtype).eps + + nn_ops.softplus(untransformed_scale)) + return loc, scale + return _fn + + +def default_mean_field_normal_fn( + is_singular=False, + loc_initializer=None, + untransformed_scale_initializer=None, + loc_regularizer=None, + untransformed_scale_regularizer=None, + loc_constraint=None, + untransformed_scale_constraint=None): + """Creates a function to build Normal distributions with trainable params. + + This function produces a closure which produces `tf.distributions.Normal` + parameterized by a loc` and `scale` each created using `tf.get_variable`. The + produced closure accepts the following arguments: + + name: Python `str` name prepended to any created (or existing) + `tf.Variable`s. + shape: Python `list`-like representing the parameter's event shape. + dtype: Type of parameter's event. + trainable: Python `bool` indicating all created `tf.Variable`s should be + added to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. + add_variable_fn: `tf.get_variable`-like `callable` used to create (or + access existing) `tf.Variable`s. + + Args: + is_singular: Python `bool` if `True`, forces the special case limit of + `scale->0`, i.e., a `Deterministic` distribution. + loc_initializer: Initializer function for the `loc` parameters. + If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + untransformed_scale_initializer: Initializer function for the `scale` + parameters. If `None` (default), values are initialized using the default + initializer used by `tf.get_variable`. + loc_regularizer: Regularizer function for the `loc` parameters. + untransformed_scale_regularizer: Regularizer function for the `scale` + parameters. + loc_constraint: An optional projection function to be applied to the + loc after being updated by an `Optimizer`. The function must take as input + the unprojected variable and must return the projected variable (which + must have the same shape). Constraints are not safe to use when doing + asynchronous distributed training. + untransformed_scale_constraint: An optional projection function to be + applied to the `scale` parameters after being updated by an `Optimizer` + (e.g. used to implement norm constraints or value constraints). The + function must take as input the unprojected variable and must return the + projected variable (which must have the same shape). Constraints are not + safe to use when doing asynchronous distributed training. + + Returns: + make_normal_fn: Python `callable` which creates a `tf.distributions.Normal` + using from args: `dtype, shape, name, trainable, add_variable_fn`. + """ + loc_scale_fn_ = default_loc_scale_fn( + is_singular, + loc_initializer, + untransformed_scale_initializer, + loc_regularizer, + untransformed_scale_regularizer, + loc_constraint, + untransformed_scale_constraint) + def _fn(dtype, shape, name, trainable, add_variable_fn): + """Creates multivariate `Deterministic` or `Normal` distribution.""" + loc, scale = loc_scale_fn_(dtype, shape, name, trainable, add_variable_fn) + if scale is None: + dist = deterministic_lib.Deterministic(loc=loc) + else: + dist = normal_lib.Normal(loc=loc, scale=scale) + reinterpreted_batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] + return independent_lib.Independent( + dist, reinterpreted_batch_ndims=reinterpreted_batch_ndims) + return _fn + + +def random_sign(shape, dtype=dtypes.float32, seed=None): + """Draw values from {-1, 1} uniformly, i.e., Rademacher distribution.""" + random_bernoulli = random_ops.random_uniform(shape, minval=0, maxval=2, + dtype=dtypes.int32, + seed=seed) + return math_ops.cast(2 * random_bernoulli - 1, dtype) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a645eafc249d1c39e0d4a238ae7ec8755c78d8 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics.py @@ -0,0 +1,32 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Markov Chain Monte Carlo (MCMC) sampling.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.bayesflow.python.ops.mcmc_diagnostics_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "effective_sample_size", + "potential_scale_reduction", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..0424b6952bc89ce7fe5b00b0135c9a5fe1faa8cf --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/mcmc_diagnostics_impl.py @@ -0,0 +1,400 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for Markov Chain Monte Carlo (MCMC) sampling. + +@@effective_sample_size +@@potential_scale_reduction +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import sample_stats +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + +__all__ = [ + "effective_sample_size", + "potential_scale_reduction", +] + + +def effective_sample_size(states, + filter_threshold=0., + filter_beyond_lag=None, + name=None): + """Estimate a lower bound on effective sample size for each independent chain. + + Roughly speaking, "effective sample size" (ESS) is the size of an iid sample + with the same variance as `state`. + + More precisely, given a stationary sequence of possibly correlated random + variables `X_1, X_2,...,X_N`, each identically distributed ESS is the number + such that + + ```Variance{ N**-1 * Sum{X_i} } = ESS**-1 * Variance{ X_1 }.``` + + If the sequence is uncorrelated, `ESS = N`. In general, one should expect + `ESS <= N`, with more highly correlated sequences having smaller `ESS`. + + #### Example of using ESS to estimate standard error. + + ``` + tfd = tf.contrib.distributions + tfb = tf.contrib.bayesflow + + target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) + + # Get 1000 states from one chain. + states = tfb.hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target.log_prob, + current_state=tf.constant([0., 0.]), + step_size=0.05, + num_leapfrog_steps=20, + num_burnin_steps=200) + states.shape + ==> (1000, 2) + + ess = effective_sample_size(states) + ==> Shape (2,) Tensor + + mean, variance = tf.nn.moments(states, axis=0) + standard_error = tf.sqrt(variance / ess) + ``` + + Some math shows that, with `R_k` the auto-correlation sequence, + `R_k := Covariance{X_1, X_{1+k}} / Variance{X_1}`, we have + + ```ESS(N) = N / [ 1 + 2 * ( (N - 1) / N * R_1 + ... + 1 / N * R_{N-1} ) ]``` + + This function estimates the above by first estimating the auto-correlation. + Since `R_k` must be estimated using only `N - k` samples, it becomes + progressively noisier for larger `k`. For this reason, the summation over + `R_k` should be truncated at some number `filter_beyond_lag < N`. Since many + MCMC methods generate chains where `R_k > 0`, a reasonable critera is to + truncate at the first index where the estimated auto-correlation becomes + negative. + + The arguments `filter_beyond_lag`, `filter_threshold` are filters intended to + remove noisy tail terms from `R_k`. They combine in an "OR" manner meaning + terms are removed if they were to be filtered under the `filter_beyond_lag` OR + `filter_threshold` criteria. + + Args: + states: `Tensor` or list of `Tensor` objects. Dimension zero should index + identically distributed states. + filter_threshold: `Tensor` or list of `Tensor` objects. + Must broadcast with `state`. The auto-correlation sequence is truncated + after the first appearance of a term less than `filter_threshold`. + Setting to `None` means we use no threshold filter. Since `|R_k| <= 1`, + setting to any number less than `-1` has the same effect. + filter_beyond_lag: `Tensor` or list of `Tensor` objects. Must be + `int`-like and scalar valued. The auto-correlation sequence is truncated + to this length. Setting to `None` means we do not filter based on number + of lags. + name: `String` name to prepend to created ops. + + Returns: + ess: `Tensor` or list of `Tensor` objects. The effective sample size of + each component of `states`. Shape will be `states.shape[1:]`. + + Raises: + ValueError: If `states` and `filter_threshold` or `states` and + `filter_beyond_lag` are both lists with different lengths. + """ + states_was_list = _is_list_like(states) + + # Convert all args to lists. + if not states_was_list: + states = [states] + + filter_beyond_lag = _broadcast_maybelist_arg(states, filter_beyond_lag, + "filter_beyond_lag") + filter_threshold = _broadcast_maybelist_arg(states, filter_threshold, + "filter_threshold") + + # Process items, one at a time. + with ops.name_scope(name, "effective_sample_size"): + ess_list = [ + _effective_sample_size_single_state(s, ml, mlt) + for (s, ml, mlt) in zip(states, filter_beyond_lag, filter_threshold) + ] + + if states_was_list: + return ess_list + return ess_list[0] + + +def _effective_sample_size_single_state(states, filter_beyond_lag, + filter_threshold): + """ESS computation for one single Tensor argument.""" + + with ops.name_scope( + "effective_sample_size_single_state", + values=[states, filter_beyond_lag, filter_threshold]): + + states = ops.convert_to_tensor(states, name="states") + dt = states.dtype + + # filter_beyond_lag == None ==> auto_corr is the full sequence. + auto_corr = sample_stats.auto_correlation( + states, axis=0, max_lags=filter_beyond_lag) + if filter_threshold is not None: + filter_threshold = ops.convert_to_tensor( + filter_threshold, dtype=dt, name="filter_threshold") + # Get a binary mask to zero out values of auto_corr below the threshold. + # mask[i, ...] = 1 if auto_corr[j, ...] > threshold for all j <= i, + # mask[i, ...] = 0, otherwise. + # So, along dimension zero, the mask will look like [1, 1, ..., 0, 0,...] + # Building step by step, + # Assume auto_corr = [1, 0.5, 0.0, 0.3], and filter_threshold = 0.2. + # Step 1: mask = [False, False, True, False] + mask = auto_corr < filter_threshold + # Step 2: mask = [0, 0, 1, 1] + mask = math_ops.cast(mask, dtype=dt) + # Step 3: mask = [0, 0, 1, 2] + mask = math_ops.cumsum(mask, axis=0) + # Step 4: mask = [1, 1, 0, 0] + mask = math_ops.maximum(1. - mask, 0.) + auto_corr *= mask + + # With R[k] := auto_corr[k, ...], + # ESS = N / {1 + 2 * Sum_{k=1}^N (N - k) / N * R[k]} + # = N / {-1 + 2 * Sum_{k=0}^N (N - k) / N * R[k]} (since R[0] = 1) + # approx N / {-1 + 2 * Sum_{k=0}^M (N - k) / N * R[k]} + # where M is the filter_beyond_lag truncation point chosen above. + + # Get the factor (N - k) / N, and give it shape [M, 1,...,1], having total + # ndims the same as auto_corr + n = _axis_size(states, axis=0) + k = math_ops.range(0., _axis_size(auto_corr, axis=0)) + nk_factor = (n - k) / n + if auto_corr.shape.ndims is not None: + new_shape = [-1] + [1] * (auto_corr.shape.ndims - 1) + else: + new_shape = array_ops.concat( + ([-1], + array_ops.ones([array_ops.rank(auto_corr) - 1], dtype=dtypes.int32)), + axis=0) + nk_factor = array_ops.reshape(nk_factor, new_shape) + + return n / (-1 + 2 * math_ops.reduce_sum(nk_factor * auto_corr, axis=0)) + + +def potential_scale_reduction(chains_states, + independent_chain_ndims=1, + name=None): + """Gelman and Rubin's potential scale reduction factor for chain convergence. + + Given `N > 1` states from each of `C > 1` independent chains, the potential + scale reduction factor, commonly referred to as R-hat, measures convergence of + the chains (to the same target) by testing for equality of means. + Specifically, R-hat measures the degree to which variance (of the means) + between chains exceeds what one would expect if the chains were identically + distributed. See [1], [2]. + + Some guidelines: + + * The initial state of the chains should be drawn from a distribution + overdispersed with respect to the target. + * If all chains converge to the target, then as `N --> infinity`, R-hat --> 1. + Before that, R-hat > 1 (except in pathological cases, e.g. if the chain + paths were identical). + * The above holds for any number of chains `C > 1`. Increasing `C` does + improves effectiveness of the diagnostic. + * Sometimes, R-hat < 1.2 is used to indicate approximate convergence, but of + course this is problem depedendent. See [2]. + * R-hat only measures non-convergence of the mean. If higher moments, or other + statistics are desired, a different diagnostic should be used. See [2]. + + #### Examples + + Diagnosing convergence by monitoring 10 chains that each attempt to + sample from a 2-variate normal. + + ```python + tfd = tf.contrib.distributions + tfb = tf.contrib.bayesflow + + target = tfd.MultivariateNormalDiag(scale_diag=[1., 2.]) + + # Get 10 (2x) overdispersed initial states. + initial_state = target.sample(10) * 2. + ==> (10, 2) + + # Get 1000 samples from the 10 independent chains. + chains_states, _ = tfb.hmc.sample_chain( + num_results=1000, + target_log_prob_fn=target.log_prob, + current_state=initial_state, + step_size=0.05, + num_leapfrog_steps=20, + num_burnin_steps=200) + chains_states.shape + ==> (1000, 10, 2) + + rhat = tfb.mcmc_diagnostics.potential_scale_reduction( + chains_states, independent_chain_ndims=1) + + # The second dimension needed a longer burn-in. + rhat.eval() + ==> [1.05, 1.3] + ``` + + To see why R-hat is reasonable, let `X` be a random variable drawn uniformly + from the combined states (combined over all chains). Then, in the limit + `N, C --> infinity`, with `E`, `Var` denoting expectation and variance, + + ```R-hat = ( E[Var[X | chain]] + Var[E[X | chain]] ) / E[Var[X | chain]].``` + + Using the law of total variance, the numerator is the variance of the combined + states, and the denominator is the total variance minus the variance of the + the individual chain means. If the chains are all drawing from the same + distribution, they will have the same mean, and thus the ratio should be one. + + [1] "Inference from Iterative Simulation Using Multiple Sequences" + Andrew Gelman and Donald B. Rubin + Statist. Sci. Volume 7, Number 4 (1992), 457-472. + [2] "General Methods for Monitoring Convergence of Iterative Simulations" + Stephen P. Brooks and Andrew Gelman + Journal of Computational and Graphical Statistics, 1998. Vol 7, No. 4. + + Args: + chains_states: `Tensor` or Python `list` of `Tensor`s representing the + state(s) of a Markov Chain at each result step. The `ith` state is + assumed to have shape `[Ni, Ci1, Ci2,...,CiD] + A`. + Dimension `0` indexes the `Ni > 1` result steps of the Markov Chain. + Dimensions `1` through `D` index the `Ci1 x ... x CiD` independent + chains to be tested for convergence to the same target. + The remaining dimensions, `A`, can have any shape (even empty). + independent_chain_ndims: Integer type `Tensor` with value `>= 1` giving the + number of giving the number of dimensions, from `dim = 1` to `dim = D`, + holding independent chain results to be tested for convergence. + name: `String` name to prepend to created ops. Default: + `potential_scale_reduction`. + + Returns: + `Tensor` or Python `list` of `Tensor`s representing the R-hat statistic for + the state(s). Same `dtype` as `state`, and shape equal to + `state.shape[1 + independent_chain_ndims:]`. + + Raises: + ValueError: If `independent_chain_ndims < 1`. + """ + chains_states_was_list = _is_list_like(chains_states) + if not chains_states_was_list: + chains_states = [chains_states] + + # tensor_util.constant_value returns None iff a constant value (as a numpy + # array) is not efficiently computable. Therefore, we try constant_value then + # check for None. + icn_const_ = tensor_util.constant_value( + ops.convert_to_tensor(independent_chain_ndims)) + if icn_const_ is not None: + independent_chain_ndims = icn_const_ + if icn_const_ < 1: + raise ValueError( + "Argument `independent_chain_ndims` must be `>= 1`, found: {}".format( + independent_chain_ndims)) + + with ops.name_scope(name, "potential_scale_reduction"): + rhat_list = [ + _potential_scale_reduction_single_state(s, independent_chain_ndims) + for s in chains_states + ] + + if chains_states_was_list: + return rhat_list + return rhat_list[0] + + +def _potential_scale_reduction_single_state(state, independent_chain_ndims): + """potential_scale_reduction for one single state `Tensor`.""" + with ops.name_scope( + "potential_scale_reduction_single_state", + values=[state, independent_chain_ndims]): + # We assume exactly one leading dimension indexes e.g. correlated samples + # from each Markov chain. + state = ops.convert_to_tensor(state, name="state") + sample_ndims = 1 + + sample_axis = math_ops.range(0, sample_ndims) + chain_axis = math_ops.range(sample_ndims, + sample_ndims + independent_chain_ndims) + sample_and_chain_axis = math_ops.range( + 0, sample_ndims + independent_chain_ndims) + + n = _axis_size(state, sample_axis) + m = _axis_size(state, chain_axis) + + # In the language of [2], + # B / n is the between chain variance, the variance of the chain means. + # W is the within sequence variance, the mean of the chain variances. + b_div_n = _reduce_variance( + math_ops.reduce_mean(state, sample_axis, keepdims=True), + sample_and_chain_axis, + biased=False) + w = math_ops.reduce_mean( + _reduce_variance(state, sample_axis, keepdims=True, biased=True), + sample_and_chain_axis) + + # sigma^2_+ is an estimate of the true variance, which would be unbiased if + # each chain was drawn from the target. c.f. "law of total variance." + sigma_2_plus = w + b_div_n + + return ((m + 1.) / m) * sigma_2_plus / w - (n - 1.) / (m * n) + + +# TODO(b/72873233) Move some variant of this to sample_stats. +def _reduce_variance(x, axis=None, biased=True, keepdims=False): + with ops.name_scope("reduce_variance"): + x = ops.convert_to_tensor(x, name="x") + mean = math_ops.reduce_mean(x, axis=axis, keepdims=True) + biased_var = math_ops.reduce_mean( + math_ops.squared_difference(x, mean), axis=axis, keepdims=keepdims) + if biased: + return biased_var + n = _axis_size(x, axis) + return (n / (n - 1.)) * biased_var + + +def _axis_size(x, axis=None): + """Get number of elements of `x` in `axis`, as type `x.dtype`.""" + if axis is None: + return math_ops.cast(array_ops.size(x), x.dtype) + return math_ops.cast( + math_ops.reduce_prod(array_ops.gather(array_ops.shape(x), axis)), x.dtype) + + +def _is_list_like(x): + """Helper which returns `True` if input is `list`-like.""" + return isinstance(x, (tuple, list)) + + +def _broadcast_maybelist_arg(states, secondary_arg, name): + """Broadcast a listable secondary_arg to that of states.""" + if _is_list_like(secondary_arg): + if len(secondary_arg) != len(states): + raise ValueError("Argument `%s` was a list of different length ({}) than " + "`states` ({})".format(name, len(states))) + else: + secondary_arg = [secondary_arg] * len(states) + + return secondary_arg diff --git a/tensorflow/contrib/bayesflow/python/ops/optimizers.py b/tensorflow/contrib/bayesflow/python/ops/optimizers.py index ee32e6b5c3d9efaeaf73436638c5eea55f2cfc70..fb70628d1083836281e9327e83e109493276c64f 100644 --- a/tensorflow/contrib/bayesflow/python/ops/optimizers.py +++ b/tensorflow/contrib/bayesflow/python/ops/optimizers.py @@ -24,11 +24,13 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import from tensorflow.contrib.bayesflow.python.ops.sgld_optimizer import * +from tensorflow.contrib.bayesflow.python.ops.variational_sgd_optimizer import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'SGLDOptimizer', + 'VariationalSGDOptimizer', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py index 5d36ea7a2b51aa45cdc253992a2a58634c068987..7786656398e3c87704227be95b3cd23a38785249 100644 --- a/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py +++ b/tensorflow/contrib/bayesflow/python/ops/sgld_optimizer.py @@ -189,6 +189,10 @@ class SGLDOptimizer(optimizer.Optimizer): new_grad, use_locking=self._use_locking).op + def _finish(self, update_ops, name_scope): + update_ops.append([self._counter.assign_add(1)]) + return control_flow_ops.group(*update_ops, name=name_scope) + @property def variable_scope(self): """Variable scope of all calls to `tf.get_variable`.""" diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eadf6f4d5fa1c776e2c71c66c4b64b8f5ac98359 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/variable_utils.py @@ -0,0 +1,29 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions related to managing `tf.Variable`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +from tensorflow.contrib.bayesflow.python.ops.variable_utils_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member +from tensorflow.python.util import all_util + +_allowed_symbols = [ + "externalize_variables_as_args", +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..ca3d75b5bfee093449026c7d1d62e3bdeff6b096 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/variable_utils_impl.py @@ -0,0 +1,157 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions related to managing `tf.Variable`s. + +@@externalize_variables_as_args +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import warnings + +from tensorflow.python.framework import ops +from tensorflow.python.ops import gradients_impl as gradients_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.ops import variables as variables_ops + +__all__ = [ + "externalize_variables_as_args", +] + + +# Cause all warnings to always be triggered. +# Not having this means subsequent calls wont trigger the warning. +warnings.simplefilter("always") + + +def externalize_variables_as_args(fn, + fn_args=(), + ancestor_variables=None, + possible_ancestor_vars=None, + assert_variable_override=False, + name=None): + """"Converts variables within a callable into explicit args. + + Makes a new callable from `fn` which has arguments `list(fn_args) + + list(ancestor_variables)`. If `ancestor_variables` is not specified, it is + inferred by checking which of `possible_ancestor_vars` actually influences the + return value of `fn` (concretely, gradient of `fn(*fn_args)` is not `None`). + By default `possible_ancestor_vars` is `tf.trainable_variables() + + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)`. + + #### Examples: + + ```python + num_samples = 2 + num_dims = 1 + dtype = np.float32 + + def foo(x): + x = tf.convert_to_tensor(x, dtype=dtype, name="x") + s = x.shape.as_list() + y = tf.get_variable( + name="y", + dtype=dtype, + initializer=np.arange(np.prod(s)).reshape(s).astype(dtype)) + return x + y + + x = tf.constant(dtype([0.1, 0.2])) + + wrapped_foo, discovered_ancestor_variables = ( + externalize_variables_as_args(foo, [x])) + + new_x = dtype([[1.], [2.]]) + new_y = dtype([[3.], [4.]]) + new_result = wrapped_foo(new_x, new_y) + # ==> [[4.], [6.]] + + discovered_ancestor_variables == [tf.get_variable("y", dtype)] + # ==> [True] + ``` + + Args: + fn: Python callable which returns a `Tensor` and accepts `*fn_args`. + fn_args: Python list of args to `fn`. Represents dummy arguments passed to + `fn` to trace its execution; actual values are unimportant. These args are + only used to construct the output of `fn` and to resolve the ancestor + `tf.Variable`s. + Default value: `()` (i.e., `fn` takes no args). + ancestor_variables: Python list of `tf.Variable`s. When `None` the list is + expanded to non-`None` gradients of `fn(*fn_args)`. By directly providing + the `ancestor_variables` the internal call to `fn` is avoided. + Default value: `None` (i.e., `tf.Variable` dependencies are discovered). + possible_ancestor_vars: Python list of possible `tf.Variable`s which might + be a dependency of computing `fn(*fn_args)`. + Default value: `None` (i.e., expanded as described above). + assert_variable_override: Python `bool` indicating that not finding a + `tf.Variable` in the override list is an exception. + Default value: `False` (i.e., missing a `Variable` triggers a `warning`). + name: Python `str` name prefixed to Ops created by this function. + Default value: `None` (i.e., "externalize_variables_as_args"). + + Returns: + wrapped_fn: Python callable taking arguments like + `*(list(fn_args) + discovered_ancestor_variables)`. + discovered_ancestor_variables: Python list of `tf.Variable`s known to be a + dependency of `fn(*fn_args)`. + + Raises: + ValueError: if `assert_variable_override` is `True` and `Variable` is + requested but not overridden. + """ + def _make_bypassing_custom_getter_fn(new_var_dict): + """Return dict value rather than what would otherwise be dict key.""" + def _custom_getter(getter, *args, **kwargs): + v = getter(*args, **kwargs) + new_v = new_var_dict.get(v, None) + if new_v is None: + msg = "Variable \"{}\" not found in bypass dict.".format(v) + if assert_variable_override: + raise ValueError(msg) + warnings.warn(msg) + return v + return new_v + return _custom_getter + + with ops.name_scope(name, "externalize_variables_as_args"): + if ancestor_variables is not None and not ancestor_variables: + return fn, () + if ancestor_variables is None: + y = fn(*fn_args) # Side-effect: adds trainable vars. + if possible_ancestor_vars is None: + possible_ancestor_vars = ( + variables_ops.trainable_variables() + + ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES)) + # TODO(b/72873296): Add a dedicated op for identifying ancestors. + ancestors = [v for g, v + in zip(gradients_ops.gradients(y, possible_ancestor_vars), + possible_ancestor_vars) + if g is not None] + ancestor_variables = sorted(ancestors, key=lambda v: v.name) + n = len(fn_args) + def _fn(*args): + with ops.name_scope("wrapped_fn"): + vars_dict = dict( + (k, ops.convert_to_tensor( + v, dtype=k.dtype.base_dtype, name=k.op.name)) + for k, v in zip(ancestor_variables, args[n:])) + with varscope_ops.variable_scope( + varscope_ops.get_variable_scope(), + reuse=True, + custom_getter=_make_bypassing_custom_getter_fn(vars_dict)): + return fn(*args[:n]) + return _fn, ancestor_variables diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py b/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..4d5f0cfe9713a011b32c5aba8d429847d81f33e2 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/variational_sgd_optimizer.py @@ -0,0 +1,279 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""An optimizer module for constant stochastic gradient descent.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as varscope_ops +from tensorflow.python.training import optimizer +from tensorflow.python.training import training_ops + + +class VariationalSGDOptimizer(optimizer.Optimizer): + """An optimizer module for constant stochastic gradient descent. + + This implements an optimizer module for the constant stochastic gradient + descent algorithm [1]. The optimization variable is regarded as an + approximate sample from the posterior . + + Note: If a prior is included in the loss, it should be scaled by + `1/num_pseudo_batches`, where num_pseudo_batches is the number of minibatches + in the data. I.e., it should be divided by the `num_pseudo_batches` term + described below. + + [1]: "Stochastic Gradient Descent as Approximate Bayesian Inference + Stephan Mandt, Matthew D. Hoffman, David M. Blei. + ArXiv:1704.04289, 2017. https://arxiv.org/abs/1704.04289 + + Args: + batch_size: Scalar `int`-like `Tensor`. The number of examples in a + minibatch in the data set. Note: Assumes the loss is taken as the mean + over a minibatch. Otherwise if the sum was taken set this to 1. + total_num_examples: Scalar `int`-like `Tensor`. The total number of examples + in the data set. + max_learning_rate: Scalar `float`-like `Tensor`. A maximum allowable + effective coordinate-wise learning rate. The algorithm scales down any + effective learning rate (i.e. after preconditioning) that is larger than + this. (Default: `1`) + preconditioner_decay_rate: Scalar `float`-like `Tensor`. The exponential + decay rate of the rescaling of the preconditioner (RMSprop). (This is + "alpha" in [1]). Should be smaller than but nearly `1` to approximate + sampling from the posterior. (Default: `0.95`) + burnin: Scalar `int`-like `Tensor`. The number of iterations to collect + gradient statistics to update the preconditioner before starting to draw + noisy samples. (Default: `25`) + burnin_max_learning_rate: Scalar `float`-like `Tensor`. Maximum learning + rate to use during the burnin period. + (Default: `1e-8`) + use_single_learning_rate: Boolean Indicates whether one single learning + rate is used or coordinate_wise learning rates are used. + (Default: `False`) + name: Python `str` describing ops managed by this function. + (Default: `"VariationalSGDOptimizer"`) + variable_scope: Variable scope used for calls to `tf.get_variable`. + If `None`, a new variable scope is created using name + `ops.get_default_graph().unique_name(name or default_name)`. + + Raises: + InvalidArgumentError: If preconditioner_decay_rate is a `Tensor` not in + `(0,1]`. + """ + + def __init__(self, + batch_size, + total_num_examples, + max_learning_rate=1.0, + preconditioner_decay_rate=0.95, + burnin=25, + burnin_max_learning_rate=1e-6, + use_single_learning_rate=False, + name=None, + variable_scope=None): + default_name = 'VariationalSGDOptimizer' + with ops.name_scope(name, default_name, [ + max_learning_rate, preconditioner_decay_rate, batch_size, burnin, + burnin_max_learning_rate + ]): + if variable_scope is None: + var_scope_name = ops.get_default_graph().unique_name( + name or default_name) + with varscope_ops.variable_scope(var_scope_name) as scope: + self._variable_scope = scope + else: + self._variable_scope = variable_scope + + self._preconditioner_decay_rate = ops.convert_to_tensor( + preconditioner_decay_rate, name='preconditioner_decay_rate') + self._batch_size = ops.convert_to_tensor(batch_size, name='batch_size') + self._total_num_examples = ops.convert_to_tensor( + total_num_examples, name='total_num_examples') + self._burnin = ops.convert_to_tensor(burnin, name='burnin') + self._burnin_max_learning_rate = ops.convert_to_tensor( + burnin_max_learning_rate, name='burnin_max_learning_rate') + self._max_learning_rate = ops.convert_to_tensor( + max_learning_rate, name='max_learning_rate') + self._use_single_learning_rate = use_single_learning_rate + + with varscope_ops.variable_scope(self._variable_scope): + self._counter = varscope_ops.get_variable( + 'counter', initializer=0, trainable=False) + + self._preconditioner_decay_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._preconditioner_decay_rate, + message='`preconditioner_decay_rate` must be non-negative'), + check_ops.assert_less_equal( + self._preconditioner_decay_rate, + 1., + message='`preconditioner_decay_rate` must be at most 1.'), + ], self._preconditioner_decay_rate) + + self._batch_size = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._batch_size, + 0, + message='`batch_size` must be greater than zero') + ], self._batch_size) + + self._total_num_examples = control_flow_ops.with_dependencies([ + check_ops.assert_greater( + self._total_num_examples, + 0, + message='`total_num_examples` must be greater than zero') + ], self._total_num_examples) + + self._burnin = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin, message='`burnin` must be non-negative'), + check_ops.assert_integer( + self._burnin, message='`burnin` must be an integer') + ], self._burnin) + + self._burnin_max_learning_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._burnin_max_learning_rate, + message='`burnin_max_learning_rate` must be non-negative') + ], self._burnin_max_learning_rate) + + self._max_learning_rate = control_flow_ops.with_dependencies([ + check_ops.assert_non_negative( + self._max_learning_rate, + message='`max_learning_rate` must be non-negative') + ], self._max_learning_rate) + + super(VariationalSGDOptimizer, self).__init__( + use_locking=False, name=name or default_name) + + def _create_slots(self, var_list): + for v in var_list: + init_moment = init_ops.zeros_initializer(dtype=v.dtype) + self._get_or_make_slot_with_initializer( + v, init_moment, v.get_shape(), v.dtype, 'first_moment', self._name) + self._get_or_make_slot_with_initializer( + v, init_moment, v.get_shape(), v.dtype, 'second_moment', self._name) + + def _prepare(self): + self._decay_tensor = ops.convert_to_tensor( + self._preconditioner_decay_rate, name='preconditioner_decay_rate') + self._batch_size_tensor = ops.convert_to_tensor( + self._batch_size, name='batch_size_tensor') + + super(VariationalSGDOptimizer, self)._prepare() + + def _get_coordinatewise_learning_rate(self, grad, var): + # Compute the learning rate using a moving average for the diagonal of BB^T + avg_first = self.get_slot(var, 'first_moment') + avg_second = self.get_slot(var, 'second_moment') + decay_tensor = math_ops.cast(self._decay_tensor, var.dtype) + batch_size = math_ops.cast(self._batch_size_tensor, var.dtype) + + # Create an estimator for the moving average of gradient mean and variance + # via Welford's algorithm + if isinstance(grad, ops.Tensor): + delta = grad - avg_first + first_moment_update = avg_first.assign_add( + array_ops.where(self._counter < 1, math_ops.cast(1, var.dtype), + 1. - decay_tensor) * delta) + + with ops.control_dependencies([first_moment_update]): + second_moment_update = avg_second.assign_add( + math_ops.cast(self._counter < 1, var.dtype) * + -(1. - decay_tensor) * ( + avg_second - decay_tensor * math_ops.square(delta))) + diag_preconditioner = control_flow_ops.with_dependencies( + [second_moment_update], + clip_ops.clip_by_value(avg_second, 1e-12, 1e12)) + elif isinstance(grad, ops.IndexedSlices): + delta = grad.values - array_ops.gather_nd(avg_first, grad.indices) + first_moment_update = state_ops.scatter_add( + avg_first, + grad.indices, + array_ops.where(self._counter < 1, + math_ops.cast(1., var.dtype), + 1. - decay_tensor) * delta) + + with ops.control_dependencies([first_moment_update]): + avg_second = state_ops.scatter_add( + avg_second, + grad.indices, + math_ops.cast(self._counter < 1, var.dtype) * + -(1. - decay_tensor) * ( + array_ops.gather_nd(avg_second, grad.indices) - decay_tensor * + math_ops.square(delta))) + avg_second = array_ops.gather_nd(avg_second, grad.indices) + # TODO(b/70783772) + diag_preconditioner = clip_ops.clip_by_value(avg_second, 1e-12, 1e12) + else: + raise errors.InvalidArgumentError( + None, None, 'grad must of type Tensor or IndexedSlice') + + diag_preconditioner *= batch_size + + if self._use_single_learning_rate: + diag_preconditioner = math_ops.reduce_mean(diag_preconditioner) + + # From Theorem 2 Corollary 1 of Mandt et al. 2017 + return 2. * batch_size / ( + math_ops.cast(self._total_num_examples, var.dtype.base_dtype) * + diag_preconditioner) + + def _apply_dense(self, grad, var): + + max_learning_rate = array_ops.where(self._counter < self._burnin, + self._burnin_max_learning_rate, + self._max_learning_rate) + + learn_rates = clip_ops.clip_by_value( + self._get_coordinatewise_learning_rate(grad, var), 0.0, + math_ops.cast(max_learning_rate, var.dtype.base_dtype)) + + newgrad = grad * learn_rates + return training_ops.apply_gradient_descent( + var, + math_ops.cast(1.0, var.dtype), + newgrad, + use_locking=self._use_locking).op + + def _apply_sparse(self, grad, var): + + max_learning_rate = array_ops.where(self._counter < self._burnin, + self._burnin_max_learning_rate, + self._max_learning_rate) + + learn_rate = clip_ops.clip_by_value( + self._get_coordinatewise_learning_rate(grad, var), 0.0, + math_ops.cast(max_learning_rate, var.dtype)) + delta = grad.values * learn_rate + + return state_ops.scatter_sub(var, grad.indices, delta, + use_locking=self._use_locking) + + def _finish(self, update_ops, name_scope): + update_ops.append([self._counter.assign_add(1)]) + return control_flow_ops.group(*update_ops, name=name_scope) + + @property + def variable_scope(self): + """Variable scope of all calls to `tf.get_variable`.""" + return self._variable_scope diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 7072f56420ac9e576b20b62c0aa67498857403a7..6fdcd0f996ee011842a5add79f06264a28a2145c 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -196,6 +196,7 @@ py_test( name = "quantile_ops_test", size = "small", srcs = ["python/kernel_tests/quantile_ops_test.py"], + shard_count = 3, srcs_version = "PY2AND3", deps = [ ":quantile_ops_py", @@ -601,6 +602,7 @@ py_library( ":init_py", "//tensorflow/contrib/boosted_trees:gbdt_batch", "//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy", + "//tensorflow/contrib/boosted_trees/estimator_batch:dnn_tree_combined_estimator", "//tensorflow/contrib/boosted_trees/estimator_batch:init_py", "//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks", "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index 7792c7127c0285dc2eb5b213da054674f6a81d64..289f5bb3140974d8c37f4938ceef27275b099f9a 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -50,6 +50,7 @@ py_library( deps = [ "//tensorflow/contrib/learn", "//tensorflow/core:protos_all_py", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", "//tensorflow/python:training", @@ -129,3 +130,38 @@ py_library( "//tensorflow/python:math_ops", ], ) + +py_library( + name = "dnn_tree_combined_estimator", + srcs = ["dnn_tree_combined_estimator.py"], + srcs_version = "PY2AND3", + deps = [ + ":trainer_hooks", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees:model_ops_py", + "//tensorflow/contrib/learn", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "dnn_tree_combined_estimator_test", + size = "small", + srcs = ["dnn_tree_combined_estimator_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_gpu", + "no_pip_gpu", + "notsan", + ], + deps = [ + ":dnn_tree_combined_estimator", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + ], +) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index ef8dee91b6cc05c4c3dd5eb3c81de4fb65b473e3..31f5c444817b9b82723c86bea3504d4934e57eb8 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -33,6 +33,8 @@ from tensorflow.python.platform import gfile from tensorflow.python.saved_model import loader as saved_model_loader from tensorflow.python.saved_model import tag_constants +_SPARSE_FLOAT_FEATURE_NAME_TEMPLATE = "%s_%d" + def make_custom_export_strategy(name, convert_fn, @@ -147,13 +149,15 @@ def convert_to_universal_format(dtec, sorted_feature_names, inequality_test.threshold.float_value = split.threshold elif node_type == "sparse_float_binary_split_default_left": split = gtflow_node.sparse_float_binary_split_default_left.split - node.default_direction = ( - generic_tree_model_pb2.BinaryNode.LEFT) - # TODO(nponomareva): adjust this id assignement when we allow multi- - # column sparse tensors. + node.default_direction = (generic_tree_model_pb2.BinaryNode.LEFT) feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) + model_and_features.features.pop(sorted_feature_names[feature_id]) + (model_and_features.features[inequality_test.feature_id.id.value] + .SetInParent()) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -165,7 +169,12 @@ def convert_to_universal_format(dtec, sorted_feature_names, # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test - inequality_test.feature_id.id.value = sorted_feature_names[feature_id] + inequality_test.feature_id.id.value = ( + _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % + (sorted_feature_names[feature_id], split.dimension_id)) + model_and_features.features.pop(sorted_feature_names[feature_id]) + (model_and_features.features[inequality_test.feature_id.id.value] + .SetInParent()) inequality_test.type = ( generic_tree_model_pb2.InequalityTest.LESS_OR_EQUAL) inequality_test.threshold.float_value = split.threshold @@ -201,10 +210,14 @@ def _get_feature_importances(dtec, feature_names, num_dense_floats, split_column = feature_names[split.feature_column] elif node_type == "sparse_float_binary_split_default_left": split = tree_node.sparse_float_binary_split_default_left.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "sparse_float_binary_split_default_right": split = tree_node.sparse_float_binary_split_default_right.split - split_column = feature_names[split.feature_column + num_dense_floats] + split_column = _SPARSE_FLOAT_FEATURE_NAME_TEMPLATE % ( + feature_names[split.feature_column + num_dense_floats], + split.dimension_id) elif node_type == "categorical_id_binary_split": split = tree_node.categorical_id_binary_split split_column = feature_names[split.feature_column + num_dense_floats + diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py index 4ed18b2d34c5af47826ab1c058f5d13797593bd4..67ec0e16bf815e9dbea6567cc87c3980a825a004 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy_test.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for the conversion code from GTFlow format to Chauffeur.""" +"""Tests for the conversion code and for feature importances export. + +Tests that cover conversion from TFBT format to a tensorflow.contrib. +decision_tree generic_tree_model format and feature importances export. +""" from __future__ import absolute_import from __future__ import division @@ -95,10 +99,31 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } + nodes { + sparse_float_binary_split_default_right { + split { + feature_column: 1 + dimension_id:3 + threshold: -0.4 + left_id: 7 + right_id: 8 + } + } + node_metadata { + gain: 3600 + } + } + nodes { + leaf { + vector { + value: 0.36 + } + } + } nodes { leaf { vector { - value: 0.3 + value: 18 } } } @@ -108,17 +133,25 @@ class ConvertModelTest(test_util.TensorFlowTestCase): """ dtec = tree_config_pb2.DecisionTreeEnsembleConfig() text_format.Merge(dtec_str, dtec) - feature_columns = ["feature_b", "feature_a", "feature_d"] + feature_columns = [ + "feature_b", + "feature_a", + "feature_a_m", + "feature_d", + ] return dtec, feature_columns def testConvertModel(self): dtec, feature_columns = self._make_trees() + # Assume 2 sparse float columns, one with 1 dimension, the second one with + # 5 dimensions. # The feature columns in the order they were added. out = custom_export_strategy.convert_to_universal_format( - dtec, feature_columns, 1, 1, - 1) + dtec, feature_columns, 1, 2, 1) + # Features a and a_m are sparse float features, a_m is multidimensional. expected_tree = """ - features { key: "feature_a" } + features { key: "feature_a_0" } + features { key: "feature_a_m_3" } features { key: "feature_b" } features { key: "feature_d" } model { @@ -169,7 +202,6 @@ class ConvertModelTest(test_util.TensorFlowTestCase): } } } - nodes { node_id { value: 1 @@ -196,7 +228,7 @@ class ConvertModelTest(test_util.TensorFlowTestCase): inequality_left_child_test { feature_id { id { - value: "feature_a" + value: "feature_a_0" } } threshold { @@ -259,14 +291,51 @@ class ConvertModelTest(test_util.TensorFlowTestCase): node_id { value: 6 } + binary_node { + left_child_id { + value: 7 + } + right_child_id { + value: 8 + } + default_direction: RIGHT + inequality_left_child_test { + feature_id { + id { + value: "feature_a_m_3" + } + } + threshold { + float_value: -0.4 + } + } + } + } + nodes { + node_id { + value: 7 + } leaf { vector { value { - float_value: 0.03 + float_value: 0.036 } } } } + nodes { + node_id { + value: 8 + } + leaf { + vector { + value { + float_value: 1.8 + } + } + } + } + } } submodel_id { @@ -280,12 +349,15 @@ class ConvertModelTest(test_util.TensorFlowTestCase): def testFeatureImportance(self): dtec, feature_columns = self._make_trees() feature_importances = custom_export_strategy._get_feature_importances( - dtec, feature_columns, 1, 1, 1) - self.assertItemsEqual(["feature_b", "feature_a", "feature_d"], - feature_importances.keys()) + dtec, feature_columns, 1, 2, 1) + self.assertItemsEqual( + ["feature_b", "feature_a_0", "feature_a_m_3", "feature_d"], + feature_importances.keys()) self.assertAlmostEqual(50.0, feature_importances["feature_b"], places=4) - self.assertAlmostEqual(50.0, feature_importances["feature_a"], places=4) + self.assertAlmostEqual(50.0, feature_importances["feature_a_0"], places=4) self.assertAlmostEqual(50.0, feature_importances["feature_d"], places=4) + self.assertAlmostEqual( + 360.0, feature_importances["feature_a_m_3"], places=4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..cec3892b57655dc967b4e7926f7f5a6a30084487 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator.py @@ -0,0 +1,515 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorFlow estimators for combined DNN + GBDT training model. + +The combined model trains a DNN first, then trains boosted trees to boost the +logits of the DNN. The input layer of the DNN (including the embeddings learned +over sparse features) can optionally be provided to the boosted trees as +an additional input feature. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.boosted_trees.estimator_batch import trainer_hooks +from tensorflow.contrib.boosted_trees.python.ops import model_ops +from tensorflow.contrib.boosted_trees.python.training.functions import gbdt_batch +from tensorflow.contrib.layers.python.layers import optimizers +from tensorflow.contrib.learn.python.learn.estimators import estimator +from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + + +_DNN_LEARNING_RATE = 0.001 + + +def _get_optimizer(optimizer): + if callable(optimizer): + return optimizer() + else: + return optimizer + + +def _add_hidden_layer_summary(value, tag): + summary.scalar("%s_fraction_of_zero_values" % tag, nn.zero_fraction(value)) + summary.histogram("%s_activation" % tag, value) + + +def _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, + dnn_feature_columns, tree_learner_config, num_trees, + tree_examples_per_layer, + config=None, dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """DNN and GBDT combined model_fn. + + Args: + features: `dict` of `Tensor` objects. + labels: Labels used to train on. + mode: Mode we are in. (TRAIN/EVAL/INFER) + head: A `Head` instance. + dnn_hidden_units: List of hidden units per layer. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + config: `RunConfig` of the estimator. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate of 0.001. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + + Returns: + A `ModelFnOps` object. + Raises: + ValueError: if inputs are not valid. + """ + if not isinstance(features, dict): + raise ValueError("features should be a dictionary of `Tensor`s. " + "Given type: {}".format(type(features))) + + if not dnn_feature_columns: + raise ValueError("dnn_feature_columns must be specified") + + # Build DNN Logits. + dnn_parent_scope = "dnn" + dnn_partitioner = dnn_input_layer_partitioner or ( + partitioned_variables.min_max_variable_partitioner( + max_partitions=config.num_ps_replicas, + min_slice_size=64 << 20)) + + with variable_scope.variable_scope( + dnn_parent_scope, + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner): + + with variable_scope.variable_scope( + "input_from_feature_columns", + values=tuple(six.itervalues(features)), + partitioner=dnn_partitioner) as input_layer_scope: + input_layer = layers.input_from_feature_columns( + columns_to_tensors=features, + feature_columns=dnn_feature_columns, + weight_collections=[dnn_parent_scope], + scope=input_layer_scope) + previous_layer = input_layer + for layer_id, num_hidden_units in enumerate(dnn_hidden_units): + with variable_scope.variable_scope( + "hiddenlayer_%d" % layer_id, + values=(previous_layer,)) as hidden_layer_scope: + net = layers.fully_connected( + previous_layer, + num_hidden_units, + activation_fn=dnn_activation_fn, + variables_collections=[dnn_parent_scope], + scope=hidden_layer_scope) + if dnn_dropout is not None and mode == model_fn.ModeKeys.TRAIN: + net = layers.dropout(net, keep_prob=(1.0 - dnn_dropout)) + _add_hidden_layer_summary(net, hidden_layer_scope.name) + previous_layer = net + with variable_scope.variable_scope( + "logits", + values=(previous_layer,)) as logits_scope: + dnn_logits = layers.fully_connected( + previous_layer, + head.logits_dimension, + activation_fn=None, + variables_collections=[dnn_parent_scope], + scope=logits_scope) + _add_hidden_layer_summary(dnn_logits, logits_scope.name) + + def _dnn_train_op_fn(loss): + """Returns the op to optimize the loss.""" + return optimizers.optimize_loss( + loss=loss, + global_step=training_util.get_global_step(), + learning_rate=_DNN_LEARNING_RATE, + optimizer=_get_optimizer(dnn_optimizer), + name=dnn_parent_scope, + variables=ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, + scope=dnn_parent_scope), + # Empty summaries to prevent optimizers from logging training_loss. + summaries=[]) + + # Build Tree Logits. + global_step = training_util.get_global_step() + with ops.device(global_step.device): + ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config="", # Initialize an empty ensemble. + name="ensemble_model") + + tree_features = features.copy() + if dnn_input_layer_to_tree: + tree_features["dnn_input_layer"] = input_layer + tree_feature_columns.append(layers.real_valued_column("dnn_input_layer")) + gbdt_model = gbdt_batch.GradientBoostedDecisionTreeModel( + is_chief=config.is_chief, + num_ps_replicas=config.num_ps_replicas, + ensemble_handle=ensemble_handle, + center_bias=tree_center_bias, + examples_per_layer=tree_examples_per_layer, + learner_config=tree_learner_config, + feature_columns=tree_feature_columns, + logits_dimension=head.logits_dimension, + features=tree_features) + + with ops.name_scope("gbdt"): + predictions_dict = gbdt_model.predict(mode) + tree_logits = predictions_dict["predictions"] + + def _tree_train_op_fn(loss): + """Returns the op to optimize the loss.""" + update_op = gbdt_model.train(loss, predictions_dict, labels) + with ops.control_dependencies( + [update_op]), (ops.colocate_with(global_step)): + update_op = state_ops.assign_add(global_step, 1).op + return update_op + + tree_train_logits = dnn_logits + tree_logits + + def _no_train_op_fn(loss): + """Returns a no-op.""" + del loss + return control_flow_ops.no_op() + + model_fn_ops = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_no_train_op_fn, + logits=tree_train_logits) + dnn_train_op = head.create_model_fn_ops( + features=features, + mode=mode, + labels=labels, + train_op_fn=_dnn_train_op_fn, + logits=dnn_logits).train_op + tree_train_op = head.create_model_fn_ops( + features=tree_features, + mode=mode, + labels=labels, + train_op_fn=_tree_train_op_fn, + logits=tree_train_logits).train_op + + if tree_center_bias: + num_trees += 1 + finalized_trees, attempted_trees = gbdt_model.get_number_of_trees_tensor() + + model_fn_ops.training_hooks.extend([ + trainer_hooks.SwitchTrainOp( + dnn_train_op, dnn_steps_to_train, tree_train_op), + trainer_hooks.StopAfterNTrees( + num_trees, attempted_trees, finalized_trees)]) + + return model_fn_ops + + +class DNNBoostedTreeCombinedClassifier(estimator.Estimator): + """A classifier that uses a combined DNN/GBDT model.""" + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + n_classes=2, + weight_column_name=None, + model_dir=None, + config=None, + label_name=None, + label_keys=None, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedClassifier instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + n_classes: The number of label classes. + weight_column_name: The name of weight column. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + label_name: String, name of the key in label dict. Can be null if label + is a tensor (single headed models). + label_keys: Optional list of strings with size `[n_classes]` defining the + label vocabulary. Only supported for `n_classes` > 2. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + head = head_lib.multi_class_head( + n_classes=n_classes, + label_name=label_name, + label_keys=label_keys, + weight_column_name=weight_column_name, + enable_centered_bias=False) + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedClassifier, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) + + +class DNNBoostedTreeCombinedRegressor(estimator.Estimator): + """A regressor that uses a combined DNN/GBDT model.""" + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + weight_column_name=None, + model_dir=None, + config=None, + label_name=None, + label_dimension=1, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedRegressor instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + weight_column_name: The name of weight column. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + label_name: String, name of the key in label dict. Can be null if label + is a tensor (single headed models). + label_dimension: Number of regression labels per example. This is the size + of the last dimension of the labels `Tensor` (typically, this has shape + `[batch_size, label_dimension]`). + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + head = head_lib.regression_head( + label_name=label_name, + label_dimension=label_dimension, + weight_column_name=weight_column_name, + enable_centered_bias=False) + + # num_classes needed for GradientBoostedDecisionTreeModel + if label_dimension == 1: + tree_learner_config.num_classes = 2 + else: + tree_learner_config.num_classes = label_dimension + + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedRegressor, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) + + +class DNNBoostedTreeCombinedEstimator(estimator.Estimator): + """An estimator that uses a combined DNN/GBDT model. + + Useful for training with user specified `Head`. + """ + + def __init__(self, + dnn_hidden_units, + dnn_feature_columns, + tree_learner_config, + num_trees, + tree_examples_per_layer, + head, + model_dir=None, + config=None, + feature_engineering_fn=None, + dnn_optimizer="Adagrad", + dnn_activation_fn=nn.relu, + dnn_dropout=None, + dnn_input_layer_partitioner=None, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=10000, + tree_feature_columns=None, + tree_center_bias=True): + """Initializes a DNNBoostedTreeCombinedEstimator instance. + + Args: + dnn_hidden_units: List of hidden units per layer for DNN. + dnn_feature_columns: An iterable containing all the feature columns + used by the model's DNN. + tree_learner_config: A config for the tree learner. + num_trees: Number of trees to grow model to after training DNN. + tree_examples_per_layer: Number of examples to accumulate before + growing the tree a layer. This value has a big impact on model + quality and should be set equal to the number of examples in + training dataset if possible. It can also be a function that computes + the number of examples based on the depth of the layer that's + being built. + head: `Head` instance. + model_dir: Directory for model exports. + config: `RunConfig` of the estimator. + feature_engineering_fn: Feature engineering function. Takes features and + labels which are the output of `input_fn` and returns features and + labels which will be fed into the model. + dnn_optimizer: string, `Optimizer` object, or callable that defines the + optimizer to use for training the DNN. If `None`, will use the Adagrad + optimizer with default learning rate. + dnn_activation_fn: Activation function applied to each layer of the DNN. + If `None`, will use `tf.nn.relu`. + dnn_dropout: When not `None`, the probability to drop out a given + unit in the DNN. + dnn_input_layer_partitioner: Partitioner for input layer of the DNN. + Defaults to `min_max_variable_partitioner` with `min_slice_size` + 64 << 20. + dnn_input_layer_to_tree: Whether to provide the DNN's input layer + as a feature to the tree. + dnn_steps_to_train: Number of steps to train dnn for before switching + to gbdt. + tree_feature_columns: An iterable containing all the feature columns + used by the model's boosted trees. If dnn_input_layer_to_tree is + set to True, these features are in addition to dnn_feature_columns. + tree_center_bias: Whether a separate tree should be created for + first fitting the bias. + """ + def _model_fn(features, labels, mode, config): + return _dnn_tree_combined_model_fn( + features, labels, mode, head, dnn_hidden_units, dnn_feature_columns, + tree_learner_config, num_trees, tree_examples_per_layer, config, + dnn_optimizer, dnn_activation_fn, dnn_dropout, + dnn_input_layer_partitioner, dnn_input_layer_to_tree, + dnn_steps_to_train, + tree_feature_columns, tree_center_bias) + + super(DNNBoostedTreeCombinedEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, + config=config, feature_engineering_fn=feature_engineering_fn) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..83d58c561008e8a5a69eb503d1605bb9e940f281 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py @@ -0,0 +1,105 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for combined DNN + GBDT estimators.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile + +from tensorflow.contrib.boosted_trees.estimator_batch import dnn_tree_combined_estimator as estimator +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn.python.learn.estimators import estimator_test_utils +from tensorflow.contrib.learn.python.learn.estimators import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + + +def _train_input_fn(): + features = { + "x": constant_op.constant([[2.], [1.], [1.]]) + } + label = constant_op.constant([[1], [0], [0]], dtype=dtypes.int32) + return features, label + + +def _eval_input_fn(): + features = { + "x": constant_op.constant([[1.], [2.], [2.]]) + } + label = constant_op.constant([[0], [1], [1]], dtype=dtypes.int32) + return features, label + + +class DNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase): + + def testClassifierContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedClassifier) + + def testRegressorContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedRegressor) + + def testEstimatorContract(self): + estimator_test_utils.assert_estimator_contract( + self, estimator.DNNBoostedTreeCombinedEstimator) + + def testNoDNNFeatureColumns(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + + with self.assertRaisesRegexp( + ValueError, + "dnn_feature_columns must be specified"): + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2) + classifier.fit(input_fn=_train_input_fn, steps=5) + + def testFitAndEvaluateDontThrowException(self): + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + learner_config.constraints.max_tree_depth = 1 + model_dir = tempfile.mkdtemp() + config = run_config.RunConfig() + + classifier = estimator.DNNBoostedTreeCombinedClassifier( + dnn_hidden_units=[1], + dnn_feature_columns=[feature_column.real_valued_column("x")], + tree_learner_config=learner_config, + num_trees=1, + tree_examples_per_layer=3, + n_classes=2, + model_dir=model_dir, + config=config, + dnn_steps_to_train=10, + dnn_input_layer_to_tree=False, + tree_feature_columns=[feature_column.real_valued_column("x")]) + + classifier.fit(input_fn=_train_input_fn, steps=15) + classifier.evaluate(input_fn=_eval_input_fn, steps=1) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py index 79193fffc3d3fa97e20a12181bf20e6ad86dcb58..2e4151cac40f770e2bece70d752122eb7f34dd40 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/trainer_hooks.py @@ -24,6 +24,7 @@ from tensorflow.contrib.learn.python.learn import session_run_hook from tensorflow.contrib.learn.python.learn.session_run_hook import SessionRunArgs from tensorflow.core.framework.summary_pb2 import Summary from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import training_util from tensorflow.python.training.summary_io import SummaryWriterCache @@ -175,3 +176,40 @@ class StopAfterNTrees(session_run_hook.SessionRunHook): logging.info("Requesting stop since we have reached %d trees.", num_finalized_trees) run_context.request_stop() + + +class SwitchTrainOp(session_run_hook.SessionRunHook): + """Hook that switches the train op after specified number of steps. + + Hook that replaces the train op depending on the number of steps of training + that have taken place. The first_train_op is used till train_steps steps + are reached. Thereafter the second_train_op is used. + """ + + def __init__(self, first_train_op, train_steps, second_train_op): + """Initializes a `SwitchTrainOp`.""" + self._first_train_op = first_train_op + self._second_train_op = second_train_op + self._train_steps = train_steps + + def _get_train_op_for_global_step(self, current_step): + """Gets train_op for current global step.""" + if current_step < self._train_steps: + return self._first_train_op + return self._second_train_op + + def begin(self): + self._global_step_tensor = training_util.get_global_step() + self._current_train_op = control_flow_ops.no_op() + if self._global_step_tensor is None: + raise RuntimeError( + "Global step should be created to use SwitchTrainOp.") + + def before_run(self, run_context): # pylint: disable=unused-argument + return session_run_hook.SessionRunArgs( + {"global_step": self._global_step_tensor, + "train_op": self._current_train_op}) + + def after_run(self, run_context, run_values): + self._current_train_op = self._get_train_op_for_global_step( + run_values.results["global_step"]) diff --git a/tensorflow/contrib/boosted_trees/examples/boston_combined.py b/tensorflow/contrib/boosted_trees/examples/boston_combined.py new file mode 100644 index 0000000000000000000000000000000000000000..e04b56afbfd266dc13a5b0d78d171ea273415ee3 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/boston_combined.py @@ -0,0 +1,165 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Regression on Boston housing data using DNNBoostedTreeCombinedRegressor. + + Example Usage: + + python tensorflow/contrib/boosted_trees/examples/boston_combined.py \ + --batch_size=404 --output_dir="/tmp/boston" \ + --dnn_hidden_units="8,4" --dnn_steps_to_train=1000 \ + --tree_depth=4 --tree_learning_rate=0.1 \ + --num_trees=100 --tree_l2=0.001 --num_eval_steps=1 \ + --vmodule=training_ops=1 + + When training is done, mean squared error on eval data is reported. + Point tensorboard to the directory for the run to see how the training + progresses: + + tensorboard --logdir=/tmp/boston + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import tensorflow as tf + +from tensorflow.contrib.boosted_trees.estimator_batch.dnn_tree_combined_estimator import DNNBoostedTreeCombinedRegressor +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn.python.learn import learn_runner +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils + +_BOSTON_NUM_FEATURES = 13 + + +def _get_estimator(output_dir, feature_cols): + """Configures DNNBoostedTreeCombinedRegressor based on flags.""" + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.fixed.learning_rate = ( + FLAGS.tree_learning_rate) + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.tree_l2 + learner_config.constraints.max_tree_depth = FLAGS.tree_depth + + run_config = tf.contrib.learn.RunConfig(save_summary_steps=1) + + # Create a DNNBoostedTreeCombinedRegressor estimator. + estimator = DNNBoostedTreeCombinedRegressor( + dnn_hidden_units=[int(x) for x in FLAGS.dnn_hidden_units.split(",")], + dnn_feature_columns=feature_cols, + tree_learner_config=learner_config, + num_trees=FLAGS.num_trees, + # This should be the number of examples. For large datasets it can be + # larger than the batch_size. + tree_examples_per_layer=FLAGS.batch_size, + model_dir=output_dir, + config=run_config, + dnn_input_layer_to_tree=True, + dnn_steps_to_train=FLAGS.dnn_steps_to_train) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for DNNBoostedTreeCombinedRegressor.""" + (x_train, y_train), (x_test, + y_test) = tf.keras.datasets.boston_housing.load_data() + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_train}, + y=y_train, + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) + + feature_columns = [ + feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) + ] + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + serving_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) + export_strategies = [ + saved_model_export_utils.make_export_strategy(serving_input_fn)] + return tf.contrib.learn.Experiment( + estimator=_get_estimator(output_dir, feature_columns), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None, + export_strategies=export_strategies) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for configuring DNNBoostedTreeCombinedRegressor. + parser.add_argument( + "--dnn_hidden_units", + type=str, + default="8,4", + help="Hidden layers for DNN.") + parser.add_argument( + "--dnn_steps_to_train", + type=int, + default=1000, + help="Number of steps to train DNN.") + parser.add_argument( + "--tree_depth", type=int, default=4, help="Maximum depth of trees.") + parser.add_argument( + "--tree_l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--tree_learning_rate", + type=float, + default=0.1, + help=("Learning rate (shrinkage weight) with which each " + "new tree is added.")) + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index 4b5d5ba0de6c3995ee2da7a44ab0ba099cbf1b35..754b7bc3270d647fc381033b769eadd7b791771e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -48,8 +48,9 @@ class CreateTreeEnsembleVariableOp : public OpKernel { if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), stamp_token)) { result->Unref(); - OP_REQUIRES(context, false, errors::InvalidArgument( - "Unable to parse tree ensemble config.")); + OP_REQUIRES( + context, false, + errors::InvalidArgument("Unable to parse tree ensemble config.")); } // Only create one, if one does not exist already. Report status for all diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index f8086b0c2bb93eae6af0336bbe33fc23f8fcde22..b3fe38614e05801b223f0c96f7a70ce7e432a70b 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -47,8 +47,8 @@ namespace boosted_trees { using boosted_trees::learner::LearnerConfig; using boosted_trees::learner::LearningRateConfig; using boosted_trees::learner::LearningRateDropoutDrivenConfig; -using boosted_trees::models::MultipleAdditiveTrees; using boosted_trees::models::DecisionTreeEnsembleResource; +using boosted_trees::models::MultipleAdditiveTrees; using boosted_trees::utils::DropoutUtils; using boosted_trees::utils::TensorUtils; diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 8600c8c53caa5fd4274ba6730fc764d8315d680c..0f4c2298f56be48bb32f52d5d44cff8afe284f1e 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -36,22 +36,21 @@ namespace tensorflow { using ::boosted_trees::QuantileConfig; -using boosted_trees::utils::TensorUtils; using boosted_trees::QuantileStreamResource; +using boosted_trees::utils::TensorUtils; namespace { const char* const kExampleWeightsName = "example_weights"; const char* const kMaxElementsName = "max_elements"; -const char* const kHandleName = "handle"; const char* const kNextStampTokenName = "next_stamp_token"; const char* const kStampTokenName = "stamp_token"; const char* const kAreBucketsReadyName = "are_buckets_ready"; +const char* const kGenerateQuantiles = "generate_quantiles"; // Names for sparse arguments. const char* const kNumSparseFeaturesName = "num_sparse_features"; const char* const kSparseBucketsName = "sparse_buckets"; const char* const kSparseValuesName = "sparse_values"; const char* const kSparseIndicesName = "sparse_indices"; -const char* const kSparseStreamsStateName = "sparse_streams_state"; const char* const kSparseSummariesName = "sparse_summaries"; const char* const kSparseConfigName = "sparse_config"; const char* const kSparseOutputTensorName = "sparse_quantiles"; @@ -59,7 +58,6 @@ const char* const kSparseOutputTensorName = "sparse_quantiles"; const char* const kDenseBucketsName = "dense_buckets"; const char* const kDenseConfigName = "dense_config"; const char* const kDenseOutputTensorName = "dense_quantiles"; -const char* const kDenseStreamsStateName = "dense_streams_state"; const char* const kDenseSummariesName = "dense_summaries"; const char* const kDenseValuesName = "dense_values"; const char* const kNumDenseFeaturesName = "num_dense_features"; @@ -182,6 +180,16 @@ std::vector GenerateBoundaries(const QuantileStream& stream, return boundaries; } +// Generates quantiles on a finalized QuantileStream. +std::vector GenerateQuantiles(const QuantileStream& stream, + int num_quantiles) { + // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values + // will be returned. + std::vector boundaries = stream.GenerateQuantiles(num_quantiles); + CHECK_EQ(boundaries.size(), num_quantiles + 1); + return boundaries; +} + // Copies quantiles to output list. void CopyBoundaries(OpKernelContext* const context, const std::vector& boundaries, const int64 index, @@ -224,6 +232,8 @@ class CreateQuantileAccumulatorOp : public OpKernel { OP_REQUIRES_OK(context, context->GetAttr(kNumQuantilesName, &num_quantiles_)); OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_)); + OP_REQUIRES_OK(context, + context->GetAttr(kGenerateQuantiles, &generate_quantiles_)); } void Compute(OpKernelContext* context) override { @@ -231,9 +241,9 @@ class CreateQuantileAccumulatorOp : public OpKernel { // other exceptions. If one already exists, it unrefs the new one. const Tensor* stamp_token_t; OP_REQUIRES_OK(context, context->input(kStampTokenName, &stamp_token_t)); - auto result = - new QuantileStreamResource(epsilon_, num_quantiles_, max_elements_, - stamp_token_t->scalar()()); + auto result = new QuantileStreamResource(epsilon_, num_quantiles_, + max_elements_, generate_quantiles_, + stamp_token_t->scalar()()); auto status = CreateResource(context, HandleFromInput(context, 0), result); if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { OP_REQUIRES(context, false, status); @@ -246,6 +256,7 @@ class CreateQuantileAccumulatorOp : public OpKernel { // An upperbound on the number of enteries that the summaries might have // for a feature. int64 max_elements_; + bool generate_quantiles_; }; REGISTER_KERNEL_BUILDER(Name("CreateQuantileAccumulator").Device(DEVICE_CPU), @@ -373,7 +384,7 @@ class MakeQuantileSummariesOp : public OpKernel { protobuf::Arena arena; ::boosted_trees::QuantileSummaryState* summary_proto = protobuf::Arena::CreateMessage< - ::boosted_trees::QuantileSummaryState>(&arena); + ::boosted_trees::QuantileSummaryState>(&arena); const auto& summary = stream.GetFinalSummary(); CopySummaryToProto(summary, summary_proto); // Output to tensor. @@ -597,10 +608,15 @@ class QuantileAccumulatorFlushOp : public OpKernel { << "Passed stamp token: " << stamp_token << " " << "Current token: " << streams_resource->stamp(); QuantileStream* stream = streams_resource->stream(stamp_token); + bool generate_quantiles = streams_resource->generate_quantiles(); stream->Finalize(); + streams_resource->set_boundaries( stamp_token, - GenerateBoundaries(*stream, streams_resource->num_quantiles())); + generate_quantiles + ? GenerateQuantiles(*stream, streams_resource->num_quantiles()) + : GenerateBoundaries(*stream, streams_resource->num_quantiles())); + streams_resource->Reset(next_stamp_token); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc index 18b4abd654ea3541d646a43ac901aca1a678446f..44a8ffaf4b2f5a9c11b3abc46ce55a18c80ad318 100644 --- a/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/split_handler_ops.cc @@ -34,10 +34,10 @@ namespace tensorflow { +using boosted_trees::learner::LearnerConfig_MultiClassStrategy; using boosted_trees::learner::SplitInfo; using boosted_trees::learner::stochastic::GradientStats; using boosted_trees::learner::stochastic::NodeStats; -using boosted_trees::learner::LearnerConfig_MultiClassStrategy; namespace { const int32 DUMMY_FEATURE_DIMENSION = -1; @@ -47,9 +47,8 @@ class BaseBuildSplitOp : public OpKernel { public: explicit BaseBuildSplitOp(OpKernelConstruction* const context) : OpKernel(context) { - OP_REQUIRES_OK( - context, - context->GetAttr("feature_column_group_id", &feature_column_group_id_)); + OP_REQUIRES_OK(context, context->GetAttr("feature_column_group_id", + &feature_column_group_id_)); OP_REQUIRES_OK(context, context->GetAttr("l1_regularization", &l1_regularization_)); OP_REQUIRES_OK(context, diff --git a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc index a9a229c8ae0c26bba5f0a684dad7e546298577bb..90a0655201f8cb8df6fc6417cb51216dec91b4d7 100644 --- a/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/stats_accumulator_ops.cc @@ -134,10 +134,9 @@ void SerializeScalarAccumulatorToOutput( OpKernelContext* context) { int64 num_slots = accumulator_resource.values().size(); Tensor* partition_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_partition_ids", TensorShape({num_slots}), - &partition_ids_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", + TensorShape({num_slots}), + &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); // Feature ids tensor has ids of feature columns and their dimensions. @@ -149,15 +148,14 @@ void SerializeScalarAccumulatorToOutput( Tensor* gradients_t = nullptr; OP_REQUIRES_OK( - context, - context->allocate_output("output_gradients", TensorShape({num_slots}), - &gradients_t)); + context, context->allocate_output( + "output_gradients", TensorShape({num_slots}), &gradients_t)); auto gradients = gradients_t->vec(); Tensor* hessians_t = nullptr; - OP_REQUIRES_OK(context, - context->allocate_output( - "output_hessians", TensorShape({num_slots}), &hessians_t)); + OP_REQUIRES_OK( + context, context->allocate_output("output_hessians", + TensorShape({num_slots}), &hessians_t)); auto hessians = hessians_t->vec(); int i = 0; @@ -177,10 +175,9 @@ void SerializeTensorAccumulatorToOutput( OpKernelContext* context) { int64 num_slots = accumulator_resource.values().size(); Tensor* partition_ids_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_partition_ids", TensorShape({num_slots}), - &partition_ids_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_partition_ids", + TensorShape({num_slots}), + &partition_ids_t)); auto partition_ids = partition_ids_t->vec(); Tensor* feature_ids_t = nullptr; @@ -202,9 +199,8 @@ void SerializeTensorAccumulatorToOutput( int64 num_hessian_elements = hessian_shape.num_elements(); hessian_shape.InsertDim(0, num_slots); Tensor* hessians_t = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output("output_hessians", hessian_shape, &hessians_t)); + OP_REQUIRES_OK(context, context->allocate_output("output_hessians", + hessian_shape, &hessians_t)); auto hessians = hessians_t->flat_outer_dims(); int i = 0; diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index c77d90e243c304ec8e9a10a0b63401f9bd825c3e..7f8dea1d3c2a04b725843f6e2932a0cdfbc7733c 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -361,10 +361,27 @@ class GrowTreeEnsembleOp : public OpKernel { // Increment attempt stats. ensemble_resource->IncrementAttempts(); + // In case we want to do feature selection and we have reached the limit, + // build a list of handlers used so far to avoid adding new features. + std::vector allowed_handlers; + if (learner_config_.constraints().max_number_of_unique_feature_columns() > + 0) { + allowed_handlers = ensemble_resource->GetUsedHandlers(); + // TODO(soroush): We can disable handlers that are not going to be used to + // avoid unnecessary computations. + if (allowed_handlers.size() < + learner_config_.constraints() + .max_number_of_unique_feature_columns()) { + // We have not reached the limit yet. Empty the list of allow features + // which means we can keep adding new features. + allowed_handlers.clear(); + } + } + // Find best splits for each active partition. std::map best_splits; - FindBestSplitsPerPartition(context, partition_ids_list, gains_list, - splits_list, &best_splits); + FindBestSplitsPerPartition(context, allowed_handlers, partition_ids_list, + gains_list, splits_list, &best_splits); // No-op if no new splits can be considered. if (best_splits.empty()) { @@ -381,7 +398,8 @@ class GrowTreeEnsembleOp : public OpKernel { // Split tree nodes. for (auto& split_entry : best_splits) { - SplitTreeNode(split_entry.first, &split_entry.second, tree_config); + SplitTreeNode(split_entry.first, &split_entry.second, tree_config, + ensemble_resource); } // Post-prune finalized tree if needed. @@ -403,12 +421,20 @@ class GrowTreeEnsembleOp : public OpKernel { // Helper method which effectively does a reduce over all split candidates // and finds the best split for each partition. void FindBestSplitsPerPartition( - OpKernelContext* const context, const OpInputList& partition_ids_list, - const OpInputList& gains_list, const OpInputList& splits_list, + OpKernelContext* const context, + const std::vector& allowed_handlers, // Empty means all handlers. + const OpInputList& partition_ids_list, const OpInputList& gains_list, + const OpInputList& splits_list, std::map* best_splits) { // Find best split per partition going through every feature candidate. // TODO(salehay): Is this worth parallelizing? for (int64 handler_id = 0; handler_id < num_handlers_; ++handler_id) { + if (!allowed_handlers.empty()) { + if (!std::binary_search(allowed_handlers.begin(), + allowed_handlers.end(), handler_id)) { + continue; + } + } const auto& partition_ids = partition_ids_list[handler_id].vec(); const auto& gains = gains_list[handler_id].vec(); const auto& splits = splits_list[handler_id].vec(); @@ -592,8 +618,10 @@ class GrowTreeEnsembleOp : public OpKernel { // Helper method to split a tree node and append its respective // leaf children given the split candidate. - void SplitTreeNode(const int32 node_id, SplitCandidate* split, - boosted_trees::trees::DecisionTreeConfig* tree_config) { + void SplitTreeNode( + const int32 node_id, SplitCandidate* split, + boosted_trees::trees::DecisionTreeConfig* tree_config, + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource) { // No-op if we have no real node. CHECK(node_id < tree_config->nodes_size()) << "Invalid node " << node_id << " to split."; @@ -633,6 +661,9 @@ class GrowTreeEnsembleOp : public OpKernel { // Replace node in tree. (*tree_config->mutable_nodes(node_id)) = *split->split_info.mutable_split_node(); + if (learner_config_.constraints().max_number_of_unique_feature_columns()) { + ensemble_resource->MaybeAddUsedHandler(split->handler_id); + } } void PruneTree(boosted_trees::trees::DecisionTreeConfig* tree_config) { diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py index 72e20aaa127cda592bd314786cddb925cc87a075..7df514cd207c5e781f3b4abaa2020016b197669d 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py @@ -436,7 +436,7 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column, quantized_feature = quantile_ops.quantiles([float_column], [], [quantile_buckets], [], []) quantized_feature = math_ops.cast(quantized_feature[0], dtypes.int64) - quantized_feature = array_ops.squeeze(quantized_feature) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) return (example_partition_ids, quantized_feature, gradients, hessians) def not_ready_inputs_fn(): @@ -468,7 +468,7 @@ def sparse_make_stats_update( [sparse_column_indices]) quantized_feature = math_ops.cast(quantized_feature[1], dtypes.int64) - quantized_feature = array_ops.squeeze(quantized_feature) + quantized_feature = array_ops.squeeze(quantized_feature, axis=0) example_indices, _ = array_ops.split( sparse_column_indices, num_or_size_splits=2, axis=1) diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py index ee16a5f838a65f20db4436eb86527518621b6d8d..54d03018d9e266beabbbabd78ebbb80cfe689c04 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py +++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler_test.py @@ -1121,6 +1121,87 @@ class SparseSplitHandlerTest(test_util.TensorFlowTestCase): self.assertEqual(len(gains), 0) self.assertEqual(len(splits), 0) + def testDegenerativeCase(self): + with self.test_session() as sess: + # One data example only, one leaf and thus one quantile bucket.The same + # situation is when all examples have the same values. This case was + # causing before a failure. + gradients = array_ops.constant([0.2]) + hessians = array_ops.constant([0.12]) + example_partitions = array_ops.constant([1], dtype=dtypes.int32) + indices = array_ops.constant([[0, 0]], dtype=dtypes.int64) + values = array_ops.constant([0.58]) + sparse_column = sparse_tensor.SparseTensor(indices, values, [1, 1]) + + gradient_shape = tensor_shape.scalar() + hessian_shape = tensor_shape.scalar() + class_id = -1 + + split_handler = ordinal_split_handler.SparseSplitHandler( + l1_regularization=0, + l2_regularization=2, + tree_complexity_regularization=0, + min_node_weight=0, + epsilon=0.01, + num_quantiles=2, + feature_column_group_id=0, + sparse_float_column=sparse_column, + init_stamp_token=0, + gradient_shape=gradient_shape, + hessian_shape=hessian_shape, + multiclass_strategy=learner_pb2.LearnerConfig.TREE_PER_CLASS) + resources.initialize_resources(resources.shared_resources()).run() + + empty_gradients, empty_hessians = get_empty_tensors( + gradient_shape, hessian_shape) + example_weights = array_ops.ones([1, 1], dtypes.float32) + + update_1 = split_handler.update_stats_sync( + 0, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_1]): + are_splits_ready = split_handler.make_splits(0, 1, class_id)[0] + + with ops.control_dependencies([are_splits_ready]): + update_2 = split_handler.update_stats_sync( + 1, + example_partitions, + gradients, + hessians, + empty_gradients, + empty_hessians, + example_weights, + is_active=array_ops.constant([True, True])) + with ops.control_dependencies([update_2]): + are_splits_ready2, partitions, gains, splits = ( + split_handler.make_splits(1, 2, class_id)) + are_splits_ready, are_splits_ready2, partitions, gains, splits = ( + sess.run([ + are_splits_ready, are_splits_ready2, partitions, gains, splits + ])) + + # During the first iteration, inequality split handlers are not going to + # have any splits. Make sure that we return not_ready in that case. + self.assertFalse(are_splits_ready) + self.assertTrue(are_splits_ready2) + + self.assertAllEqual([1], partitions) + self.assertAllEqual([0.0], gains) + + split_info = split_info_pb2.SplitInfo() + split_info.ParseFromString(splits[0]) + split_node = split_info.split_node.sparse_float_binary_split_default_left + + self.assertEqual(0, split_node.split.feature_column) + + self.assertAllClose(0.58, split_node.split.threshold) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h index e1bef0278846e7ff6abc91e8c57f780af45e8b41..3c54868951a6db93a8b685c8da4dfc78996b7b1f 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/class-partition-key.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ #include "tensorflow/core/lib/hash/hash.h" @@ -58,4 +58,4 @@ struct ClassPartitionKey { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_CLASS_PARTITION_KEY_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h index 3814edb5675be74794a08e00becb649f8fc53fdb..ec4e7c52bb5f4536a50192e1b5fcc019dd7b2511 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/accumulators/feature-stats-accumulator.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ #include #include @@ -79,4 +79,4 @@ class FeatureStatsAccumulator { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_ACCUMULATORS_FEATURE_STATS_ACCUMULATOR_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h b/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h index aed0d9fdac108dff4576cc1563dae420340387be..37a71037041445e6a6fcf6290015b93cffef1618 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/partitioners/example_partitioner.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ #include #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" @@ -50,4 +50,4 @@ class ExamplePartitioner { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_PARTITIONERS_EXAMPLE_PARTITIONER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h index 339c2e0fded10e6a7b140da62e152e2868ffd164..382b85cf0b2c146f82fa79551c569b9c70d9b7a6 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/feature-split-candidate.h @@ -13,8 +13,8 @@ // limitations under the License. // // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ #include "tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" @@ -58,4 +58,4 @@ struct FeatureSplitCandidate { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_FEATURE_SPLIT_CANDIDATE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h index 34e3ddb777242553d62035a51f1aec33d0f9ba54..3dd03215d88abc223a2d081d11901ffd3fb7aaa9 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/gradient-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ #include @@ -190,4 +190,4 @@ inline GradientStats operator-(const GradientStats& a, const GradientStats& b) { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_GRADIENT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h index 642a183aec5c7e591579fa5ee91d45729bfb624d..cd925f6b65e569538212e9c26aef0abc8482960b 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ #include "third_party/eigen3/Eigen/Core" #include "third_party/eigen3/Eigen/Eigenvalues" @@ -298,4 +298,4 @@ struct NodeStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_NODE_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc index f867e77d3ef0609774628b2a9c36ca52bcf2a957..8bca132acfde9397942b198db9a8d4c0e4d74897 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/node-stats_test.cc @@ -17,8 +17,8 @@ #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/platform/test.h" -using tensorflow::test::AsTensor; using std::vector; +using tensorflow::test::AsTensor; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h index 054ccd9a8cd0be0c48b14cca013f15677deba900..81ee2774bdab91f492064455055181c56ef6a065 100644 --- a/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h +++ b/tensorflow/contrib/boosted_trees/lib/learner/common/stats/split-stats.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ #include @@ -81,4 +81,4 @@ struct SplitStats { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_LEARNER_COMMON_STATS_SPLIT_STATS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index ee29a8aa797b96d41ec2d77bf831ee287d5443e7..cc3dc226cdbc88fc7010ada1e7f0e6c0a3913c5f 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ #include @@ -45,4 +45,4 @@ class MultipleAdditiveTrees { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_MODELS_MULTIPLE_ADDITIVE_TREES_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h index 70037d5bd8f446bdbbfcc468edb8a76c05e4fab7..804b218f1c08338df80f8dd2e6135f5d92b9928e 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ #include #include @@ -129,4 +129,4 @@ constexpr decltype(CompareFn()) } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h index fd577ad712f228fa8016a48942511a3263aae5da..8ad97fedc923ac50bcaad86e0ba2c2e46df6821b 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h @@ -12,12 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#include #include #include -#include #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_buffer.h" #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h" @@ -322,4 +322,4 @@ WeightedQuantilesStream::GetQuantileSpecs( } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index c329c6d4f7363a7738b06648943fe1dbd065cce5..aec232f3cbb096f0aa51e4362a821882391f8027 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ #include #include @@ -334,4 +334,4 @@ constexpr decltype(CompareFn()) } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h index d95878ec87b9e903930d2016bb573eee2573f776..b98190b10dc88d5bba9023e771844a2bd6c9a45d 100644 --- a/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h +++ b/tensorflow/contrib/boosted_trees/lib/testutil/batch_features_testutil.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ #include "tensorflow/contrib/boosted_trees/lib/utils/batch_features.h" #include "tensorflow/core/framework/tensor.h" @@ -42,4 +42,4 @@ void RandomlyInitializeBatchFeatures( } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_BATCH_FEATURES_TESTUTIL_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc index cbe26ba918d384ad903fb854ca3e88e84d16a923..705b65e9db9f1aed9af1be153240d57e163c2d5b 100644 --- a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.cc @@ -22,9 +22,9 @@ namespace tensorflow { namespace boosted_trees { namespace testutil { +using boosted_trees::trees::DenseFloatBinarySplit; using tensorflow::boosted_trees::trees::DecisionTreeConfig; using tensorflow::boosted_trees::trees::TreeNode; -using boosted_trees::trees::DenseFloatBinarySplit; namespace { diff --git a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h index 5e12429ba778344edda623d149e017661f1e0222..1838b4cee21afb5df72a9b902f0ec0ce6f7ac627 100644 --- a/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h +++ b/tensorflow/contrib/boosted_trees/lib/testutil/random_tree_gen.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ #include @@ -72,4 +72,4 @@ class RandomTreeGen { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TESTUTIL_RANDOM_TREE_GEN_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h index 604ff02744b25b136bd935bf85635731730effe8..43526c229a65d45a2b0ced4aa1262d489526fc7b 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ #include "tensorflow/contrib/boosted_trees/lib/utils/example.h" #include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" // NOLINT @@ -46,4 +46,4 @@ class DecisionTree { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_TREES_DECISION_TREE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index badc629a118f768d5aa25ef1b94b8190e6910c7f..da5e7448519cb7f4092f7bbbe1b526271008ec22 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ #include #include "tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h" @@ -92,4 +92,4 @@ class BatchFeatures { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_BATCH_FEATURES_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 9de3e32b097a151b3bd6f5c30df2db0938b65e9c..609519e8b1153a27d987c5f9ca9bfcc9ee6717d6 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -25,8 +25,8 @@ namespace boosted_trees { namespace utils { namespace { -using test::AsTensor; using errors::InvalidArgument; +using test::AsTensor; class BatchFeaturesTest : public ::testing::Test {}; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc index 38f0151255bbf4fcd87f1d0d76fd111649ee4a12..db34db998a7442c69f2ab468f4557d991429f4ee 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.cc @@ -23,10 +23,10 @@ #include "tensorflow/core/lib/random/simple_philox.h" #include "tensorflow/core/platform/logging.h" +using tensorflow::Status; using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; using tensorflow::random::PhiloxRandom; using tensorflow::random::SimplePhilox; -using tensorflow::Status; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h index c3f1c918ca5f603cf9470071017d8ee384dc9320..928bfbfe5c9394ab4083aabced4c8e1149bb10aa 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ #include #include @@ -74,4 +74,4 @@ class DropoutUtils { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_DROPOUT_UTILS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc index ce7632e58987f5890beaded5dd305724f950e1e8..02f972c8e00e8229426ac53d8f20765484787b6e 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/dropout_utils_test.cc @@ -26,9 +26,9 @@ #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/env.h" +using std::unordered_set; using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; using tensorflow::boosted_trees::trees::DecisionTreeEnsembleConfig; -using std::unordered_set; namespace tensorflow { namespace boosted_trees { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h index 54f60e1dee49a4a40b84fcc6e042fac1858aa187..1371ff337f78dd1c38f2bd0ba86911642f3aeb3e 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/example.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ #include #include @@ -131,4 +131,4 @@ struct Example { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h index 5b33c8158879ec65425ac77b5338ee98fbdf07db..1b654e1c44e545fb97216ad950f3cd2d3240ffd0 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ #include @@ -205,4 +205,4 @@ class ExamplesIterable { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLES_ITERABLE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/macros.h b/tensorflow/contrib/boosted_trees/lib/utils/macros.h index 28ea0a4dc191af66ced574d78d9873cc8335f491..9a53fb2ef7d0581986885f3bc8233d91b67c0166 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/macros.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/macros.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ #include "tensorflow/core/platform/macros.h" @@ -23,4 +23,4 @@ return (STATUS); \ } -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_MACROS_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h index c141fe059d48072c6c4495535eafec9633616d21..b2166f53d7a037fb8ec53d5295b98bb82b17d4c7 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/optional_value.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ #include "tensorflow/core/platform/logging.h" @@ -44,4 +44,4 @@ class OptionalValue { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_OPTIONAL_VALUE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h index c80431b5587cecc0bce22f6150a69d30397529da..ec06787e1db69514c9e60f6d152f3b0c7de23842 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/parallel_for.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ +#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ +#define TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ #include "tensorflow/core/lib/core/threadpool.h" @@ -30,4 +30,4 @@ void ParallelFor(int64 batch_size, int64 desired_parallelism, } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ +#endif // TENSORFLOW_CONTRIB_LIB_UTILS_PARALLEL_FOR_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/random.h b/tensorflow/contrib/boosted_trees/lib/utils/random.h index 6dd55fcacc42b88116737ab6fb413852ffc1473d..546d344f5585458f10699a644621f0adf26b6446 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/random.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/random.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ +#ifndef TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ +#define TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ #include "tensorflow/core/lib/random/simple_philox.h" @@ -36,4 +36,4 @@ inline int32 PoissonBootstrap(random::SimplePhilox* rng) { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ +#endif // TENSORFLOW_CONTRIB_LIB_UTILS_RANDOM_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc index 0d46565a1962b88cbb267f3d6043610758790578..1297aa884938f2f099a32568acc80c6cd8162651 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc @@ -51,7 +51,7 @@ class IndicesRowIterator return tmp; } - reference operator*() { return iter_->ix()(row_idx_, 0); } + reference operator*() const { return iter_->ix()(row_idx_, 0); } pointer operator->() { return &iter_->ix()(row_idx_, 0); } @@ -97,7 +97,7 @@ class IndicesRowIterator } bool operator<(const IndicesRowIterator& other) const { - return (row_idx_ < other.row_idx_); + return (row_idx_ < other.row_idx_); } bool operator==(const IndicesRowIterator& other) const { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h index 9664c9d1c6a0c0c8b1bbd1506944c54d2310c611..87fb1fbf5ae3cc6bcf25f68a180d1d9b21ef4d6f 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" @@ -127,4 +127,4 @@ class SparseColumnIterable { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_SPARSE_COLUMN_ITERABLE_H_ diff --git a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h index 58f5e5a0d18788375cd8166d1fcbdc7c294ba5e2..475d3718eccc2b23260b7cf5286abdd31ef1bad6 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -57,4 +57,4 @@ class TensorUtils { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc index 1fa70bafddb0c94f47d006d5694bea941edaddf9..ae99d53a2cf805d70d60746cd44f73f7fd9dc6e2 100644 --- a/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -19,8 +19,8 @@ namespace tensorflow { namespace boosted_trees { -using shape_inference::InferenceContext; using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; REGISTER_RESOURCE_HANDLE_OP(QuantileStreamResource); @@ -39,6 +39,7 @@ REGISTER_OP("CreateQuantileAccumulator") .Attr("max_elements: int = 1099511627776") // 1 << 40 .Attr("epsilon: float") .Attr("num_quantiles: int") + .Attr("generate_quantiles: bool=False") .Input("quantile_accumulator_handle: resource") .Input("stamp_token: int64") .SetShapeFn([](shape_inference::InferenceContext* c) { diff --git a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc index 0d27ddaf3a1d540efee268c2bcca217077ff5871..5d0ebbf73ce1272b51a475f67984db3a181b7130 100644 --- a/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc @@ -18,9 +18,9 @@ namespace tensorflow { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -using shape_inference::DimensionHandle; REGISTER_OP("BuildDenseInequalitySplits") .Attr("feature_column_group_id: int") diff --git a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc index 0354f7853cbedf22d0a299273b4dbd225b3121ab..179505eef01f79bb149137400468b84285fe478a 100644 --- a/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc @@ -19,9 +19,9 @@ namespace tensorflow { namespace boosted_trees { +using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -using shape_inference::DimensionHandle; REGISTER_RESOURCE_HANDLE_OP(StatsAccumulatorScalarResource); diff --git a/tensorflow/contrib/boosted_trees/proto/learner.proto b/tensorflow/contrib/boosted_trees/proto/learner.proto index 919e7cd81427c27cf892bc77998f52406d2bcf15..d84ba7438e7f03685d5bafca52ff8283f0fce898 100644 --- a/tensorflow/contrib/boosted_trees/proto/learner.proto +++ b/tensorflow/contrib/boosted_trees/proto/learner.proto @@ -22,6 +22,10 @@ message TreeConstraintsConfig { // Min hessian weight per node. float min_node_weight = 2; + + // Maximum number of unique features used in the tree. Zero means there is no + // limit. + int64 max_number_of_unique_feature_columns = 3; } // LearningRateConfig describes all supported learning rate tuners. diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index fc570c1083d01a65760a456c109dad93afd9f62a..4407c4d981785a279b6296f4726a221cacb4c5b1 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -128,6 +128,10 @@ message GrowingMetadata { // Number of layers that we have attempted to build. After pruning, these // layers might have been removed. int64 num_layers_attempted = 2; + + // Sorted list of column handlers that have been used in at least one split + // so far. + repeated int64 used_handler_ids = 3; } // DecisionTreeEnsembleConfig describes an ensemble of decision trees. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py index 888d5c57ed33446c8b6f18d2d1e393647613d132..81f58de28cbe98bb996c6665114eeb0030ee52f9 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/quantile_ops_test.py @@ -106,9 +106,11 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): | 6 | 16 | [16, 17, 18, 19, 20, 21] """ + num_quantiles = 3 with self.test_session() as sess: accumulator = quantile_ops.QuantileAccumulator( - init_stamp_token=0, num_quantiles=3, epsilon=0.001, name="q1") + init_stamp_token=0, num_quantiles=num_quantiles, + epsilon=0.001, name="q1") resources.initialize_resources(resources.shared_resources()).run() input_column = array_ops.placeholder(dtypes.float32) weights = array_ops.placeholder(dtypes.float32) @@ -131,8 +133,128 @@ class QuantileBucketsOpTest(test_util.TensorFlowTestCase): buckets, are_ready_flush = (sess.run( [buckets, are_ready_flush])) self.assertEqual(True, are_ready_flush) + self.assertEqual(num_quantiles + 1, len(buckets)) self.assertAllEqual([1, 86., 170., 253.], buckets) + def testStreamingQuantileBucketsLowPrecisionInput(self): + """Tests inputs that simulate low precision float16 values.""" + + num_quantiles = 3 + # set generate_quantiles to True since the test will generate fewer + # boundaries otherwise. + with self.test_session() as sess: + accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token=0, num_quantiles=num_quantiles, + epsilon=0.001, name="q1", generate_quantiles=True) + resources.initialize_resources(resources.shared_resources()).run() + input_column = array_ops.placeholder(dtypes.float32) + weights = array_ops.placeholder(dtypes.float32) + update = accumulator.add_summary( + stamp_token=0, + column=input_column, + example_weights=weights) + + with self.test_session() as sess: + # This input is generated by integer in the range [2030, 2060] + # but represented by with float16 precision. Integers <= 2048 are + # exactly represented, whereas numbers > 2048 are rounded; and hence + # numbers > 2048 are repeated. For precision loss / rounding, see: + # https://en.wikipedia.org/wiki/Half-precision_floating-point_format. + # + # The intent of the test is not handling of float16 values, but to + # validate the number of buckets is returned, in cases where the input + # may contain repeated values. + inputs = [ + 2030.0, 2031.0, 2032.0, 2033.0, 2034.0, 2035.0, 2036.0, 2037.0, + 2038.0, 2039.0, 2040.0, 2041.0, 2042.0, 2043.0, 2044.0, 2045.0, + 2046.0, 2047.0, 2048.0, 2048.0, 2050.0, 2052.0, 2052.0, 2052.0, + 2054.0, 2056.0, 2056.0, 2056.0, 2058.0, 2060.0 + ] + sess.run(update, + {input_column: inputs, + weights: [1] * len(inputs)}) + + with self.test_session() as sess: + sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1)) + are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1)) + buckets, are_ready_flush = (sess.run( + [buckets, are_ready_flush])) + self.assertEqual(True, are_ready_flush) + self.assertEqual(num_quantiles + 1, len(buckets)) + self.assertAllEqual([2030, 2040, 2050, 2060], buckets) + + def _testStreamingQuantileBucketsHelper( + self, inputs, num_quantiles=3, expected_buckets=None): + """Helper to test quantile buckets on different inputs.""" + + # set generate_quantiles to True since the test will generate fewer + # boundaries otherwise. + with self.test_session() as sess: + accumulator = quantile_ops.QuantileAccumulator( + init_stamp_token=0, num_quantiles=num_quantiles, + epsilon=0.001, name="q1", generate_quantiles=True) + resources.initialize_resources(resources.shared_resources()).run() + input_column = array_ops.placeholder(dtypes.float32) + weights = array_ops.placeholder(dtypes.float32) + update = accumulator.add_summary( + stamp_token=0, + column=input_column, + example_weights=weights) + + with self.test_session() as sess: + sess.run(update, + {input_column: inputs, + weights: [1] * len(inputs)}) + + with self.test_session() as sess: + sess.run(accumulator.flush(stamp_token=0, next_stamp_token=1)) + are_ready_flush, buckets = (accumulator.get_buckets(stamp_token=1)) + buckets, are_ready_flush = (sess.run( + [buckets, are_ready_flush])) + self.assertEqual(True, are_ready_flush) + # By default, use 3 quantiles, 4 boundaries for simplicity. + self.assertEqual(num_quantiles + 1, len(buckets)) + if expected_buckets: + self.assertAllEqual(buckets, expected_buckets) + + def testStreamingQuantileBucketsRepeatedSingleValue(self): + inputs = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1] + self._testStreamingQuantileBucketsHelper(inputs) + + def testStreamingQ2antileBucketsRepeatedTwoValues(self): + inputs = [1, 1, 1, 2, 2, 2, 2, 2, 1, 1] + self._testStreamingQuantileBucketsHelper(inputs) + + def testStreamingQ2antileBucketsRepeatedTwoValuesUnbalanced(self): + inputs = [7, 7, 7, 2, 7, 7, 2, 2, 7, 7] + self._testStreamingQuantileBucketsHelper(inputs) + + def testStreamingQuantileBucketsFewerInputstThanBuckets(self): + inputs = [5] + self._testStreamingQuantileBucketsHelper(inputs) + + def testStreamingQuantileBucketsEqualDistributionInSequence(self): + # Input pattern is of the form [1, 1, 1, 2, 2, 2, 3, 3, 3, ...] + ones = 100 * [1] + inputs = [] + for i in range(1, 101): + inputs += [i * k for k in ones] + # Expect 100 equally spaced buckets. + expected_buckets = range(1, 101) + self._testStreamingQuantileBucketsHelper( + inputs, num_quantiles=99, expected_buckets=expected_buckets) + + def testStreamingQuantileBucketsEqualDistributionInterleaved(self): + # Input pattern is of the form [1, 2, 3, 1, 2, 3, 1, 2, 3, ...] + sequence = range(1, 101) + inputs = [] + for _ in range(1, 101): + inputs += sequence + # Expect 100 equally spaced buckets. + expected_buckets = range(1, 101) + self._testStreamingQuantileBucketsHelper( + inputs, num_quantiles=99, expected_buckets=expected_buckets) + def testStreamingQuantileBuckets(self): """Sets up the quantile summary op test as follows. diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py index c2e65b643df90e88aadb0bb9acaf692da35b1a16..8ca1aabacaf53b66aaba184962922294427d6803 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/training_ops_test.py @@ -63,7 +63,7 @@ def _gen_learner_config(num_classes, if dropout_prob_of_skipping is not None: config.learning_rate_tuner.dropout.dropout_prob_of_skipping = ( dropout_prob_of_skipping) - return config.SerializeToString() + return config def _gen_dense_split_info(fc, threshold, left_weight, right_weight): @@ -145,7 +145,7 @@ class CenterTreeEnsembleBiasOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here. - dropout_probability=0.5) + dropout_probability=0.5).SerializeToString() # Center bias for the initial step. grads = constant_op.constant([0.4, -0.3]) @@ -296,7 +296,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here, tree is not finalized. - dropout_probability=0.5) + dropout_probability=0.5).SerializeToString() # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -443,7 +443,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, # Dropout does not change anything here - tree is not finalized. - dropout_probability=0.5) + dropout_probability=0.5).SerializeToString() # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -632,7 +632,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -772,7 +773,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # All handlers have negative gain. @@ -837,7 +839,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=1, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # Note that handlers 1 & 3 have the same gain but different splits. @@ -943,7 +946,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # All handlers have negative gain. @@ -1090,7 +1094,8 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): max_depth=2, min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.POST_PRUNE, - growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE).SerializeToString( + ) # Prepare handler inputs. # Second handler has positive gain. @@ -1330,7 +1335,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.LAYER_BY_LAYER, # Dropout will have no effect, since the tree will not be fully grown. - dropout_probability=1.0) + dropout_probability=1.0).SerializeToString() # Prepare handler inputs. # Handler 1 only has a candidate for partition 1, handler 2 has candidates @@ -1538,7 +1543,7 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): min_node_weight=0, pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE, - dropout_probability=1.0) + dropout_probability=1.0).SerializeToString() # Prepare handler inputs. handler1_partitions = np.array([0], dtype=np.int32) @@ -1583,6 +1588,301 @@ class GrowTreeEnsembleOpTest(test_util.TensorFlowTestCase): self.assertEqual( 2, tree_ensemble_config.tree_metadata[2].num_tree_weight_updates) + def testGrowExistingEnsembleTreeWithFeatureSelectionCanStillGrow(self): + """Test growing a tree with feature selection.""" + with self.test_session() as session: + # Create existing ensemble with one root split and one bias tree. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge(""" + trees { + nodes { + leaf { + vector { + value: -0.32 + value: 0.28 + } + } + } + } + trees { + nodes { + categorical_id_binary_split { + feature_column: 3 + feature_id: 7 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 1.3 + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: 2.3 + } + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: -0.9 + } + } + } + } + tree_weights: 0.7 + tree_weights: 1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 5 + num_layers_grown: 1 + is_finalized: true + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 2 + used_handler_ids: 2 + used_handler_ids: 5 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=1, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + # There are 2 handler_ids in used_handler_ids already but one of them + # is handler 2, so we can still grow trees. + learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config = learner_config.SerializeToString() + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([7.62], dtype=np.float32) + handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([0.63], dtype=np.float32) + handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([7.62], dtype=np.float32) + handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] + + # Grow tree ensemble. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config, + dropout_seed=123, + center_bias=True) + session.run(grow_op) + + # Expect a new tree to be added with the split from handler 1. + _, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + tree_ensemble_config.ParseFromString(serialized) + self.assertEqual(3, len(tree_ensemble_config.trees)) + self.assertEqual( + 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + + def testGrowExistingEnsembleTreeWithFeatureSelectionEmptyEnsemble(self): + """Test growing a tree with feature selection with empty ensemble.""" + with self.test_session() as session: + # Create existing ensemble with one root split and one bias tree. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=1, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config = learner_config.SerializeToString() + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([7.62], dtype=np.float32) + handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([0.63], dtype=np.float32) + handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([7.62], dtype=np.float32) + handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] + + # Grow tree ensemble. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config, + dropout_seed=123, + center_bias=True) + session.run(grow_op) + + _, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + tree_ensemble_config.ParseFromString(serialized) + self.assertEqual(1, len(tree_ensemble_config.trees)) + self.assertEqual( + 1, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + + def testGrowExistingEnsembleTreeWithFeatureSelectionCantGrow(self): + """Test growing a tree with feature selection with empty ensemble.""" + with self.test_session() as session: + # Create existing ensemble with one root split and one bias tree. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + text_format.Merge(""" + trees { + nodes { + leaf { + vector { + value: -0.32 + value: 0.28 + } + } + } + } + trees { + nodes { + categorical_id_binary_split { + feature_column: 3 + feature_id: 7 + left_id: 1 + right_id: 2 + } + node_metadata { + gain: 1.3 + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: 2.3 + } + } + } + nodes { + leaf { + sparse_vector { + index: 0 + value: -0.9 + } + } + } + } + tree_weights: 0.7 + tree_weights: 1 + tree_metadata { + num_tree_weight_updates: 1 + num_layers_grown: 1 + is_finalized: true + } + tree_metadata { + num_tree_weight_updates: 5 + num_layers_grown: 1 + is_finalized: true + } + growing_metadata { + num_trees_attempted: 2 + num_layers_attempted: 2 + used_handler_ids: 4 + used_handler_ids: 5 + } + """, tree_ensemble_config) + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="tree_ensemble") + resources.initialize_resources(resources.shared_resources()).run() + + # Prepare learner config. + learner_config = _gen_learner_config( + num_classes=2, + l1_reg=0, + l2_reg=0, + tree_complexity=0, + max_depth=1, + min_node_weight=0, + pruning_mode=learner_pb2.LearnerConfig.PRE_PRUNE, + growing_mode=learner_pb2.LearnerConfig.WHOLE_TREE) + learner_config.constraints.max_number_of_unique_feature_columns = 2 + learner_config = learner_config.SerializeToString() + # Prepare handler inputs. + handler1_partitions = np.array([0], dtype=np.int32) + handler1_gains = np.array([7.62], dtype=np.float32) + handler1_split = [_gen_dense_split_info(5, 0.52, -4.375, 7.143)] + handler2_partitions = np.array([0], dtype=np.int32) + handler2_gains = np.array([0.63], dtype=np.float32) + handler2_split = [_gen_dense_split_info(2, 0.23, -0.6, 0.24)] + handler3_partitions = np.array([0], dtype=np.int32) + handler3_gains = np.array([7.62], dtype=np.float32) + handler3_split = [_gen_categorical_split_info(8, 7, -4.375, 7.143)] + + # Grow tree ensemble. + grow_op = training_ops.grow_tree_ensemble( + tree_ensemble_handle, + stamp_token=0, + next_stamp_token=1, + learning_rate=1, + partition_ids=[ + handler1_partitions, handler2_partitions, handler3_partitions + ], + gains=[handler1_gains, handler2_gains, handler3_gains], + splits=[handler1_split, handler2_split, handler3_split], + learner_config=learner_config, + dropout_seed=123, + center_bias=True) + session.run(grow_op) + + _, serialized = session.run( + model_ops.tree_ensemble_serialize(tree_ensemble_handle)) + tree_ensemble_config.ParseFromString(serialized) + # We can't grow a tree since we have reached the limit of 2 unique + # features [4, 5] and the only available splits are from + # handlers [0, 1, 2]. + self.assertEqual(2, len(tree_ensemble_config.trees)) + self.assertEqual( + 2, len(tree_ensemble_config.growing_metadata.used_handler_ids)) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py index 23168bf4935e92bcb5072348361ae04861641b6d..7a5f329b7ab3216972180ccbb4c85f2537175422 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py +++ b/tensorflow/contrib/boosted_trees/python/ops/batch_ops_utils.py @@ -81,32 +81,32 @@ def _scheduled_stamp_resource_op_runner(batch, stamp): if not batch: return arg_keys = set(batch[0].args.keys()) - grouped_args = collections.defaultdict(list) + grouped_args = collections.OrderedDict() resource_handles = [] # Check that the set of arguments is the same across all the scheduled ops. for op in batch: if set(op.args.keys()) != arg_keys: raise ValueError("Mismatching arguments: %s, %s.", op.args, arg_keys) for key in arg_keys: - grouped_args[key].append(op.args[key]) + grouped_args.setdefault(key, []).append(op.args[key]) resource_handles.append(op.resource_handle) # Move all the inputs to the op device in one RPC. - grouped_args = { - k: _move_tensors(v, resource_handles[0].device) - for k, v in grouped_args.items() - } + grouped_args = collections.OrderedDict( + (k, _move_tensors(v, resource_handles[0].device)) + for k, v in sorted(grouped_args.items())) with ops.device(resource_handles[0].device): return batch[0].op(resource_handles, stamp, **grouped_args) def run_handler_scheduled_ops(per_handler_ops, stamp, worker_device): """Given a dictionary of ops for each handler, runs them in batch.""" - batched_ops = collections.defaultdict(list) + batched_ops = collections.OrderedDict() # Group the ops by their batching_key. Ops that share the same batching key # can be executed together. for handler in per_handler_ops.keys(): for op in per_handler_ops[handler]: - batched_ops[(op.batching_key(), op.batch_runner_fn())].append(op) + key = (op.batching_key(), op.batch_runner_fn()) + batched_ops.setdefault(key, []).append(op) op_results = {} for batch in batched_ops.values(): # Run each of the batched ops using its runner. diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py index 7e8e15e7d8c89d1adaa472b1da7e8bb3c73ca17e..97d57e8b23608d4c3a8719426a75056fc6417d1d 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py @@ -45,18 +45,24 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): init_stamp_token, epsilon, num_quantiles, + max_elements=None, name=None, - container=None): + container=None, + generate_quantiles=False): """Creates a QuantileAccumulator object. Args: init_stamp_token: The initial value for the stamp token. epsilon: Error bound on the quantile computation. num_quantiles: Number of quantiles to produce from the final summary. + max_elements: Maximum number of elements added to the accumulator. name: the name to save the accumulator under. container: An optional `string`. Defaults to `""` + generate_quantiles: Generate quantiles instead of approximate boundaries. + If true, exactly `num_quantiles` will be produced in the final summary. """ self._epsilon = epsilon + self._generate_quantiles = generate_quantiles name = _PATTERN.sub("", name) with ops.name_scope(name, "QuantileAccumulator") as name: @@ -67,7 +73,9 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): self._quantile_accumulator_handle, init_stamp_token, epsilon=epsilon, - num_quantiles=num_quantiles) + max_elements=max_elements, + num_quantiles=num_quantiles, + generate_quantiles=generate_quantiles) is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized( self._quantile_accumulator_handle) resources.register_resource(self._quantile_accumulator_handle, @@ -173,7 +181,14 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject): summaries=summary) def flush(self, stamp_token, next_stamp_token): - """Finalizes quantile summary stream and resets it for next iteration.""" + """Finalizes quantile summary stream and resets it for next iteration. + + Args: + stamp_token: Exepcted current token. + next_stamp_token: Next value for the token. + Returns: + A list of quantiles or approximate boundaries. + """ return gen_quantile_ops.quantile_accumulator_flush( quantile_accumulator_handle=self._quantile_accumulator_handle, stamp_token=stamp_token, diff --git a/tensorflow/contrib/boosted_trees/python/training/__init__.py b/tensorflow/contrib/boosted_trees/python/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b569ac5fdb60e0907c322ad73aca65645e548d94 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/python/training/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""training module under boosted_trees.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py b/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1750117cd7c311515b4bca6882d55f496daac0e --- /dev/null +++ b/tensorflow/contrib/boosted_trees/python/training/functions/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""functions module under boosted_trees.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 6094dae6b59d8b05bb12a28cf167a536e6825287..f0b66dcbbe1c5167b9993e66b30b1dc8a839c380 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import copy from tensorflow.contrib import learn @@ -163,7 +164,7 @@ def extract_features(features, feature_columns): scope = "gbdt" with variable_scope.variable_scope(scope): feature_columns = list(feature_columns) - transformed_features = {} + transformed_features = collections.OrderedDict() for fc in feature_columns: # pylint: disable=protected-access if isinstance(fc, feature_column_lib._EmbeddingColumn): @@ -322,9 +323,11 @@ class GradientBoostedDecisionTreeModel(object): self._feature_columns = feature_columns self._learner_config_serialized = learner_config.SerializeToString() self._attempted_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="attempted_trees") self._finalized_trees = variables.Variable( - initial_value=array_ops.zeros([], dtypes.int64), trainable=False) + initial_value=array_ops.zeros([], dtypes.int64), trainable=False, + name="finalized_trees") if not features: raise ValueError("Features dictionary must be specified.") (fc_names, dense_floats, sparse_float_indices, sparse_float_values, @@ -679,13 +682,13 @@ class GradientBoostedDecisionTreeModel(object): control_flow_ops.no_op)) # Update handler stats. - handler_reads = {} + handler_reads = collections.OrderedDict() for handler in handlers: handler_reads[handler] = handler.scheduled_reads() handler_results = batch_ops_utils.run_handler_scheduled_ops( handler_reads, ensemble_stamp, worker_device) - per_handler_updates = {} + per_handler_updates = collections.OrderedDict() # Two values per handler. First one is if the handler is active for the # current layer. The second one is if the handler is going to be active # for the next layer. diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py index 16e24d97ddee0751e0b808b89080074c1b4baba7..dba51d4f527792d2a8dedc693f74c07119fd231d 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py @@ -912,8 +912,10 @@ class GbdtTest(test_util.TensorFlowTestCase): self.assertEqual(1, len(output.trees[0].nodes[2].leaf.sparse_vector.index)) self.assertEqual(3, output.trees[0].nodes[2].leaf.sparse_vector.index[0]) - self.assertAlmostEqual( - 0.893284678459, output.trees[0].nodes[2].leaf.sparse_vector.value[0]) + self.assertAllClose( + 0.893284678459, + output.trees[0].nodes[2].leaf.sparse_vector.value[0], + atol=1e-4, rtol=1e-4) if __name__ == "__main__": diff --git a/tensorflow/contrib/boosted_trees/python/utils/__init__.py b/tensorflow/contrib/boosted_trees/python/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6ceb150c26552584d631948f5eef2fedfa690894 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/python/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""utils module under boosted_trees.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 1e8b3ac08a74a94a0e5729e42ace91398a7b5c94..ab7ac2aba605db22a8ed370049b27d55cf1d413a 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -78,7 +78,7 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): # Calculate softmax probabilities for each class. unnormalized_probs = math_ops.exp(logits) - normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keep_dims=True) + normalizers = math_ops.reduce_sum(unnormalized_probs, 1, keepdims=True) softmax_predictions = math_ops.divide(unnormalized_probs, math_ops.add(normalizers, eps)) @@ -120,7 +120,7 @@ def per_example_squared_loss(labels, weights, predictions): update_op: An update operation to update the loss's internal state. """ unweighted_loss = math_ops.reduce_sum( - math_ops.square(predictions - labels), 1, keep_dims=True) + math_ops.square(predictions - labels), 1, keepdims=True) return unweighted_loss * weights, control_flow_ops.no_op() diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 284ad5cdb9abf374650940ade7bb36663d72c0dd..3ebf28ea442edf87815c39971ae9e01a2a8aae9a 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ #include "tensorflow/contrib/boosted_trees/lib/trees/decision_tree.h" #include "tensorflow/contrib/boosted_trees/resources/stamped_resource.h" @@ -111,6 +111,35 @@ class DecisionTreeEnsembleResource : public StampedResource { return decision_tree_ensemble_->tree_weights(index); } + void MaybeAddUsedHandler(const int32 handler_id) { + protobuf::RepeatedField* used_ids = + decision_tree_ensemble_->mutable_growing_metadata() + ->mutable_used_handler_ids(); + protobuf::RepeatedField::iterator first = + std::lower_bound(used_ids->begin(), used_ids->end(), handler_id); + if (first == used_ids->end()) { + used_ids->Add(handler_id); + return; + } + if (handler_id == *first) { + // It is a duplicate entry. + return; + } + used_ids->Add(handler_id); + std::rotate(first, used_ids->end() - 1, used_ids->end()); + } + + std::vector GetUsedHandlers() const { + std::vector result; + result.reserve( + decision_tree_ensemble_->growing_metadata().used_handler_ids().size()); + for (int64 h : + decision_tree_ensemble_->growing_metadata().used_handler_ids()) { + result.push_back(h); + } + return result; + } + // Sets the weight of i'th tree, and increment num_updates in tree_metadata. void SetTreeWeight(const int32 index, const float weight, const int32 increment_num_updates) { @@ -150,4 +179,4 @@ class DecisionTreeEnsembleResource : public StampedResource { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_DECISION_TREE_ENSEMBLE_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h index fb29f79e578e8e52b67de631c527be35b7772b41..fdaaae7f472c8f564ab45a8366d3746cbf1158ee 100644 --- a/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/quantile_stream_resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ #include "tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_stream.h" #include "tensorflow/contrib/boosted_trees/proto/quantiles.pb.h" // NOLINT @@ -32,12 +32,14 @@ using QuantileStream = class QuantileStreamResource : public StampedResource { public: QuantileStreamResource(const float epsilon, const int32 num_quantiles, - const int64 max_elements, int64 stamp_token) + const int64 max_elements, bool generate_quantiles, + int64 stamp_token) : stream_(epsilon, max_elements), are_buckets_ready_(false), epsilon_(epsilon), num_quantiles_(num_quantiles), - max_elements_(max_elements) { + max_elements_(max_elements), + generate_quantiles_(generate_quantiles) { set_stamp(stamp_token); } @@ -74,6 +76,11 @@ class QuantileStreamResource : public StampedResource { are_buckets_ready_ = are_buckets_ready; } + bool generate_quantiles() const { return generate_quantiles_; } + void set_generate_quantiles(bool generate_quantiles) { + generate_quantiles_ = generate_quantiles; + } + private: ~QuantileStreamResource() override {} @@ -95,10 +102,15 @@ class QuantileStreamResource : public StampedResource { const int32 num_quantiles_; // An upper-bound for the number of elements. int64 max_elements_; + + // Generate quantiles instead of approximate boundaries. + // If true, exactly `num_quantiles` will be produced in the final summary. + bool generate_quantiles_; + TF_DISALLOW_COPY_AND_ASSIGN(QuantileStreamResource); }; } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_QUANTILE_STREAM_RESOURCE_H_ diff --git a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h index aabeeb98516eda6f7e8e7e296d6860fe5d8d5ec3..957bbe8d61d3dd32adba1a7f0cf840c69bce6273 100644 --- a/tensorflow/contrib/boosted_trees/resources/stamped_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/stamped_resource.h @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#ifndef TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#define TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/platform/mutex.h" @@ -39,4 +39,4 @@ class StampedResource : public ResourceBase { } // namespace boosted_trees } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ +#endif // TENSORFLOW_CONTRIB_BOOSTED_TREES_RESOURCES_STAMPED_RESOURCE_H_ diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index aa8f5ed12bc6f779e3c1a923b9225ec283189747..fe8bd072afd43a64fa62a65bd8900b5a98dbe761 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -60,9 +60,7 @@ tf_py_test( size = "small", srcs = ["python/ops/bigquery_reader_ops_test.py"], additional_deps = [ - ":bigquery_reader_ops_op_lib", ":cloud_py", - "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 51821f6653550afd2d2e8a49b7337ff8ba0b5489..1bfd27305d569668a0bd67d876e59eec082296b3 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -18,7 +18,6 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" namespace tensorflow { - namespace { constexpr size_t kBufferSize = 1024 * 1024; // In bytes. @@ -40,33 +39,6 @@ Status ParseJson(StringPiece json, Json::Value* result) { return Status::OK(); } -string ColumnTypeToString(BigQueryTableAccessor::ColumnType enum_type) { - switch (enum_type) { - case BigQueryTableAccessor::ColumnType::kRecord: - return "RECORD"; - case BigQueryTableAccessor::ColumnType::kString: - return "STRING"; - case BigQueryTableAccessor::ColumnType::kBytes: - return "BYTES"; - case BigQueryTableAccessor::ColumnType::kInteger: - return "INTEGER"; - case BigQueryTableAccessor::ColumnType::kFloat: - return "FLOAT"; - case BigQueryTableAccessor::ColumnType::kBoolean: - return "BOOLEAN"; - case BigQueryTableAccessor::ColumnType::kTimestamp: - return "TIMESTAMP"; - case BigQueryTableAccessor::ColumnType::kDate: - return "DATE"; - case BigQueryTableAccessor::ColumnType::kTime: - return "TIME"; - case BigQueryTableAccessor::ColumnType::kDatetime: - return "DATETIME"; - case BigQueryTableAccessor::ColumnType::kNone: - return "NONE"; - } -} - Status ParseColumnType(const string& type, BigQueryTableAccessor::ColumnType* enum_type) { if (type == "RECORD") { @@ -202,22 +174,21 @@ Status BigQueryTableAccessor::ReadRow(int64* row_id, Example* example) { std::unique_ptr request(http_request_factory_->Create()); std::vector output_buffer; output_buffer.reserve(kBufferSize); - TF_RETURN_IF_ERROR(request->Init()); // The first time that we access BigQuery there is no page token. After that // we use the page token (which returns rows faster). if (!next_page_token_.empty()) { - TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + request->SetUri(strings::StrCat( BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), - "&pageToken=", request->EscapeString(next_page_token_)))); + "&pageToken=", request->EscapeString(next_page_token_))); first_buffered_row_index_ += row_buffer_.size(); } else { - TF_RETURN_IF_ERROR(request->SetUri(strings::StrCat( + request->SetUri(strings::StrCat( BigQueryUriPrefix(), "data?maxResults=", ComputeMaxResultsArg(), - "&startIndex=", first_buffered_row_index_))); + "&startIndex=", first_buffered_row_index_)); } - TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + request->AddAuthBearerHeader(auth_token); + request->SetResultBuffer(&output_buffer); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading rows from ", FullTableName()); @@ -293,10 +264,9 @@ Status BigQueryTableAccessor::ReadSchema() { std::unique_ptr request(http_request_factory_->Create()); std::vector output_buffer; output_buffer.reserve(kBufferSize); - TF_RETURN_IF_ERROR(request->Init()); - TF_RETURN_IF_ERROR(request->SetUri(BigQueryUriPrefix())); - TF_RETURN_IF_ERROR(request->AddAuthBearerHeader(auth_token)); - TF_RETURN_IF_ERROR(request->SetResultBuffer(&output_buffer)); + request->SetUri(BigQueryUriPrefix()); + request->AddAuthBearerHeader(auth_token); + request->SetResultBuffer(&output_buffer); TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading schema for ", FullTableName()); diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h index 7d0eee59ae2f47503c4f8994ef356ce0dc336733..b349063715c903c982cfe2fb116b6525e35ff63b 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ #include #include @@ -205,4 +205,4 @@ class BigQueryTableAccessor { }; } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ +#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_PARTITION_ACCESSOR_H_ diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h index b2b11f4f57800d55ebc86273fcda71e673ff143a..fea6b15640ded74432f35112bc5d5d68e641c9dc 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor_test_data.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ -#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#ifndef TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#define TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ #include @@ -399,6 +399,6 @@ const string kTestEmptyRow = R"({ }]}]})"; } // namespace -} // namepsace tensorflow +} // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ +#endif // TENSORFLOW_CORE_KERNELS_CLOUD_BIGQUERY_TABLE_ACCESSOR_TEST_DATA_H_ diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 15abd2be0385eb776ff4f76484133efb6e34f076..80e18a43a71cc9d6c9e2ccf5836e50c6427a30f6 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -34,6 +34,7 @@ py_library( ":cluster_resolver_py", ":gce_cluster_resolver_py", ":tpu_cluster_resolver_py", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/cluster_resolver/__init__.py b/tensorflow/contrib/cluster_resolver/__init__.py index d17501e87e79158b1602ac6ddecc091bd86f2c2d..b4d8cd4a7cf42e910e7506dbeec8656a2cef62eb 100644 --- a/tensorflow/contrib/cluster_resolver/__init__.py +++ b/tensorflow/contrib/cluster_resolver/__init__.py @@ -26,3 +26,15 @@ from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import from tensorflow.contrib.cluster_resolver.python.training.gce_cluster_resolver import GceClusterResolver from tensorflow.contrib.cluster_resolver.python.training.tpu_cluster_resolver import TPUClusterResolver # pylint: enable=wildcard-import,unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'ClusterResolver', + 'SimpleClusterResolver', + 'UnionClusterResolver', + 'GceClusterResolver', + 'TPUClusterResolver', +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index c74da9cabd6816bc9c7891e32937534cff2d677d..a6a6e642e4e4c721b94821a70d55d6fe931347d6 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -18,6 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from six.moves.urllib.request import Request +from six.moves.urllib.request import urlopen + from tensorflow.contrib.cluster_resolver.python.training.cluster_resolver import ClusterResolver from tensorflow.python.training.server_lib import ClusterSpec @@ -38,10 +42,16 @@ class TPUClusterResolver(ClusterResolver): Cloud Platform project. """ + def _requestComputeMetadata(self, path): + req = Request('http://metadata/computeMetadata/v1/%s' % path, + headers={'Metadata-Flavor': 'Google'}) + resp = urlopen(req) + return resp.read() + def __init__(self, - project, - zone, tpu_names, + zone=None, + project=None, job_name='tpu_worker', credentials='default', service=None): @@ -51,9 +61,13 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - project: Name of the GCP project containing Cloud TPUs - zone: Zone where the TPUs are located tpu_names: A list of names of the target Cloud TPUs. + zone: Zone where the TPUs are located. If omitted or empty, we will assume + that the zone of the TPU is the same as the zone of the GCE VM, which we + will try to discover from the GCE metadata service. + project: Name of the GCP project containing Cloud TPUs. If omitted or + empty, we will try to discover the project name of the GCE VM from the + GCE metadata service. job_name: Name of the TensorFlow job the TPUs belong to. credentials: GCE Credentials. If None, then we use default credentials from the oauth2client @@ -65,6 +79,13 @@ class TPUClusterResolver(ClusterResolver): ImportError: If the googleapiclient is not installed. """ + if not project: + project = self._requestComputeMetadata('/project/project-id') + + if not zone: + zone_path = self._requestComputeMetadata('/instance/zone') + zone = zone_path.split('/')[-1] + self._project = project self._zone = zone self._tpu_names = tpu_names @@ -122,7 +143,8 @@ class TPUClusterResolver(ClusterResolver): request = self._service.projects().locations().nodes().get(name=full_name) response = request.execute() - instance_url = '%s:%s' % (response['ipAddress'], response['port']) - worker_list.append(instance_url) + if 'health' in response and response['health'] == 'HEALTHY': + instance_url = '%s:%s' % (response['ipAddress'], response['port']) + worker_list.append(instance_url) return ClusterSpec({self._job_name: worker_list}) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index db7419be06b58e1c5737f69f2c7fd9fee44b9d95..4fd34629cf74f90869c77b8cb098d3c585a49404 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -48,6 +48,15 @@ class MockNodeClass(object): return MockRequestClass(name, self._tpu_map) +def mock_request_compute_metadata(cls, *args, **kwargs): + del cls, kwargs # Unused. + if args[0] == '/project/project-id': + return 'test-project' + elif args[0] == '/instance/zone': + return 'projects/test-project/locations/us-central1-c' + return '' + + class TPUClusterResolverTest(test.TestCase): def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): @@ -89,11 +98,37 @@ class TPUClusterResolverTest(test.TestCase): return mock_client + @mock.patch.object(TPUClusterResolver, + '_requestComputeMetadata', + mock_request_compute_metadata) + def testRetrieveProjectAndZoneFromMetadata(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'HEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project=None, + zone=None, + tpu_names=['test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { name: 'tpu_worker' tasks { key: 0 value: '10.1.2.3:8470' } } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testSimpleSuccessfulRetrieval(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -114,11 +149,13 @@ class TPUClusterResolverTest(test.TestCase): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' }, 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { 'ipAddress': '10.4.5.6', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } @@ -136,15 +173,54 @@ class TPUClusterResolverTest(test.TestCase): """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testHealthyTpuNodeRetrieval(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470', + 'health': 'HEALTHY' + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { + 'ipAddress': '10.4.5.6', + 'port': '8470', + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-3': { + 'ipAddress': '10.7.8.9', + 'port': '8470', + 'health': 'UNHEALTHY' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-2', 'test-tpu-1', 'test-tpu-3'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + + actual_cluster_spec = tpu_cluster_resolver.cluster_spec() + expected_proto = """ + job { + name: 'tpu_worker' + tasks { + key: 0 + value: '10.1.2.3:8470' + } + } + """ + self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + def testGetMasterMultipleEntries(self): tpu_map = { 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { 'ipAddress': '10.1.2.3', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' }, 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { 'ipAddress': '10.4.5.6', - 'port': '8470' + 'port': '8470', + 'health': 'HEALTHY' } } diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index ba708673b0d562f928230f427406147ab22f0007..16317f538f3890661f1b59ea39fe67dcf04d0d0a 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -18,7 +18,6 @@ cmake_policy(SET CMP0022 NEW) # Options option(tensorflow_VERBOSE "Enable for verbose output" OFF) -option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) option(tensorflow_ENABLE_SSL_SUPPORT "Enable boringssl support" OFF) option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON) option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF) @@ -34,6 +33,13 @@ option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) +option(tensorflow_DISABLE_EIGEN_FORCEINLINE "Disable forceinline, to speed up build on windows." OFF) + +# GPU, CUDA and cuDNN options +option(tensorflow_ENABLE_GPU "Enable GPU support" OFF) +set(tensorflow_CUDA_VERSION "9.0" CACHE STRING "CUDA version to build against") +set(tensorflow_CUDNN_VERSION "7" CACHE STRING "cuDNN version to build against") + if(HAIKU) option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) else() @@ -46,6 +52,7 @@ if (NOT WIN32) # for targets that link ${CMAKE_THREAD_LIBS_INIT}. find_package (Threads) + # Options for linking CUDA/CUDNN libraries option(tensorflow_PATH_STATIC_LIB "Additional library search path for libcudnn_static.a, libnccl_static.a, libculibos.a" /usr/local/cuda/lib64/) option(tensorflow_CUDNN_INCLUDE "cudnn.h header install path" /usr/include/) if (NOT tensorflow_CUDNN_INCLUDE) @@ -53,12 +60,28 @@ if (NOT WIN32) set(tensorflow_CUDNN_INCLUDE /usr/include) endif (NOT tensorflow_CUDNN_INCLUDE) option(tensorflow_PATH_CUDNN_STATIC_LIB "Override PATH_STATIC_LIB for libcudnn_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_CUDNN_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_CUDNN_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_CUDNN_STATIC_LIB) option(tensorflow_PATH_NCCL_STATIC_LIB "Override PATH_STATIC_LIB for libnccl_static.a" ${tensorflow_PATH_STATIC_LIB}) + if (NOT tensorflow_PATH_NCCL_STATIC_LIB) + # option's default value is OFF. Fill it with real default values + set (tensorflow_PATH_NCCL_STATIC_LIB ${tensorflow_PATH_STATIC_LIB}) + endif (NOT tensorflow_PATH_NCCL_STATIC_LIB) option(tensorflow_CUDA_LIBRARY_PATH "Designate the default CUDA library paths" /usr/local/cuda/lib64) if (NOT tensorflow_CUDA_LIBRARY_PATH) # option's default value is OFF. Fill it with real default values set(tensorflow_CUDA_LIBRARY_PATH /usr/local/cuda/lib64) endif (NOT tensorflow_CUDA_LIBRARY_PATH) + + # Options for linking other libraries + option(systemlib_ZLIB "Use the system installed library as shared objects instead of downloading ZLIB and statically linking to it: ZLIB" OFF) + + option(systemlib_ALL "Turn on every possible systemlib_* options" OFF) + if (systemlib_ALL) + set (systmelib_ZLIB ON) + endif (systemlib_ALL) endif() if (WIN32) @@ -92,6 +115,13 @@ else() set(CMAKE_POSITION_INDEPENDENT_CODE OFF) endif() +# TODO(jart): We should make this only apply to snapfn.cc +add_definitions(-DSQLITE_OMIT_LOAD_EXTENSION) + +if (tensorflow_DISABLE_EIGEN_FORCEINLINE) + add_definitions(-DEIGEN_STRONG_INLINE=inline) +endif() + add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) @@ -113,6 +143,9 @@ if(WIN32) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_MINSIZEREL "${CMAKE_CXX_FLAGS_MINSIZEREL} /D_ITERATOR_DEBUG_LEVEL=0") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} /D_ITERATOR_DEBUG_LEVEL=0") + + # Try to avoid flaky failures due to failed generation of generate.stamp files. + set(CMAKE_SUPPRESS_REGENERATION ON) endif() if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") @@ -160,13 +193,14 @@ include(protobuf) include(re2) include(cub) include(sqlite) -include(double_conversion) if (tensorflow_BUILD_CC_TESTS) include(googletest) endif() +add_definitions(${ADD_CFLAGS}) +link_directories(${ADD_LINK_DIRECTORY}) + set(tensorflow_EXTERNAL_LIBRARIES - ${zlib_STATIC_LIBRARIES} ${gif_STATIC_LIBRARIES} ${png_STATIC_LIBRARIES} ${jpeg_STATIC_LIBRARIES} @@ -179,8 +213,16 @@ set(tensorflow_EXTERNAL_LIBRARIES ${protobuf_STATIC_LIBRARIES} ${re2_STATIC_LIBRARIES} ${sqlite_STATIC_LIBRARIES} - ${double_conversion_STATIC_LIBRARIES} ) + +if (systemlib_ZLIB) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${ZLIB_LIBRARIES}) +else (systemlib_ZLIB) + set(tensorflow_EXTERNAL_LIBRARIES ${tensorflow_EXTERNAL_LIBRARIES} + ${zlib_STATIC_LIBRARIES}) +endif (systemlib_ZLIB) + set(tensorflow_EXTERNAL_DEPENDENCIES zlib_copy_headers_to_destination gif_copy_headers_to_destination @@ -198,7 +240,6 @@ set(tensorflow_EXTERNAL_DEPENDENCIES fft2d re2 sqlite_copy_headers_to_destination - double_conversion ) include_directories( @@ -221,7 +262,6 @@ include_directories( ${PROTOBUF_INCLUDE_DIRS} ${re2_INCLUDE_DIR} ${sqlite_INCLUDE_DIR} - ${double_conversion_INCLUDE_DIR} ) if(tensorflow_ENABLE_SSL_SUPPORT) @@ -266,7 +306,21 @@ if (tensorflow_ENABLE_GPU) list(APPEND CMAKE_LIBRARY_PATH "${tensorflow_CUDA_LIBRARY_PATH}/stubs") endif (NOT WIN32) - find_package(CUDA 8.0 REQUIRED) + # later command will make use of the value in tensorflow_CUDA_VERSION + find_package(CUDA ${tensorflow_CUDA_VERSION} REQUIRED EXACT) + + # Test compatibility of compiler on CUDA + try_compile(CUDA_TEST_COMPILE_C + ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.c + CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) + try_compile(CUDA_TEST_COMPILE_CXX + ${CMAKE_CURRENT_BINARY_DIR}/tests/cuda + ${CMAKE_CURRENT_SOURCE_DIR}/tests/cuda/compatibility_test.cc + CMAKE_FLAGS -DINCLUDE_DIRECTORIES=${CUDA_INCLUDE_DIRS}) + if(NOT (CUDA_TEST_COMPILE_C AND CUDA_TEST_COMPILE_CXX)) + message(FATAL_ERROR "Selected compiler (or version) is not supported for CUDA") + endif() # by default we assume compute cabability 3.5 and 5.2. If you change this change it in # CUDA_NVCC_FLAGS and cuda_config.h below @@ -320,13 +374,16 @@ if (tensorflow_ENABLE_GPU) ${CUDA_curand_LIBRARY} ${CUDA_cupti_LIBRARY} ${CUDA_cusolver_LIBRARY} ${cudnn_STATIC_LIBRARY} ${culibos_STATIC_LIBRARY} ${nccl_STATIC_LIBRARY}) endif (WIN32) + # Remove "." from CUDA version variable. + string(REPLACE "." "" short_CUDA_VER ${tensorflow_CUDA_VERSION}) + # create cuda_config.h FILE(WRITE ${tensorflow_source_dir}/third_party/gpus/cuda/cuda_config.h "#ifndef CUDA_CUDA_CONFIG_H_\n" "#define CUDA_CUDA_CONFIG_H_\n" "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" - "#define TF_CUDA_VERSION \"64_80\"\n" - "#define TF_CUDNN_VERSION \"64_6\"\n" + "#define TF_CUDA_VERSION \"64_${short_CUDA_VER}\"\n" + "#define TF_CUDNN_VERSION \"64_${tensorflow_CUDNN_VERSION}\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -341,6 +398,8 @@ if (tensorflow_ENABLE_GPU) ${CUDA_TOOLKIT_TARGET_DIR}/include/cufft.h ${CUDA_TOOLKIT_TARGET_DIR}/include/curand.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_runtime_api.h ${CUDA_TOOLKIT_TARGET_DIR}/include/cusolverDn.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/cuda_fp16.h + ${CUDA_TOOLKIT_TARGET_DIR}/include/device_functions.h DESTINATION ${tensorflow_source_dir}/third_party/gpus/cuda/include ) else(WIN32) @@ -364,15 +423,15 @@ if (tensorflow_ENABLE_GPU) if(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value msvcp_dll_name=msvcp140.dll - cudart_dll_name=cudart64_80.dll - cuda_version_number=8.0 + cudart_dll_name=cudart64_${short_CUDA_VER}.dll + cuda_version_number=${tensorflow_CUDA_VERSION} nvcuda_dll_name=nvcuda.dll - cudnn_dll_name=cudnn64_6.dll - cudnn_version_number=6) + cudnn_dll_name=cudnn64_${tensorflow_CUDNN_VERSION}.dll + cudnn_version_number=${tensorflow_CUDNN_VERSION}) else(WIN32) set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value - cuda_version_number=8.0 - cudnn_version_number=6) + cuda_version_number=${tensorflow_CUDA_VERSION} + cudnn_version_number=${tensorflow_CUDNN_VERSION}) endif(WIN32) else(tensorflow_ENABLE_GPU) set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value @@ -387,10 +446,8 @@ endif() # Let's get to work! include(tf_core_framework.cmake) -# NOTE: Disabled until issue #3996 is fixed. -# include(tf_stream_executor.cmake) if (tensorflow_ENABLE_GPU) - include(tf_stream_executor.cmake) + include(tf_stream_executor.cmake) endif() include(tf_core_cpu.cmake) diff --git a/tensorflow/contrib/cmake/README.md b/tensorflow/contrib/cmake/README.md index 4ddfec5960d2b759bacb376202cd8dab6ef2b024..8f85a75ee466dbac524a1266dc2522109ca77cd5 100644 --- a/tensorflow/contrib/cmake/README.md +++ b/tensorflow/contrib/cmake/README.md @@ -19,23 +19,6 @@ for instructions on how to install a pre-built TensorFlow package on Windows. ### Current known limitations * It is not possible to load a custom Op library. * GCS file system is not supported. -* The following Ops are not currently implemented: - - Dequantize - - QuantizeAndDequantize - - QuantizedAvgPool - - QuantizedBatchNomWithGlobalNormalization - - QuantizedBiasAdd - - QuantizedConcat - - QuantizedConv2D - - QuantizedMatmul - - QuantizedMaxPoo - - QuantizeDownAndShrinkRange - - QuantizedRelu - - QuantizedRelu6 - - QuantizedReshape - - QuantizeV2 - - RequantizationRange - - Requantize ## Building with CMake @@ -47,7 +30,7 @@ bindings. * CMake version 3.5 or later. -* [Git](http://git-scm.com) +* [Git](https://git-scm.com) * [SWIG](http://www.swig.org/download.html) @@ -65,7 +48,7 @@ bindings. * Microsoft Windows 10 - Microsoft Visual Studio Enterprise 2015 with Visual C++ 2015 - - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.continuum.io/downloads) + - [Anaconda 4.1.1 (Python 3.5 64-bit)](https://www.anaconda.com/download/) - [Git for Windows version 2.9.2.windows.1](https://git-scm.com/download/win) - [swigwin-3.0.10](http://www.swig.org/download.html) - [NVidia CUDA Toolkit 8.0](https://developer.nvidia.com/cuda-downloads) diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index cca8444e2ae9952ea7c69a9392580ead715d363b..3c4bb01e24fd121c9d0fc3594cc25de37af0e8a1 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -37,13 +37,10 @@ ExternalProject_Add(boringssl GIT_TAG ${boringssl_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" # BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${boringssl_STATIC_LIBRARIES} INSTALL_COMMAND "" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF ) diff --git a/tensorflow/contrib/cmake/external/double_conversion.cmake b/tensorflow/contrib/cmake/external/double_conversion.cmake deleted file mode 100644 index 527ccdc8d887cb4c2e7d2412c99a8bc682568472..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/cmake/external/double_conversion.cmake +++ /dev/null @@ -1,54 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -include (ExternalProject) - -set(double_conversion_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/double_conversion/src/double_conversion) -set(double_conversion_URL https://github.com/google/double-conversion.git) -set(double_conversion_TAG 5664746) -set(double_conversion_BUILD ${double_conversion_INCLUDE_DIR}) -set(double_conversion_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.so) -set(double_conversion_INCLUDES ${double_conversion_BUILD}) - -if(WIN32) - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/$(Configuration)/double-conversion.lib) -else() - set(double_conversion_STATIC_LIBRARIES ${double_conversion_BUILD}/double-conversion/libdouble-conversion.a) -endif() - -set(double_conversion_HEADERS - "${double_conversion_INCLUDE_DIR}/double-conversion/bignum-dtoa.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/cached-powers.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/double-conversion.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/fixed-dtoa.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/strtod.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/bignum.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/diy-fp.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/fast-dtoa.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/ieee.h" - "${double_conversion_INCLUDE_DIR}/double-conversion/utils.h" -) - -ExternalProject_Add(double_conversion - PREFIX double_conversion - GIT_REPOSITORY ${double_conversion_URL} - GIT_TAG ${double_conversion_TAG} - DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - BUILD_IN_SOURCE 1 - INSTALL_COMMAND "" - CMAKE_CACHE_ARGS - -DCMAKE_BUILD_TYPE:STRING=Release - -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -) diff --git a/tensorflow/contrib/cmake/external/farmhash.cmake b/tensorflow/contrib/cmake/external/farmhash.cmake index 0cd0c1030c73d5218411f281d2b077af217e8275..d51569bc213f2bd354571a00910714e787120951 100644 --- a/tensorflow/contrib/cmake/external/farmhash.cmake +++ b/tensorflow/contrib/cmake/external/farmhash.cmake @@ -33,6 +33,7 @@ if(WIN32) URL_HASH ${farmhash_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${farmhash_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/farmhash/CMakeLists.txt ${farmhash_BUILD} INSTALL_DIR ${farmhash_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/fft2d.cmake b/tensorflow/contrib/cmake/external/fft2d.cmake index d3af2a46761c0f7f0b5db134af8400fc93f2f095..a7bc50d5bcd4384d5c943d681fd7cd6fa1ffa796 100644 --- a/tensorflow/contrib/cmake/external/fft2d.cmake +++ b/tensorflow/contrib/cmake/external/fft2d.cmake @@ -29,6 +29,7 @@ if(WIN32) URL_HASH ${fft2d_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${fft2d_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/fft2d/CMakeLists.txt ${fft2d_BUILD}/src/fft2d/CMakeLists.txt INSTALL_DIR ${fft2d_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/gemmlowp.cmake b/tensorflow/contrib/cmake/external/gemmlowp.cmake index 3b146657bfc9bdd54db14839195af45972e67aff..a235442dc5c0a07e249653381436eeae81575883 100644 --- a/tensorflow/contrib/cmake/external/gemmlowp.cmake +++ b/tensorflow/contrib/cmake/external/gemmlowp.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(gemmlowp_URL https://mirror.bazel.build/github.com/google/gemmlowp/archive/010bb3e71a26ca1d0884a167081d092b43563996.zip) -set(gemmlowp_HASH SHA256=dd2557072bde12141419cb8320a9c25e6ec41a8ae53c2ac78c076a347bb46d9d) +set(gemmlowp_URL https://github.com/google/gemmlowp/archive/6a2a90822e8546fc2bfa7044de0faf1c1cb4862f.zip) +set(gemmlowp_HASH SHA256=3447948d219f3270383766bbe08942888c0eb4e0ca6663c0e0548502ec5bb77d) set(gemmlowp_BUILD ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) set(gemmlowp_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/gemmlowp/src/gemmlowp) diff --git a/tensorflow/contrib/cmake/external/gif.cmake b/tensorflow/contrib/cmake/external/gif.cmake index 3d53c51fffcec1602a3b5553cdf3b225e3b0ae46..e1f8d13f8ea47b83e4a1840afac7398ef226eb45 100644 --- a/tensorflow/contrib/cmake/external/gif.cmake +++ b/tensorflow/contrib/cmake/external/gif.cmake @@ -33,6 +33,7 @@ if(WIN32) PREFIX gif URL ${gif_URL} URL_HASH ${gif_HASH} + BUILD_BYPRODUCTS ${gif_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_SOURCE_DIR}/patches/gif/CMakeLists.txt ${gif_BUILD} INSTALL_DIR ${gif_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" diff --git a/tensorflow/contrib/cmake/external/googletest.cmake b/tensorflow/contrib/cmake/external/googletest.cmake index d09bb02890f25a0312e62c876c1729e57a059e82..7cc5ae6390934773635cf7a4dff77a3cbfb41ba1 100644 --- a/tensorflow/contrib/cmake/external/googletest.cmake +++ b/tensorflow/contrib/cmake/external/googletest.cmake @@ -20,8 +20,13 @@ set(googletest_BUILD ${CMAKE_CURRENT_BINARY_DIR}/googletest/) set(googletest_TAG ec44c6c1675c25b9827aacd08c02433cccde7780) if(WIN32) - set(googletest_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/$(Configuration)/gtest.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(googletest_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/$(Configuration)/gtest.lib) + else() + set(googletest_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/gtest.lib) + endif() else() set(googletest_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/googletest/src/googletest/googletest/${CMAKE_BUILD_TYPE}/gtest.a) @@ -33,6 +38,7 @@ ExternalProject_Add(googletest GIT_TAG ${googletest_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${googletest_STATIC_LIBRARIES} #PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_SOURCE_DIR}/patches/grpc/CMakeLists.txt ${GRPC_BUILD} INSTALL_COMMAND "" CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/grpc.cmake b/tensorflow/contrib/cmake/external/grpc.cmake index 41ea0b48a4600d7ca2dd2f4a61c14ec0cc5b4734..a9f43a3ecba4830533efcc13f8c4c1c61fe1ef78 100644 --- a/tensorflow/contrib/cmake/external/grpc.cmake +++ b/tensorflow/contrib/cmake/external/grpc.cmake @@ -17,13 +17,20 @@ include (ExternalProject) set(GRPC_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/include) set(GRPC_URL https://github.com/grpc/grpc.git) set(GRPC_BUILD ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc) -set(GRPC_TAG 54e8f37e537794c2d814c1604c1282125f64f093) +set(GRPC_TAG 730b778632e79cc3c96ad237f282d687ee325ce7) if(WIN32) - set(grpc_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc_unsecure.lib - ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/gpr.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(grpc_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/Release/gpr.lib) + else() + set(grpc_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/grpc++_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/grpc_unsecure.lib + ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/gpr.lib) + endif() else() set(grpc_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/grpc/src/grpc/libgrpc++_unsecure.a @@ -40,6 +47,7 @@ ExternalProject_Add(grpc GIT_TAG ${GRPC_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${grpc_STATIC_LIBRARIES} BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc++_unsecure COMMAND ${CMAKE_COMMAND} --build . --config Release --target grpc_cpp_plugin INSTALL_COMMAND "" diff --git a/tensorflow/contrib/cmake/external/highwayhash.cmake b/tensorflow/contrib/cmake/external/highwayhash.cmake index 2c23bef8a331de356c93dbf9d0e91d8bb13bd6c8..a6e8a38d8c2ee3deb5453c264e0c5eb23248301f 100644 --- a/tensorflow/contrib/cmake/external/highwayhash.cmake +++ b/tensorflow/contrib/cmake/external/highwayhash.cmake @@ -42,6 +42,7 @@ ExternalProject_Add(highwayhash GIT_TAG ${highwayhash_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${highwayhash_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/highwayhash/CMakeLists.txt ${highwayhash_BUILD} INSTALL_DIR ${highwayhash_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake index 198ba13e64e4b6df57c4325a0104b1a6745d173a..afadcc007d66414be3306e91e7186a00b6e587ce 100644 --- a/tensorflow/contrib/cmake/external/jemalloc.cmake +++ b/tensorflow/contrib/cmake/external/jemalloc.cmake @@ -24,8 +24,11 @@ if (WIN32) ${jemalloc_INCLUDE_DIRS} ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include/msvc_compat ) - set(jemalloc_ADDITIONAL_CMAKE_OPTIONS -A x64) - set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib) + else() + set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/jemalloc.lib) + endif() else() set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.a) endif() @@ -36,12 +39,12 @@ ExternalProject_Add(jemalloc URL_HASH ${jemalloc_HASH} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 - CONFIGURE_COMMAND ${CMAKE_COMMAND} + BUILD_BYPRODUCTS ${jemalloc_STATIC_LIBRARIES} + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc + INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step." + CMAKE_CACHE_ARGS -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -Dwith-jemalloc-prefix:STRING=jemalloc_ -Dwithout-export:BOOL=ON - ${jemalloc_ADDITIONAL_CMAKE_OPTIONS} - BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc - INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step." ) diff --git a/tensorflow/contrib/cmake/external/jpeg.cmake b/tensorflow/contrib/cmake/external/jpeg.cmake index d9a165e856c588880ebdf996666d70c9e7f53da8..c1c5842aa4454f1c95ec284392194a89d47ee8d5 100644 --- a/tensorflow/contrib/cmake/external/jpeg.cmake +++ b/tensorflow/contrib/cmake/external/jpeg.cmake @@ -46,6 +46,7 @@ if (WIN32) PREFIX jpeg URL ${jpeg_URL} URL_HASH ${jpeg_HASH} + BUILD_BYPRODUCTS ${jpeg_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/jpeg/CMakeLists.txt ${jpeg_BUILD} INSTALL_DIR ${jpeg_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" diff --git a/tensorflow/contrib/cmake/external/jsoncpp.cmake b/tensorflow/contrib/cmake/external/jsoncpp.cmake index d2ae4c76e8cd175cdc3ba41fdf4e4009f8237309..84c52e3652ff935c287d32c0c80fd407e1213f29 100644 --- a/tensorflow/contrib/cmake/external/jsoncpp.cmake +++ b/tensorflow/contrib/cmake/external/jsoncpp.cmake @@ -23,7 +23,11 @@ set(jsoncpp_LIBRARIES ${jsoncpp_BUILD}/obj/so/libjsoncpp.so) set(jsoncpp_INCLUDES ${jsoncpp_BUILD}) if(WIN32) - set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/$(Configuration)/jsoncpp.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/$(Configuration)/jsoncpp.lib) + else() + set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/jsoncpp.lib) + endif() else() set(jsoncpp_STATIC_LIBRARIES ${jsoncpp_BUILD}/libjsoncpp.a) endif() @@ -40,13 +44,10 @@ ExternalProject_Add(jsoncpp GIT_TAG ${jsoncpp_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${jsoncpp_STATIC_LIBRARIES} INSTALL_COMMAND "" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF ) diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index e41384f023ca9fc4cba697917b491af5a9db92bc..ed5ab788acc5625b9c8020fce15f027d98433096 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -20,31 +20,28 @@ set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb) set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install) +if(WIN32) + set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/lmdb.lib) +else() + set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/liblmdb.a) +endif() + ExternalProject_Add(lmdb PREFIX lmdb URL ${lmdb_URL} URL_HASH ${lmdb_HASH} + BUILD_BYPRODUCTS ${lmdb_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/lmdb/CMakeLists.txt ${lmdb_BUILD} INSTALL_DIR ${lmdb_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${lmdb_INSTALL} ) -if(WIN32) - set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/lmdb.lib) -else() - set(lmdb_STATIC_LIBRARIES ${lmdb_INSTALL}/lib/liblmdb.a) -endif() - set(lmdb_HEADERS "${lmdb_INSTALL}/include/lmdb.h" "${lmdb_INSTALL}/include/midl.h" diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake index 155c91cb97dbe5ef33c318efb5544a9fa22166c7..f3a37ff5088e3f9e54e38c0edb5777c27b26969f 100644 --- a/tensorflow/contrib/cmake/external/nsync.cmake +++ b/tensorflow/contrib/cmake/external/nsync.cmake @@ -16,7 +16,7 @@ include (ExternalProject) set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public) set(nsync_URL https://github.com/google/nsync) -set(nsync_TAG 93815892dddafe9146a5f7e7042281d59d0f4323) +set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668) set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync) set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install) @@ -42,6 +42,7 @@ ExternalProject_Add(nsync GIT_TAG ${nsync_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${nsync_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/nsync/CMakeLists.txt ${nsync_BUILD} INSTALL_DIR ${nsync_INSTALL} CMAKE_CACHE_ARGS diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index aad6618f52f909096fd2388e867ef3a965d033cb..6cd66a65990e7a2b963b52b310061b551752cd4d 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -21,9 +21,19 @@ set(png_BUILD ${CMAKE_BINARY_DIR}/png/src/png) set(png_INSTALL ${CMAKE_BINARY_DIR}/png/install) if(WIN32) - set(png_STATIC_LIBRARIES - debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib - optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(png_STATIC_LIBRARIES + debug ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib + optimized ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + else() + if(CMAKE_BUILD_TYPE EQUAL Debug) + set(png_STATIC_LIBRARIES + ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_staticd.lib) + else() + set(png_STATIC_LIBRARIES + ${CMAKE_BINARY_DIR}/png/install/lib/libpng12_static.lib) + endif() + endif() else() set(png_STATIC_LIBRARIES ${CMAKE_BINARY_DIR}/png/install/lib/libpng12.a) endif() @@ -38,14 +48,11 @@ ExternalProject_Add(png DEPENDS zlib URL ${png_URL} URL_HASH ${png_HASH} + BUILD_BYPRODUCTS ${png_STATIC_LIBRARIES} INSTALL_DIR ${png_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index b53857a47bfbf797af02fe7f69474263119161cd..aba8a5244e17d717293deec6d9b6e8e725ef010e 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -16,14 +16,37 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) set(PROTOBUF_URL https://github.com/google/protobuf.git) -set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) +set(PROTOBUF_TAG 396336eb961b75f03b25824fe86cf6490fb75e3a) if(WIN32) - set(protobuf_STATIC_LIBRARIES - debug ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobufd.lib - optimized ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobuf.lib) - set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/protoc.exe) - set(PROTOBUF_ADDITIONAL_CMAKE_OPTIONS -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=OFF -A x64) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(protobuf_STATIC_LIBRARIES + debug ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobufd.lib + optimized ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/libprotobuf.lib) + set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/$(Configuration)/protoc.exe) + else() + if(CMAKE_BUILD_TYPE EQUAL Debug) + set(protobuf_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/libprotobufd.lib) + else() + set(protobuf_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/libprotobuf.lib) + endif() + set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/protoc.exe) + endif() + + # This section is to make sure CONFIGURE_COMMAND use the same generator settings + set(PROTOBUF_GENERATOR_PLATFORM) + if (CMAKE_GENERATOR_PLATFORM) + set(PROTOBUF_GENERATOR_PLATFORM -A ${CMAKE_GENERATOR_PLATFORM}) + endif() + set(PROTOBUF_GENERATOR_TOOLSET) + if (CMAKE_GENERATOR_TOOLSET) + set(PROTOBUF_GENERATOR_TOOLSET -T ${CMAKE_GENERATOR_TOOLSET}) + endif() + set(PROTOBUF_ADDITIONAL_CMAKE_OPTIONS -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=OFF + -G${CMAKE_GENERATOR} ${PROTOBUF_GENERATOR_PLATFORM} ${PROTOBUF_GENERATOR_TOOLSET}) + # End of section else() set(protobuf_STATIC_LIBRARIES ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/libprotobuf.a) set(PROTOBUF_PROTOC_EXECUTABLE ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/protoc) @@ -36,20 +59,23 @@ ExternalProject_Add(protobuf GIT_TAG ${PROTOBUF_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${PROTOBUF_PROTOC_EXECUTABLE} ${protobuf_STATIC_LIBRARIES} SOURCE_DIR ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf + # SOURCE_SUBDIR cmake/ # Requires CMake 3.7, this will allow removal of CONFIGURE_COMMAND + # CONFIGURE_COMMAND resets some settings made in CMAKE_CACHE_ARGS and the generator used CONFIGURE_COMMAND ${CMAKE_COMMAND} cmake/ - -Dprotobuf_BUILD_TESTS=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE=ON + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -Dprotobuf_BUILD_TESTS:BOOL=OFF -DZLIB_ROOT=${ZLIB_INSTALL} ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS} INSTALL_COMMAND "" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF + -Dprotobuf_BUILD_TESTS:BOOL=OFF + -Dprotobuf_MSVC_STATIC_RUNTIME:BOOL=OFF -DZLIB_ROOT:STRING=${ZLIB_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/re2.cmake b/tensorflow/contrib/cmake/external/re2.cmake index d10f5959f71dd350e6e2bcb81be8882b203fb231..c4bc0b1707bf9e86ea41234c8155fd6321c4c33b 100644 --- a/tensorflow/contrib/cmake/external/re2.cmake +++ b/tensorflow/contrib/cmake/external/re2.cmake @@ -21,7 +21,11 @@ set(re2_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/re2/install) set(re2_TAG e7efc48) if(WIN32) - set(re2_STATIC_LIBRARIES ${re2_BUILD}/$(Configuration)/re2.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(re2_STATIC_LIBRARIES ${re2_BUILD}/$(Configuration)/re2.lib) + else() + set(re2_STATIC_LIBRARIES ${re2_BUILD}/re2.lib) + endif() else() set(re2_STATIC_LIBRARIES ${re2_BUILD}/libre2.a) endif() @@ -36,13 +40,10 @@ ExternalProject_Add(re2 GIT_TAG ${re2_TAG} INSTALL_DIR ${re2_INSTALL} BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${re2_STATIC_LIBRARIES} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL} -DRE2_BUILD_TESTING:BOOL=OFF diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index 926c271fd9ea6e2a30251aa408bd49859ae95070..f54197643b06781dad35b40f526f28d301047299 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -20,7 +20,11 @@ set(snappy_BUILD ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy) set(snappy_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/snappy/src/snappy) if(WIN32) - set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/$(Configuration)/snappy.lib) + else() + set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/snappy.lib) + endif() else() set(snappy_STATIC_LIBRARIES ${snappy_BUILD}/libsnappy.a) endif() @@ -35,20 +39,17 @@ ExternalProject_Add(snappy GIT_TAG ${snappy_TAG} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${snappy_STATIC_LIBRARIES} INSTALL_COMMAND "" LOG_DOWNLOAD ON LOG_CONFIGURE ON LOG_BUILD ON CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DSNAPPY_BUILD_TESTS:BOOL=OFF ) # actually enables snappy in the source code -add_definitions(-DTF_USE_SNAPPY) \ No newline at end of file +add_definitions(-DTF_USE_SNAPPY) diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 785039a46983747557607562675349c150e064ad..57c4ae76517e4d7247093edd5e5bd95a83258d87 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -28,6 +28,7 @@ endif() set(sqlite_HEADERS "${sqlite_BUILD}/sqlite3.h" + "${sqlite_BUILD}/sqlite3ext.h" ) if (WIN32) @@ -35,6 +36,7 @@ if (WIN32) PREFIX sqlite URL ${sqlite_URL} URL_HASH ${sqlite_HASH} + BUILD_BYPRODUCTS ${sqlite_STATIC_LIBRARIES} PATCH_COMMAND ${CMAKE_COMMAND} -E copy_if_different ${CMAKE_CURRENT_SOURCE_DIR}/patches/sqlite/CMakeLists.txt ${sqlite_BUILD} INSTALL_DIR ${sqlite_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" @@ -53,11 +55,7 @@ else() INSTALL_DIR ${sqlite_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${sqlite_INSTALL} diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index f10f84336e8b1c0a2c7de7ea1f8b8af7c21f8b51..116d42309394b92407cef79c9d3a975f494bc3ff 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -12,54 +12,75 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -include (ExternalProject) +if (systemlib_ZLIB) + find_package(PkgConfig) + pkg_search_module(ZLIB REQUIRED zlib) + set(zlib_INCLUDE_DIR ${ZLIB_INCLUDE_DIRS}) + set(ADD_LINK_DIRECTORY ${ADD_LINK_DIRECTORY} ${ZLIB_LIBRARY_DIRS}) + set(ADD_CFLAGS ${ADD_CFLAGS} ${ZLIB_CFLAGS_OTHER}) -set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive) -set(ZLIB_URL https://github.com/madler/zlib) -set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib) -set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install) -set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) + # To meet DEPENDS zlib from other projects. + # If we hit this line, zlib is already built and installed to the system. + add_custom_target(zlib) + add_custom_target(zlib_copy_headers_to_destination) -if(WIN32) - set(zlib_STATIC_LIBRARIES - debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib - optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) -else() - set(zlib_STATIC_LIBRARIES - ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a) -endif() +else (systemlib_ZLIB) + include (ExternalProject) -set(ZLIB_HEADERS - "${ZLIB_INSTALL}/include/zconf.h" - "${ZLIB_INSTALL}/include/zlib.h" -) + set(zlib_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/zlib_archive) + set(ZLIB_URL https://github.com/madler/zlib) + set(ZLIB_BUILD ${CMAKE_CURRENT_BINARY_DIR}/zlib/src/zlib) + set(ZLIB_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/zlib/install) + set(ZLIB_TAG 50893291621658f355bc5b4d450a8d06a563053d) -ExternalProject_Add(zlib - PREFIX zlib - GIT_REPOSITORY ${ZLIB_URL} - GIT_TAG ${ZLIB_TAG} - INSTALL_DIR ${ZLIB_INSTALL} - BUILD_IN_SOURCE 1 - DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" - CMAKE_CACHE_ARGS - if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON - else() - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF - endif() - -DCMAKE_BUILD_TYPE:STRING=Release - -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} -) + if(WIN32) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(zlib_STATIC_LIBRARIES + debug ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib + optimized ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) + else() + if(CMAKE_BUILD_TYPE EQUAL Debug) + set(zlib_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstaticd.lib) + else() + set(zlib_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/zlibstatic.lib) + endif() + endif() + else() + set(zlib_STATIC_LIBRARIES + ${CMAKE_CURRENT_BINARY_DIR}/zlib/install/lib/libz.a) + endif() -# put zlib includes in the directory where they are expected -add_custom_target(zlib_create_destination_dir - COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR} - DEPENDS zlib) + set(ZLIB_HEADERS + "${ZLIB_INSTALL}/include/zconf.h" + "${ZLIB_INSTALL}/include/zlib.h" + ) -add_custom_target(zlib_copy_headers_to_destination - DEPENDS zlib_create_destination_dir) + ExternalProject_Add(zlib + PREFIX zlib + GIT_REPOSITORY ${ZLIB_URL} + GIT_TAG ${ZLIB_TAG} + INSTALL_DIR ${ZLIB_INSTALL} + BUILD_IN_SOURCE 1 + BUILD_BYPRODUCTS ${zlib_STATIC_LIBRARIES} + DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" + CMAKE_CACHE_ARGS + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=${tensorflow_ENABLE_POSITION_INDEPENDENT_CODE} + -DCMAKE_BUILD_TYPE:STRING=Release + -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} + ) -foreach(header_file ${ZLIB_HEADERS}) - add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD - COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR}) -endforeach() + # put zlib includes in the directory where they are expected + add_custom_target(zlib_create_destination_dir + COMMAND ${CMAKE_COMMAND} -E make_directory ${zlib_INCLUDE_DIR} + DEPENDS zlib) + + add_custom_target(zlib_copy_headers_to_destination + DEPENDS zlib_create_destination_dir) + + foreach(header_file ${ZLIB_HEADERS}) + add_custom_command(TARGET zlib_copy_headers_to_destination PRE_BUILD + COMMAND ${CMAKE_COMMAND} -E copy_if_different ${header_file} ${zlib_INCLUDE_DIR}) + endforeach() +endif (systemlib_ZLIB) diff --git a/tensorflow/tools/ci_build/install/install_cmake_for_clang.sh b/tensorflow/contrib/cmake/make.sh similarity index 83% rename from tensorflow/tools/ci_build/install/install_cmake_for_clang.sh rename to tensorflow/contrib/cmake/make.sh index 3e626a69ab5e6b7f8d1b4997b459301606501a8e..eed3c34aba1f0326ec741169a187eb2982f253a3 100755 --- a/tensorflow/tools/ci_build/install/install_cmake_for_clang.sh +++ b/tensorflow/contrib/cmake/make.sh @@ -14,6 +14,13 @@ # limitations under the License. # ============================================================================== -CMAKE_URL="https://cmake.org/files/v3.7/cmake-3.7.2-Linux-x86_64.tar.gz" +( +cd "$(dirname "$0")" +mkdir -p _build -wget -O - "${CMAKE_URL}" | tar xzf - -C /usr/local --strip-components=1 +( +cd _build +rm -rf -- * +cmake .. +) +) diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt index 594c2492d4fd68b50c8493321a2c4dcc2d41917e..aaae18a313dd082b428654091c9411600c981ec9 100644 --- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt +++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt @@ -158,12 +158,21 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X") elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX") include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd") set (NSYNC_POSIX ON) + set (NSYNC_OS_EXTRA_SRC + "platform/posix/src/nsync_semaphore_mutex.c" + ) endif () endif () diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt new file mode 100644 index 0000000000000000000000000000000000000000..bfe53c01b3b5fb9db8a5d8fa280d1d7f98974882 --- /dev/null +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -0,0 +1,458 @@ +# python_sanity_test.py will complain about invalid or missing entries +# problematic entries can be commented for temporary whitelisting +tensorflow +tensorflow/core +tensorflow/core/example +tensorflow/core/framework +tensorflow/core/lib +tensorflow/core/lib/core +tensorflow/core/profiler +tensorflow/core/protobuf +tensorflow/core/util +tensorflow/examples +tensorflow/examples/tutorials +tensorflow/examples/tutorials/mnist +tensorflow/python +tensorflow/python/client +tensorflow/python/data +tensorflow/python/data/ops +tensorflow/python/data/util +tensorflow/python/debug +tensorflow/python/debug/cli +tensorflow/python/debug/examples +tensorflow/python/debug/lib +tensorflow/python/debug/wrappers +tensorflow/python/eager +tensorflow/python/estimator +tensorflow/python/estimator/canned +tensorflow/python/estimator/export +tensorflow/python/estimator/inputs +tensorflow/python/estimator/inputs/queues +tensorflow/python/feature_column +tensorflow/python/framework +tensorflow/python/grappler +tensorflow/python/keras +tensorflow/python/keras/activations +tensorflow/python/keras/applications +tensorflow/python/keras/applications/densenet +tensorflow/python/keras/applications/inception_resnet_v2 +tensorflow/python/keras/applications/inception_v3 +tensorflow/python/keras/applications/mobilenet +tensorflow/python/keras/applications/nasnet +tensorflow/python/keras/applications/resnet50 +tensorflow/python/keras/applications/vgg16 +tensorflow/python/keras/applications/vgg19 +tensorflow/python/keras/applications/xception +tensorflow/python/keras/backend +tensorflow/python/keras/callbacks +tensorflow/python/keras/constraints +tensorflow/python/keras/datasets +tensorflow/python/keras/datasets/boston_housing +tensorflow/python/keras/datasets/cifar10 +tensorflow/python/keras/datasets/cifar100 +tensorflow/python/keras/datasets/fashion_mnist +tensorflow/python/keras/datasets/imdb +tensorflow/python/keras/datasets/mnist +tensorflow/python/keras/datasets/reuters +tensorflow/python/keras/estimator +tensorflow/python/keras/initializers +tensorflow/python/keras/layers +tensorflow/python/keras/losses +tensorflow/python/keras/metrics +tensorflow/python/keras/models +tensorflow/python/keras/optimizers +tensorflow/python/keras/preprocessing +tensorflow/python/keras/preprocessing/image +tensorflow/python/keras/preprocessing/sequence +tensorflow/python/keras/preprocessing/text +tensorflow/python/keras/regularizers +tensorflow/python/keras/utils +tensorflow/python/keras/wrappers +tensorflow/python/keras/wrappers/scikit_learn +tensorflow/python/keras/_impl +tensorflow/python/keras/_impl/keras +tensorflow/python/keras/_impl/keras/applications +tensorflow/python/keras/_impl/keras/datasets +tensorflow/python/keras/_impl/keras/engine +tensorflow/python/keras/_impl/keras/layers +tensorflow/python/keras/_impl/keras/preprocessing +tensorflow/python/keras/_impl/keras/utils +tensorflow/python/keras/_impl/keras/wrappers +tensorflow/python/kernel_tests +tensorflow/python/kernel_tests/distributions +tensorflow/python/kernel_tests/linalg +tensorflow/python/kernel_tests/random +tensorflow/python/layers +tensorflow/python/lib +tensorflow/python/lib/core +tensorflow/python/lib/io +tensorflow/python/ops +tensorflow/python/ops/distributions +tensorflow/python/ops/linalg +tensorflow/python/ops/losses +tensorflow/python/platform +tensorflow/python/profiler +tensorflow/python/profiler/internal +tensorflow/python/saved_model +tensorflow/python/summary +tensorflow/python/summary/writer +tensorflow/python/tools +tensorflow/python/training +tensorflow/python/user_ops +tensorflow/python/util +tensorflow/python/util/protobuf +tensorflow/tools +tensorflow/tools/graph_transforms +tensorflow/contrib +tensorflow/contrib/all_reduce +tensorflow/contrib/all_reduce/python +tensorflow/contrib/android +tensorflow/contrib/android/java +tensorflow/contrib/android/java/org +tensorflow/contrib/android/java/org/tensorflow +tensorflow/contrib/android/java/org/tensorflow/contrib +tensorflow/contrib/android/java/org/tensorflow/contrib/android +tensorflow/contrib/android/jni +tensorflow/contrib/batching +tensorflow/contrib/batching/python +tensorflow/contrib/batching/python/ops +tensorflow/contrib/bayesflow +tensorflow/contrib/bayesflow/python +tensorflow/contrib/bayesflow/python/ops +tensorflow/contrib/boosted_trees +tensorflow/contrib/boosted_trees/estimator_batch +tensorflow/contrib/boosted_trees/kernels +tensorflow/contrib/boosted_trees/ops +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/boosted_trees/python +tensorflow/contrib/boosted_trees/python/ops +tensorflow/contrib/cloud +tensorflow/contrib/cloud/kernels +tensorflow/contrib/cloud/ops +tensorflow/contrib/cloud/python +tensorflow/contrib/cloud/python/ops +tensorflow/contrib/cluster_resolver +tensorflow/contrib/cluster_resolver/python +tensorflow/contrib/cluster_resolver/python/training +tensorflow/contrib/coder +tensorflow/contrib/coder/kernels +tensorflow/contrib/coder/ops +tensorflow/contrib/coder/python +tensorflow/contrib/coder/python/ops +tensorflow/contrib/compiler +tensorflow/contrib/copy_graph +tensorflow/contrib/copy_graph/python +tensorflow/contrib/copy_graph/python/util +tensorflow/contrib/crf +tensorflow/contrib/crf/python +tensorflow/contrib/crf/python/ops +tensorflow/contrib/cudnn_rnn +tensorflow/contrib/cudnn_rnn/kernels +tensorflow/contrib/cudnn_rnn/ops +tensorflow/contrib/cudnn_rnn/python +tensorflow/contrib/cudnn_rnn/python/layers +tensorflow/contrib/cudnn_rnn/python/ops +tensorflow/contrib/data +tensorflow/contrib/data/kernels +tensorflow/contrib/data/python +tensorflow/contrib/data/python/kernel_tests +tensorflow/contrib/data/python/ops +tensorflow/contrib/decision_trees +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/deprecated +tensorflow/contrib/distributions +tensorflow/contrib/distributions/python +tensorflow/contrib/distributions/python/ops +tensorflow/contrib/distributions/python/ops/bijectors +tensorflow/contrib/eager +tensorflow/contrib/eager/python +tensorflow/contrib/estimator +tensorflow/contrib/estimator/python +tensorflow/contrib/estimator/python/estimator +tensorflow/contrib/factorization +tensorflow/contrib/factorization/examples +tensorflow/contrib/factorization/kernels +tensorflow/contrib/factorization/ops +tensorflow/contrib/factorization/python +tensorflow/contrib/factorization/python/ops +tensorflow/contrib/feature_column +tensorflow/contrib/feature_column/python +tensorflow/contrib/feature_column/python/feature_column +tensorflow/contrib/ffmpeg +tensorflow/contrib/ffmpeg/default +tensorflow/contrib/framework +tensorflow/contrib/framework/kernels +tensorflow/contrib/framework/ops +tensorflow/contrib/framework/python +tensorflow/contrib/framework/python/framework +tensorflow/contrib/framework/python/ops +tensorflow/contrib/fused_conv +tensorflow/contrib/fused_conv/kernels +tensorflow/contrib/fused_conv/python +tensorflow/contrib/fused_conv/python/ops +tensorflow/contrib/gan +tensorflow/contrib/gan/python +tensorflow/contrib/gan/python/estimator +tensorflow/contrib/gan/python/estimator/python +tensorflow/contrib/gan/python/eval +tensorflow/contrib/gan/python/eval/python +tensorflow/contrib/gan/python/features +tensorflow/contrib/gan/python/features/python +tensorflow/contrib/gan/python/losses +tensorflow/contrib/gan/python/losses/python +tensorflow/contrib/graph_editor +tensorflow/contrib/graph_editor/examples +tensorflow/contrib/grid_rnn +tensorflow/contrib/grid_rnn/python +tensorflow/contrib/grid_rnn/python/ops +tensorflow/contrib/hooks +tensorflow/contrib/hooks/python +tensorflow/contrib/image +tensorflow/contrib/image/kernels +tensorflow/contrib/image/ops +tensorflow/contrib/image/python +tensorflow/contrib/image/python/ops +tensorflow/contrib/input_pipeline +tensorflow/contrib/input_pipeline/kernels +tensorflow/contrib/input_pipeline/ops +tensorflow/contrib/input_pipeline/python +tensorflow/contrib/input_pipeline/python/ops +tensorflow/contrib/integrate +tensorflow/contrib/integrate/python +tensorflow/contrib/integrate/python/ops +tensorflow/contrib/kafka/python +tensorflow/contrib/kafka/python/ops +tensorflow/contrib/keras +tensorflow/contrib/keras/api +tensorflow/contrib/keras/api/keras +tensorflow/contrib/keras/api/keras/activations +tensorflow/contrib/keras/api/keras/applications +tensorflow/contrib/keras/api/keras/applications/inception_v3 +tensorflow/contrib/keras/api/keras/applications/mobilenet +tensorflow/contrib/keras/api/keras/applications/resnet50 +tensorflow/contrib/keras/api/keras/applications/vgg16 +tensorflow/contrib/keras/api/keras/applications/vgg19 +tensorflow/contrib/keras/api/keras/applications/xception +tensorflow/contrib/keras/api/keras/backend +tensorflow/contrib/keras/api/keras/callbacks +tensorflow/contrib/keras/api/keras/constraints +tensorflow/contrib/keras/api/keras/datasets +tensorflow/contrib/keras/api/keras/datasets/boston_housing +tensorflow/contrib/keras/api/keras/datasets/cifar10 +tensorflow/contrib/keras/api/keras/datasets/cifar100 +tensorflow/contrib/keras/api/keras/datasets/imdb +tensorflow/contrib/keras/api/keras/datasets/mnist +tensorflow/contrib/keras/api/keras/datasets/reuters +tensorflow/contrib/keras/api/keras/initializers +tensorflow/contrib/keras/api/keras/layers +tensorflow/contrib/keras/api/keras/losses +tensorflow/contrib/keras/api/keras/metrics +tensorflow/contrib/keras/api/keras/models +tensorflow/contrib/keras/api/keras/optimizers +tensorflow/contrib/keras/api/keras/preprocessing +tensorflow/contrib/keras/api/keras/preprocessing/image +tensorflow/contrib/keras/api/keras/preprocessing/sequence +tensorflow/contrib/keras/api/keras/preprocessing/text +tensorflow/contrib/keras/api/keras/regularizers +tensorflow/contrib/keras/api/keras/utils +tensorflow/contrib/keras/api/keras/wrappers +tensorflow/contrib/keras/api/keras/wrappers/scikit_learn +tensorflow/contrib/kernel_methods +tensorflow/contrib/kernel_methods/python +tensorflow/contrib/kernel_methods/python/mappers +tensorflow/contrib/kfac +tensorflow/contrib/kfac/examples +tensorflow/contrib/kfac/python +tensorflow/contrib/kfac/python/ops +tensorflow/contrib/labeled_tensor +tensorflow/contrib/labeled_tensor/python +tensorflow/contrib/labeled_tensor/python/ops +tensorflow/contrib/layers +tensorflow/contrib/layers/kernels +tensorflow/contrib/layers/ops +tensorflow/contrib/layers/python +tensorflow/contrib/layers/python/layers +tensorflow/contrib/layers/python/ops +tensorflow/contrib/learn +tensorflow/contrib/learn/python +tensorflow/contrib/learn/python/learn +tensorflow/contrib/learn/python/learn/datasets +tensorflow/contrib/learn/python/learn/datasets/data +tensorflow/contrib/learn/python/learn/estimators +tensorflow/contrib/learn/python/learn/learn_io +tensorflow/contrib/learn/python/learn/ops +tensorflow/contrib/learn/python/learn/preprocessing +tensorflow/contrib/learn/python/learn/utils +tensorflow/contrib/legacy_seq2seq +tensorflow/contrib/legacy_seq2seq/python +tensorflow/contrib/legacy_seq2seq/python/ops +tensorflow/contrib/libsvm +tensorflow/contrib/libsvm/python +tensorflow/contrib/libsvm/python/kernel_tests +tensorflow/contrib/libsvm/python/ops +tensorflow/contrib/linalg +tensorflow/contrib/linalg/python +tensorflow/contrib/linalg/python/ops +tensorflow/contrib/linear_optimizer +tensorflow/contrib/linear_optimizer/kernels +tensorflow/contrib/linear_optimizer/kernels/g3doc +tensorflow/contrib/linear_optimizer/python +tensorflow/contrib/linear_optimizer/python/ops +# TODO(drpngx): Fix failing imports +# tensorflow/contrib/lite +# tensorflow/contrib/lite/python +# tensorflow/contrib/lite/toco +# tensorflow/contrib/lite/toco/python +tensorflow/contrib/lookup +tensorflow/contrib/losses +tensorflow/contrib/losses/python +tensorflow/contrib/losses/python/losses +tensorflow/contrib/losses/python/metric_learning +tensorflow/contrib/makefile +tensorflow/contrib/memory_stats +tensorflow/contrib/memory_stats/kernels +tensorflow/contrib/memory_stats/ops +tensorflow/contrib/memory_stats/python +tensorflow/contrib/memory_stats/python/ops +tensorflow/contrib/meta_graph_transform +tensorflow/contrib/metrics +tensorflow/contrib/metrics/python +tensorflow/contrib/metrics/python/metrics +tensorflow/contrib/metrics/python/ops +tensorflow/contrib/mpi_collectives/python +tensorflow/contrib/mpi_collectives/python/ops +tensorflow/contrib/model_pruning +tensorflow/contrib/model_pruning/examples +tensorflow/contrib/model_pruning/examples/cifar10 +tensorflow/contrib/model_pruning/python +tensorflow/contrib/model_pruning/python/layers +tensorflow/contrib/nccl +tensorflow/contrib/nccl/kernels +tensorflow/contrib/nccl/ops +tensorflow/contrib/nccl/python +tensorflow/contrib/nccl/python/ops +tensorflow/contrib/nearest_neighbor/kernels +tensorflow/contrib/nearest_neighbor/ops +tensorflow/contrib/nearest_neighbor/python +tensorflow/contrib/nearest_neighbor/python/ops +tensorflow/contrib/nn +tensorflow/contrib/nn/python +tensorflow/contrib/nn/python/ops +tensorflow/contrib/opt +tensorflow/contrib/opt/python +tensorflow/contrib/opt/python/training +tensorflow/contrib/pi_examples +tensorflow/contrib/pi_examples/camera +tensorflow/contrib/pi_examples/label_image +tensorflow/contrib/pi_examples/label_image/data +tensorflow/contrib/periodic_resample +tensorflow/contrib/periodic_resample/python +tensorflow/contrib/periodic_resample/python/ops +tensorflow/contrib/predictor +tensorflow/contrib/quantization +tensorflow/contrib/quantization/python +tensorflow/contrib/quantize +tensorflow/contrib/quantize/python +tensorflow/contrib/receptive_field +tensorflow/contrib/receptive_field/python +tensorflow/contrib/receptive_field/python/util +tensorflow/contrib/receptive_field/python/util/examples +tensorflow/contrib/reduce_slice_ops +tensorflow/contrib/reduce_slice_ops/kernels +tensorflow/contrib/reduce_slice_ops/ops +tensorflow/contrib/reduce_slice_ops/python +tensorflow/contrib/reduce_slice_ops/python/ops +tensorflow/contrib/remote_fused_graph +tensorflow/contrib/remote_fused_graph/pylib +tensorflow/contrib/remote_fused_graph/pylib/python +tensorflow/contrib/remote_fused_graph/pylib/python/ops +tensorflow/contrib/resampler +tensorflow/contrib/resampler/kernels +tensorflow/contrib/resampler/ops +tensorflow/contrib/resampler/python +tensorflow/contrib/resampler/python/ops +tensorflow/contrib/rnn +tensorflow/contrib/rnn/kernels +tensorflow/contrib/rnn/ops +tensorflow/contrib/rnn/python +tensorflow/contrib/rnn/python/kernel_tests +tensorflow/contrib/rnn/python/ops +tensorflow/contrib/saved_model +tensorflow/contrib/saved_model/python +tensorflow/contrib/saved_model/python/saved_model +tensorflow/contrib/seq2seq +tensorflow/contrib/seq2seq/kernels +tensorflow/contrib/seq2seq/ops +tensorflow/contrib/seq2seq/python +tensorflow/contrib/seq2seq/python/ops +tensorflow/contrib/session_bundle +tensorflow/contrib/session_bundle/example +tensorflow/contrib/signal +tensorflow/contrib/signal/python +tensorflow/contrib/signal/python/ops +tensorflow/contrib/slim +tensorflow/contrib/slim/python +tensorflow/contrib/slim/python/slim +tensorflow/contrib/slim/python/slim/data +tensorflow/contrib/slim/python/slim/nets +tensorflow/contrib/solvers +tensorflow/contrib/solvers/python +tensorflow/contrib/solvers/python/ops +tensorflow/contrib/sparsemax +tensorflow/contrib/sparsemax/python +tensorflow/contrib/sparsemax/python/ops +tensorflow/contrib/specs +tensorflow/contrib/specs/python +tensorflow/contrib/staging +tensorflow/contrib/stat_summarizer +tensorflow/contrib/stat_summarizer/python +tensorflow/contrib/stateless +tensorflow/contrib/stateless/python +tensorflow/contrib/summary +tensorflow/contrib/tensorboard +tensorflow/contrib/tensorboard/plugins +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +# TODO(sami): Add cmake implementations. +# tensorflow/contrib/tensorrt/python +# tensorflow/contrib/tensorrt/python/ops +tensorflow/contrib/tensor_forest +tensorflow/contrib/tensor_forest/client +tensorflow/contrib/tensor_forest/hybrid +tensorflow/contrib/tensor_forest/hybrid/core +tensorflow/contrib/tensor_forest/hybrid/core/ops +tensorflow/contrib/tensor_forest/hybrid/python +tensorflow/contrib/tensor_forest/hybrid/python/layers +tensorflow/contrib/tensor_forest/hybrid/python/models +tensorflow/contrib/tensor_forest/hybrid/python/ops +tensorflow/contrib/tensor_forest/kernels +tensorflow/contrib/tensor_forest/proto +tensorflow/contrib/tensor_forest/python +tensorflow/contrib/tensor_forest/python/ops +tensorflow/contrib/testing +tensorflow/contrib/testing/python +tensorflow/contrib/testing/python/framework +tensorflow/contrib/text +tensorflow/contrib/text/kernels +tensorflow/contrib/text/ops +tensorflow/contrib/text/python +tensorflow/contrib/text/python/ops +tensorflow/contrib/tfprof +tensorflow/contrib/timeseries +tensorflow/contrib/timeseries/examples +tensorflow/contrib/timeseries/examples/data +tensorflow/contrib/timeseries/python +tensorflow/contrib/timeseries/python/timeseries +tensorflow/contrib/timeseries/python/timeseries/state_space_models +tensorflow/contrib/tpu +tensorflow/contrib/tpu/ops +tensorflow/contrib/tpu/profiler +tensorflow/contrib/tpu/proto +tensorflow/contrib/tpu/python +tensorflow/contrib/tpu/python/ops +tensorflow/contrib/tpu/python/profiler +tensorflow/contrib/tpu/python/tpu +tensorflow/contrib/training +tensorflow/contrib/training/python +tensorflow/contrib/training/python/training +tensorflow/contrib/util diff --git a/tensorflow/contrib/cmake/python_protos.txt b/tensorflow/contrib/cmake/python_protos.txt new file mode 100644 index 0000000000000000000000000000000000000000..8a9c406d8b118c10ddcaafb0e4fc242aa79cdb57 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos.txt @@ -0,0 +1,19 @@ +tensorflow/core +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/boosted_trees/proto +tensorflow/contrib/cloud/kernels +tensorflow/contrib/decision_trees/proto +tensorflow/contrib/gdr +tensorflow/contrib/lite/toco +tensorflow/contrib/mpi +tensorflow/contrib/mpi_collectives +tensorflow/contrib/session_bundle +tensorflow/contrib/tensor_forest/proto +tensorflow/contrib/tensorboard/graph_explorer/proto +tensorflow/contrib/tensorboard/plugins/projector +tensorflow/contrib/tensorboard/plugins/trace +tensorflow/contrib/tpu/proto +tensorflow/contrib/tpu/profiler +tensorflow/contrib/training/python/training +tensorflow/contrib/verbs diff --git a/tensorflow/contrib/cmake/python_protos_cc.txt b/tensorflow/contrib/cmake/python_protos_cc.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4a257b25c814a1464308d0e6ce3ce65d21f6a36 --- /dev/null +++ b/tensorflow/contrib/cmake/python_protos_cc.txt @@ -0,0 +1,5 @@ +tensorflow/core/profiler +tensorflow/python +tensorflow/contrib/session_bundle +tensorflow/contrib/tensorboard +tensorflow/contrib/training diff --git a/tensorflow/contrib/cmake/python_sanity_test.py b/tensorflow/contrib/cmake/python_sanity_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e0056823a80833329bcb1f275a3384a33127bb40 --- /dev/null +++ b/tensorflow/contrib/cmake/python_sanity_test.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Complain about invalid or missing entries in python_*.txt files. + +Problematic entries can be commented for temporary whitelisting. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import unittest + + +def abs_path(path): + root = os.path.dirname(__file__) + + for _ in range(3): + root = os.path.join(root, os.pardir) + + path = os.path.join(root, path) + path = os.path.abspath(path) + return path + + +def read_entries(test): + with open(abs_path(test.entries_file), "r") as f: + lines = f.readlines() + + lines = [line.strip() for line in lines] + lines = [line for line in lines if line] + + test.entries = [] + test.whitelist = [] + + for line in lines: + # line is comment + if line.startswith("#"): + line = line[1:].strip() + # whitelist entry + if line.startswith("tensorflow/"): + test.whitelist.append(line) + # line has comment -> strip comment + elif line.find("#") != -1: + line = line[:line.find("#")].strip() + test.entries.append(line) + else: + test.entries.append(line) + + +def test_invalid_directories(test): + for entry in test.entries: + if not os.path.isdir(abs_path(entry)): + problem = "'" + test.entries_file + "' contains invalid '" + entry + "'" + solution = ("Please remove the invalid entry (or add the missing " + "directory).") + raise AssertionError(problem + "\n" + solution) + + +def test_missing_directory(test, path): + if path in test.whitelist: + return + + dir_exists = os.path.isdir(abs_path(path)) + entry_exists = path in test.entries + + if dir_exists and not entry_exists: + problem = "'" + test.entries_file + "' is missing '" + path + "'" + solution = "Please add the missing entry (comment to whitelist if needed)." + raise AssertionError(problem + "\n" + solution) + + +class PythonModuleTest(unittest.TestCase): + + def setUp(self): + self.entries_file = "tensorflow/contrib/cmake/python_modules.txt" + read_entries(self) + + def testInvalidEntries(self): + test_invalid_directories(self) + + def testMissingModules(self): + module_names = next(os.walk(abs_path("tensorflow/contrib")))[1] + + for module_name in module_names: + path = "tensorflow/contrib/" + module_name + + test_missing_directory(self, path + "/python") + test_missing_directory(self, path + "/python/ops") + test_missing_directory(self, path + "/python/kernels") + test_missing_directory(self, path + "/python/layers") + + +class PythonProtoTest(unittest.TestCase): + + def setUp(self): + self.entries_file = "tensorflow/contrib/cmake/python_protos.txt" + read_entries(self) + + def testInvalidEntries(self): + test_invalid_directories(self) + + +class PythonProtoCCTest(unittest.TestCase): + + def setUp(self): + self.entries_file = "tensorflow/contrib/cmake/python_protos_cc.txt" + read_entries(self) + + def testInvalidEntries(self): + test_invalid_directories(self) + + +if __name__ == "__main__": + unittest.main() diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c new file mode 100644 index 0000000000000000000000000000000000000000..968ab13a0c43793341431248713f81010c87f148 --- /dev/null +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.c @@ -0,0 +1,7 @@ +// This is a program to test if compiler is compatible with CUDA. +#define __CUDACC__ +#include "crt/host_config.h" + +int main(void) { + return 0; +} diff --git a/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..968ab13a0c43793341431248713f81010c87f148 --- /dev/null +++ b/tensorflow/contrib/cmake/tests/cuda/compatibility_test.cc @@ -0,0 +1,7 @@ +// This is a program to test if compiler is compatible with CUDA. +#define __CUDACC__ +#include "crt/host_config.h" + +int main(void) { + return 0; +} diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index 6e2ac203f9a7f96cb14752a91483840a9eb6b451..f73da0b8ab18af1eca4c2bd577604595f8b8ec6d 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -83,7 +83,7 @@ foreach(tf_cc_op_lib_name ${tf_cc_op_lib_names}) ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}_internal.cc - COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${tensorflow_source_dir}/tensorflow/cc/ops/op_gen_overrides.pbtxt ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api + COMMAND ${tf_cc_op_lib_name}_gen_cc ${cc_ops_target_dir}/${tf_cc_op_lib_name}.h ${cc_ops_target_dir}/${tf_cc_op_lib_name}.cc ${cc_ops_include_internal} ${tensorflow_source_dir}/tensorflow/core/api_def/base_api DEPENDS ${tf_cc_op_lib_name}_gen_cc create_cc_ops_header_dir ) @@ -149,7 +149,11 @@ add_library(tf_cc OBJECT ${tf_cc_srcs}) add_dependencies(tf_cc tf_cc_framework tf_cc_ops) if (WIN32) - set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow_internal.lib") + else() + set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib") + endif() else (WIN32) set (pywrap_tensorflow_lib "${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so") endif (WIN32) diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake index 5c01ca382fb9cc7a01a6f2b60a510c59f0aa7119..96ac60d095dbc84470ff1be92f4bf52bb420fc52 100644 --- a/tensorflow/contrib/cmake/tf_core_cpu.cmake +++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake @@ -50,6 +50,12 @@ file(GLOB_RECURSE tf_core_cpu_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/clusters/single_machine.h" @@ -63,7 +69,7 @@ if (tensorflow_ENABLE_GPU) file(GLOB_RECURSE tf_core_gpu_srcs "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc" "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu/cupti_wrapper.cc" - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc" + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc" "${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.h" "${tensorflow_source_dir}/tensorflow/core/grappler/devices.cc" diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index c607546f4a5244fb6e7cd12db874f07a962f6f4d..a1c320347fe60f87806736befc677541a93e7e93 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -126,7 +126,9 @@ endfunction() file(GLOB_RECURSE tf_protos_cc_srcs RELATIVE ${tensorflow_source_dir} "${tensorflow_source_dir}/tensorflow/core/*.proto" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" + "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" ) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_protos_cc_srcs} ) @@ -191,10 +193,6 @@ file(GLOB_RECURSE tf_core_lib_srcs "${tensorflow_source_dir}/tensorflow/core/lib/*.h" "${tensorflow_source_dir}/tensorflow/core/lib/*.cc" "${tensorflow_source_dir}/tensorflow/core/public/*.h" - # TODO(@jart): Move StatusOr into core. - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" - "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor_internals.h" ) file(GLOB tf_core_platform_srcs @@ -211,7 +209,7 @@ if (NOT tensorflow_ENABLE_GPU) list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs}) else() file(GLOB tf_core_platform_srcs_exclude - "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc") + "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc") list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude}) endif() @@ -294,6 +292,12 @@ file(GLOB_RECURSE tf_core_framework_srcs "${tensorflow_source_dir}/tensorflow/core/graph/edgeset.cc" "${tensorflow_source_dir}/tensorflow/core/graph/graph.h" "${tensorflow_source_dir}/tensorflow/core/graph/graph.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/graph_def_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.h" + "${tensorflow_source_dir}/tensorflow/core/graph/node_builder.cc" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.h" + "${tensorflow_source_dir}/tensorflow/core/graph/tensor_id.cc" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.h" "${tensorflow_source_dir}/tensorflow/core/graph/while_context.cc" "${tensorflow_source_dir}/tensorflow/core/util/*.h" @@ -317,8 +321,15 @@ file(GLOB_RECURSE tf_core_framework_exclude_srcs "${tensorflow_source_dir}/tensorflow/core/util/*test*.cc" "${tensorflow_source_dir}/tensorflow/core/util/*main.cc" "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/*test*.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/loader.cc" + "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/vacuum.cc" ) +# TODO(jart): Why doesn't this work? +# set_source_files_properties( +# ${tensorflow_source_dir}/tensorflow/contrib/tensorboard/db/snapfn.cc +# PROPERTIES COMPILE_FLAGS -DSQLITE_OMIT_LOAD_EXTENSION) + list(REMOVE_ITEM tf_core_framework_srcs ${tf_core_framework_exclude_srcs}) add_library(tf_core_framework OBJECT diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 2d015908a890fd7757bf212573f4ebce8ba8b30d..f219d5eb577afa9edaadca09aef9869c81d2bd87 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -63,10 +63,15 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc" + "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" - "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" @@ -79,12 +84,15 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/bipartite_match_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/segmentation_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/libsvm/ops/libsvm_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" @@ -150,9 +158,6 @@ list(REMOVE_ITEM tf_core_kernels_srcs ${tf_core_kernels_exclude_srcs}) if(WIN32) file(GLOB_RECURSE tf_core_kernels_windows_exclude_srcs # not working on windows yet - "${tensorflow_source_dir}/tensorflow/core/kernels/meta_support.*" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.h" - "${tensorflow_source_dir}/tensorflow/core/kernels/*quantiz*.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/neon/*" # not in core - those are loaded dynamically as dll "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/kernels/hyperplane_lsh_probes.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4a61ed7a3548b1992ddc71acb8a7761e252296ea..59e094812aaf4da2549d96314fc550e5635f9de8 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -15,6 +15,7 @@ set(tf_op_lib_names "audio_ops" "array_ops" + "batch_ops" "bitwise_ops" "candidate_sampling_ops" "checkpoint_ops" @@ -26,8 +27,10 @@ set(tf_op_lib_names "image_ops" "io_ops" "linalg_ops" + "list_ops" "lookup_ops" "logging_ops" + "manip_ops" "math_ops" "nn_ops" "no_op" @@ -80,8 +83,9 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_training "${tensorflow_source_dir}/ten GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(coder "${tensorflow_source_dir}/tensorflow/contrib/coder/ops/coder_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(data_dataset "${tensorflow_source_dir}/tensorflow/contrib/data/ops/dataset_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") @@ -92,6 +96,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/con GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(periodic_resample "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nearest_neighbor "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/ops/nearest_neighbor_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(resampler "${tensorflow_source_dir}/tensorflow/contrib/resampler/ops/resampler_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_core_profiler.cmake b/tensorflow/contrib/cmake/tf_core_profiler.cmake index 61ed6a1e145299125d037b48b8b644cae1ce96e7..b91a7f43e5c03e933d10572e54e0c8c914c55f71 100644 --- a/tensorflow/contrib/cmake/tf_core_profiler.cmake +++ b/tensorflow/contrib/cmake/tf_core_profiler.cmake @@ -17,6 +17,8 @@ ######################################################## file(GLOB_RECURSE tf_core_profiler_srcs "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" + "${tensorflow_source_dir}/tensorflow/core/profiler/tfprof_options.h" + "${tensorflow_source_dir}/tensorflow/core/profiler/tfprof_options.cc" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/*.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/*.cc" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/advisor/*.h" diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 0128946e45ea48f47a8be0df66e498bb0240de11..b730ebd3baacafe8ae401e8987104f3062372954 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -120,33 +120,46 @@ function(RELATIVE_PROTOBUF_GENERATE_CPP SRCS HDRS ROOT_DIR) set(${HDRS} ${${HDRS}} PARENT_SCOPE) endfunction() -file(GLOB_RECURSE tf_protos_python_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/*.proto" - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/decision_trees/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/proto/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tpu/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos.txt python_protos) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos "${python_protos}") +STRING(REGEX REPLACE "\n" ";" python_protos "${python_protos}") + +foreach(python_proto ${python_protos}) + if(NOT python_proto MATCHES "^\#") + STRING(REGEX REPLACE " *\#.*" "" python_proto "${python_proto}") + if(NOT EXISTS "${tensorflow_source_dir}/${python_proto}") + message(SEND_ERROR "Python proto directory not found: ${python_proto}") + endif() + file(GLOB_RECURSE tf_python_protos_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto}/*.proto" + ) + list(APPEND tf_python_protos_srcs ${tf_python_protos_src}) + endif() +endforeach(python_proto) + RELATIVE_PROTOBUF_GENERATE_PYTHON( - ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_protos_python_srcs} + ${tensorflow_source_dir} PYTHON_PROTO_GENFILES ${tf_python_protos_srcs} ) -# NOTE(mrry): Avoid regenerating the tensorflow/core protos because this -# can cause benign-but-failing-on-Windows-due-to-file-locking conflicts -# when two rules attempt to generate the same file. -file(GLOB_RECURSE tf_python_protos_cc_srcs RELATIVE ${tensorflow_source_dir} - "${tensorflow_source_dir}/tensorflow/core/profiler/*.proto" - "${tensorflow_source_dir}/tensorflow/python/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/session_bundle/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/tensorboard/*.proto" - "${tensorflow_source_dir}/tensorflow/contrib/training/*.proto" -) +FILE(READ python_protos_cc.txt python_protos_cc) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_protos_cc "${python_protos_cc}") +STRING(REGEX REPLACE "\n" ";" python_protos_cc "${python_protos_cc}") + +foreach(python_proto_cc ${python_protos_cc}) + if(NOT python_proto_cc MATCHES "^\#") + STRING(REGEX REPLACE " *\#.*" "" python_proto_cc "${python_proto_cc}") + if(NOT EXISTS "${tensorflow_source_dir}/${python_proto_cc}") + message(SEND_ERROR "Python proto CC directory not found: ${python_proto_cc}") + endif() + file(GLOB_RECURSE tf_python_protos_cc_src RELATIVE ${tensorflow_source_dir} + "${tensorflow_source_dir}/${python_proto_cc}/*.proto" + ) + list(APPEND tf_python_protos_cc_srcs ${tf_python_protos_cc_src}) + endif() +endforeach(python_proto_cc) + RELATIVE_PROTOBUF_GENERATE_CPP(PROTO_SRCS PROTO_HDRS ${tensorflow_source_dir} ${tf_python_protos_cc_srcs} ) @@ -192,315 +205,21 @@ function(add_python_module MODULE_NAME) endif() endfunction() -add_python_module("tensorflow") -add_python_module("tensorflow/core") -add_python_module("tensorflow/core/example") -add_python_module("tensorflow/core/framework") -add_python_module("tensorflow/core/lib") -add_python_module("tensorflow/core/lib/core") -add_python_module("tensorflow/core/protobuf") -add_python_module("tensorflow/core/util") -add_python_module("tensorflow/examples") -add_python_module("tensorflow/examples/tutorials") -add_python_module("tensorflow/examples/tutorials/mnist") -add_python_module("tensorflow/python") -add_python_module("tensorflow/python/client") -add_python_module("tensorflow/python/data") -add_python_module("tensorflow/python/data/ops") -add_python_module("tensorflow/python/data/util") -add_python_module("tensorflow/python/debug") -add_python_module("tensorflow/python/debug/cli") -add_python_module("tensorflow/python/debug/examples") -add_python_module("tensorflow/python/debug/lib") -add_python_module("tensorflow/python/debug/wrappers") -add_python_module("tensorflow/python/eager") -add_python_module("tensorflow/python/estimator") -add_python_module("tensorflow/python/estimator/canned") -add_python_module("tensorflow/python/estimator/export") -add_python_module("tensorflow/python/estimator/inputs") -add_python_module("tensorflow/python/estimator/inputs/queues") -add_python_module("tensorflow/python/feature_column") -add_python_module("tensorflow/python/framework") -add_python_module("tensorflow/python/grappler") -add_python_module("tensorflow/python/keras") -add_python_module("tensorflow/python/keras/activations") -add_python_module("tensorflow/python/keras/applications") -add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") -add_python_module("tensorflow/python/keras/applications/inception_v3") -add_python_module("tensorflow/python/keras/applications/mobilenet") -add_python_module("tensorflow/python/keras/applications/resnet50") -add_python_module("tensorflow/python/keras/applications/vgg16") -add_python_module("tensorflow/python/keras/applications/vgg19") -add_python_module("tensorflow/python/keras/applications/xception") -add_python_module("tensorflow/python/keras/backend") -add_python_module("tensorflow/python/keras/callbacks") -add_python_module("tensorflow/python/keras/constraints") -add_python_module("tensorflow/python/keras/datasets") -add_python_module("tensorflow/python/keras/datasets/boston_housing") -add_python_module("tensorflow/python/keras/datasets/cifar10") -add_python_module("tensorflow/python/keras/datasets/cifar100") -add_python_module("tensorflow/python/keras/datasets/fashion_mnist") -add_python_module("tensorflow/python/keras/datasets/imdb") -add_python_module("tensorflow/python/keras/datasets/mnist") -add_python_module("tensorflow/python/keras/datasets/reuters") -add_python_module("tensorflow/python/keras/estimator") -add_python_module("tensorflow/python/keras/initializers") -add_python_module("tensorflow/python/keras/layers") -add_python_module("tensorflow/python/keras/losses") -add_python_module("tensorflow/python/keras/metrics") -add_python_module("tensorflow/python/keras/models") -add_python_module("tensorflow/python/keras/optimizers") -add_python_module("tensorflow/python/keras/preprocessing") -add_python_module("tensorflow/python/keras/preprocessing/image") -add_python_module("tensorflow/python/keras/preprocessing/sequence") -add_python_module("tensorflow/python/keras/preprocessing/text") -add_python_module("tensorflow/python/keras/regularizers") -add_python_module("tensorflow/python/keras/utils") -add_python_module("tensorflow/python/keras/wrappers") -add_python_module("tensorflow/python/keras/wrappers/scikit_learn") -add_python_module("tensorflow/python/keras/_impl") -add_python_module("tensorflow/python/keras/_impl/keras") -add_python_module("tensorflow/python/keras/_impl/keras/applications") -add_python_module("tensorflow/python/keras/_impl/keras/datasets") -add_python_module("tensorflow/python/keras/_impl/keras/engine") -add_python_module("tensorflow/python/keras/_impl/keras/layers") -add_python_module("tensorflow/python/keras/_impl/keras/preprocessing") -add_python_module("tensorflow/python/keras/_impl/keras/utils") -add_python_module("tensorflow/python/keras/_impl/keras/wrappers") -add_python_module("tensorflow/python/kernel_tests") -add_python_module("tensorflow/python/kernel_tests/distributions") -add_python_module("tensorflow/python/kernel_tests/linalg") -add_python_module("tensorflow/python/layers") -add_python_module("tensorflow/python/lib") -add_python_module("tensorflow/python/lib/core") -add_python_module("tensorflow/python/lib/io") -add_python_module("tensorflow/python/ops") -add_python_module("tensorflow/python/ops/distributions") -add_python_module("tensorflow/python/ops/linalg") -add_python_module("tensorflow/python/ops/losses") -add_python_module("tensorflow/python/platform") -add_python_module("tensorflow/python/platform/default") -add_python_module("tensorflow/python/platform/summary") -add_python_module("tensorflow/python/profiler/") -add_python_module("tensorflow/python/profiler/internal") -add_python_module("tensorflow/python/saved_model") -add_python_module("tensorflow/python/summary") -add_python_module("tensorflow/python/summary/writer") -add_python_module("tensorflow/python/tools") -add_python_module("tensorflow/python/training") -add_python_module("tensorflow/python/user_ops") -add_python_module("tensorflow/python/util") -add_python_module("tensorflow/python/util/protobuf") -add_python_module("tensorflow/tools") -add_python_module("tensorflow/tools/graph_transforms") -add_python_module("tensorflow/contrib") -add_python_module("tensorflow/contrib/all_reduce") -add_python_module("tensorflow/contrib/all_reduce/python") -add_python_module("tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/java") -add_python_module("tensorflow/contrib/android/java/org") -add_python_module("tensorflow/contrib/android/java/org/tensorflow") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib") -add_python_module("tensorflow/contrib/android/java/org/tensorflow/contrib/android") -add_python_module("tensorflow/contrib/android/jni") -add_python_module("tensorflow/contrib/bayesflow") -add_python_module("tensorflow/contrib/bayesflow/examples") -add_python_module("tensorflow/contrib/bayesflow/examples/reinforce_simple") -add_python_module("tensorflow/contrib/bayesflow/python") -add_python_module("tensorflow/contrib/bayesflow/python/kernel_tests") -add_python_module("tensorflow/contrib/bayesflow/python/ops") -add_python_module("tensorflow/contrib/boosted_trees") -add_python_module("tensorflow/contrib/boosted_trees/estimator_batch") -add_python_module("tensorflow/contrib/boosted_trees/ops") -add_python_module("tensorflow/contrib/boosted_trees/proto") -add_python_module("tensorflow/contrib/boosted_trees/python") -add_python_module("tensorflow/contrib/boosted_trees/python/kernel_tests") -add_python_module("tensorflow/contrib/boosted_trees/python/ops") -add_python_module("tensorflow/contrib/cloud") -add_python_module("tensorflow/contrib/cloud/kernels") -add_python_module("tensorflow/contrib/cloud/ops") -add_python_module("tensorflow/contrib/cloud/python") -add_python_module("tensorflow/contrib/cloud/python/ops") -add_python_module("tensorflow/contrib/cluster_resolver") -add_python_module("tensorflow/contrib/cluster_resolver/python") -add_python_module("tensorflow/contrib/cluster_resolver/python/training") -add_python_module("tensorflow/contrib/compiler") -add_python_module("tensorflow/contrib/copy_graph") -add_python_module("tensorflow/contrib/copy_graph/python") -add_python_module("tensorflow/contrib/copy_graph/python/util") -add_python_module("tensorflow/contrib/crf") -add_python_module("tensorflow/contrib/crf/python") -add_python_module("tensorflow/contrib/crf/python/kernel_tests") -add_python_module("tensorflow/contrib/crf/python/ops") -add_python_module("tensorflow/contrib/cudnn_rnn") -add_python_module("tensorflow/contrib/cudnn_rnn/kernels") -add_python_module("tensorflow/contrib/cudnn_rnn/ops") -add_python_module("tensorflow/contrib/cudnn_rnn/python") -add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/cudnn_rnn/python/layers") -add_python_module("tensorflow/contrib/cudnn_rnn/python/ops") -add_python_module("tensorflow/contrib/data") -add_python_module("tensorflow/contrib/data/python") -add_python_module("tensorflow/contrib/data/python/kernel_tests") -add_python_module("tensorflow/contrib/data/python/ops") -add_python_module("tensorflow/contrib/decision_trees") -add_python_module("tensorflow/contrib/decision_trees/proto") -add_python_module("tensorflow/contrib/deprecated") -add_python_module("tensorflow/contrib/distributions") -add_python_module("tensorflow/contrib/distributions/python") -add_python_module("tensorflow/contrib/distributions/python/kernel_tests") -add_python_module("tensorflow/contrib/distributions/python/ops") -add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") -add_python_module("tensorflow/contrib/eager") -add_python_module("tensorflow/contrib/eager/python") -add_python_module("tensorflow/contrib/estimator") -add_python_module("tensorflow/contrib/estimator/python") -add_python_module("tensorflow/contrib/estimator/python/estimator") -add_python_module("tensorflow/contrib/factorization") -add_python_module("tensorflow/contrib/factorization/examples") -add_python_module("tensorflow/contrib/factorization/kernels") -add_python_module("tensorflow/contrib/factorization/ops") -add_python_module("tensorflow/contrib/factorization/python") -add_python_module("tensorflow/contrib/factorization/python/kernel_tests") -add_python_module("tensorflow/contrib/factorization/python/ops") -add_python_module("tensorflow/contrib/ffmpeg") -add_python_module("tensorflow/contrib/ffmpeg/default") -add_python_module("tensorflow/contrib/ffmpeg/testdata") -add_python_module("tensorflow/contrib/framework") -add_python_module("tensorflow/contrib/framework/kernels") -add_python_module("tensorflow/contrib/framework/ops") -add_python_module("tensorflow/contrib/framework/python") -add_python_module("tensorflow/contrib/framework/python/framework") -add_python_module("tensorflow/contrib/framework/python/ops") -add_python_module("tensorflow/contrib/gan") -add_python_module("tensorflow/contrib/gan/python") -add_python_module("tensorflow/contrib/gan/python/eval") -add_python_module("tensorflow/contrib/gan/python/eval/python") -add_python_module("tensorflow/contrib/gan/python/features") -add_python_module("tensorflow/contrib/gan/python/features/python") -add_python_module("tensorflow/contrib/gan/python/estimator") -add_python_module("tensorflow/contrib/gan/python/estimator/python") -add_python_module("tensorflow/contrib/gan/python/losses") -add_python_module("tensorflow/contrib/gan/python/losses/python") -add_python_module("tensorflow/contrib/graph_editor") -add_python_module("tensorflow/contrib/graph_editor/examples") -add_python_module("tensorflow/contrib/graph_editor/tests") -add_python_module("tensorflow/contrib/grid_rnn") -add_python_module("tensorflow/contrib/grid_rnn/python") -add_python_module("tensorflow/contrib/grid_rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/grid_rnn/python/ops") -add_python_module("tensorflow/contrib/hooks") -add_python_module("tensorflow/contrib/image") -add_python_module("tensorflow/contrib/image/ops") -add_python_module("tensorflow/contrib/image/python") -add_python_module("tensorflow/contrib/image/python/ops") -add_python_module("tensorflow/contrib/input_pipeline") -add_python_module("tensorflow/contrib/input_pipeline/ops") -add_python_module("tensorflow/contrib/input_pipeline/python") -add_python_module("tensorflow/contrib/input_pipeline/python/ops") -add_python_module("tensorflow/contrib/integrate") -add_python_module("tensorflow/contrib/integrate/python") -add_python_module("tensorflow/contrib/integrate/python/ops") -add_python_module("tensorflow/contrib/ios_examples") -add_python_module("tensorflow/contrib/ios_examples/benchmark") -add_python_module("tensorflow/contrib/ios_examples/benchmark/benchmark.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/benchmark/data") -add_python_module("tensorflow/contrib/ios_examples/camera") -add_python_module("tensorflow/contrib/ios_examples/camera/camera_example.xcodeproj") -add_python_module("tensorflow/contrib/ios_examples/camera/en.lproj") -add_python_module("tensorflow/contrib/ios_examples/simple") -add_python_module("tensorflow/contrib/ios_examples/simple/data") -add_python_module("tensorflow/contrib/ios_examples/simple/tf_ios_makefile_example.xcodeproj") -add_python_module("tensorflow/contrib/keras") -add_python_module("tensorflow/contrib/keras/api") -add_python_module("tensorflow/contrib/keras/api/keras") -add_python_module("tensorflow/contrib/keras/api/keras/activations") -add_python_module("tensorflow/contrib/keras/api/keras/applications") -add_python_module("tensorflow/contrib/keras/api/keras/applications/inception_v3") -add_python_module("tensorflow/contrib/keras/api/keras/applications/mobilenet") -add_python_module("tensorflow/contrib/keras/api/keras/applications/resnet50") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg16") -add_python_module("tensorflow/contrib/keras/api/keras/applications/vgg19") -add_python_module("tensorflow/contrib/keras/api/keras/applications/xception") -add_python_module("tensorflow/contrib/keras/api/keras/backend") -add_python_module("tensorflow/contrib/keras/api/keras/callbacks") -add_python_module("tensorflow/contrib/keras/api/keras/constraints") -add_python_module("tensorflow/contrib/keras/api/keras/datasets") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/boston_housing") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar10") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/cifar100") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/imdb") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/mnist") -add_python_module("tensorflow/contrib/keras/api/keras/datasets/reuters") -add_python_module("tensorflow/contrib/keras/api/keras/initializers") -add_python_module("tensorflow/contrib/keras/api/keras/layers") -add_python_module("tensorflow/contrib/keras/api/keras/losses") -add_python_module("tensorflow/contrib/keras/api/keras/metrics") -add_python_module("tensorflow/contrib/keras/api/keras/models") -add_python_module("tensorflow/contrib/keras/api/keras/optimizers") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/image") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/sequence") -add_python_module("tensorflow/contrib/keras/api/keras/preprocessing/text") -add_python_module("tensorflow/contrib/keras/api/keras/regularizers") -add_python_module("tensorflow/contrib/keras/api/keras/utils") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers") -add_python_module("tensorflow/contrib/keras/api/keras/wrappers/scikit_learn") -add_python_module("tensorflow/contrib/keras/python") -add_python_module("tensorflow/contrib/keras/python/keras") -add_python_module("tensorflow/contrib/keras/python/keras/applications") -add_python_module("tensorflow/contrib/keras/python/keras/datasets") -add_python_module("tensorflow/contrib/keras/python/keras/engine") -add_python_module("tensorflow/contrib/keras/python/keras/layers") -add_python_module("tensorflow/contrib/keras/python/keras/preprocessing") -add_python_module("tensorflow/contrib/keras/python/keras/utils") -add_python_module("tensorflow/contrib/keras/python/keras/wrappers") -add_python_module("tensorflow/contrib/kernel_methods") -add_python_module("tensorflow/contrib/kernel_methods/python") -add_python_module("tensorflow/contrib/kernel_methods/python/mappers") -add_python_module("tensorflow/contrib/kfac") -add_python_module("tensorflow/contrib/kfac/examples") -add_python_module("tensorflow/contrib/kfac/python") -add_python_module("tensorflow/contrib/kfac/python/ops") -add_python_module("tensorflow/contrib/labeled_tensor") -add_python_module("tensorflow/contrib/labeled_tensor/python") -add_python_module("tensorflow/contrib/labeled_tensor/python/ops") -add_python_module("tensorflow/contrib/layers") -add_python_module("tensorflow/contrib/layers/kernels") -add_python_module("tensorflow/contrib/layers/ops") -add_python_module("tensorflow/contrib/layers/python") -add_python_module("tensorflow/contrib/layers/python/kernel_tests") -add_python_module("tensorflow/contrib/layers/python/layers") -add_python_module("tensorflow/contrib/layers/python/ops") -add_python_module("tensorflow/contrib/learn") -add_python_module("tensorflow/contrib/learn/python") -add_python_module("tensorflow/contrib/learn/python/learn") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/queues") -add_python_module("tensorflow/contrib/learn/python/learn/dataframe/transforms") -add_python_module("tensorflow/contrib/learn/python/learn/datasets") -add_python_module("tensorflow/contrib/learn/python/learn/datasets/data") -add_python_module("tensorflow/contrib/learn/python/learn/estimators") -add_python_module("tensorflow/contrib/learn/python/learn/learn_io") -add_python_module("tensorflow/contrib/learn/python/learn/ops") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing") -add_python_module("tensorflow/contrib/learn/python/learn/preprocessing/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests") -add_python_module("tensorflow/contrib/learn/python/learn/tests/dataframe") -add_python_module("tensorflow/contrib/learn/python/learn/utils") -add_python_module("tensorflow/contrib/legacy_seq2seq") -add_python_module("tensorflow/contrib/legacy_seq2seq/python") -add_python_module("tensorflow/contrib/legacy_seq2seq/python/ops") -add_python_module("tensorflow/contrib/linalg") -add_python_module("tensorflow/contrib/linalg/python") -add_python_module("tensorflow/contrib/linalg/python/ops") -add_python_module("tensorflow/contrib/linalg/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer") -add_python_module("tensorflow/contrib/linear_optimizer/kernels") -add_python_module("tensorflow/contrib/linear_optimizer/kernels/g3doc") -add_python_module("tensorflow/contrib/linear_optimizer/python") -add_python_module("tensorflow/contrib/linear_optimizer/python/kernel_tests") -add_python_module("tensorflow/contrib/linear_optimizer/python/ops") +FILE(READ python_modules.txt python_modules) +# Convert file contents into a CMake list (where each element in the list is one line of the file) +STRING(REGEX REPLACE ";" "\\\\;" python_modules "${python_modules}") +STRING(REGEX REPLACE "\n" ";" python_modules "${python_modules}") + +foreach(python_module ${python_modules}) + if(NOT python_module MATCHES "^\#") + STRING(REGEX REPLACE " *\#.*" "" python_module "${python_module}") + if(NOT EXISTS "${tensorflow_source_dir}/${python_module}") + message(SEND_ERROR "Python module not found: ${python_module}") + endif() + add_python_module(${python_module}) + endif() +endforeach(python_module) + add_custom_command(TARGET tf_python_touchup_modules PRE_BUILD COMMAND ${CMAKE_COMMAND} -E make_directory "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite") @@ -514,157 +233,6 @@ add_custom_command( TARGET tf_python_copy_scripts_to_destination PRE_BUILD COMMAND ${CMAKE_COMMAND} -E touch ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/lite/python/lite.py) -add_python_module("tensorflow/contrib/lookup") -add_python_module("tensorflow/contrib/losses") -add_python_module("tensorflow/contrib/losses/python") -add_python_module("tensorflow/contrib/losses/python/losses") -add_python_module("tensorflow/contrib/losses/python/metric_learning") -add_python_module("tensorflow/contrib/makefile") -add_python_module("tensorflow/contrib/makefile/test") -add_python_module("tensorflow/contrib/memory_stats") -add_python_module("tensorflow/contrib/memory_stats/kernels") -add_python_module("tensorflow/contrib/memory_stats/ops") -add_python_module("tensorflow/contrib/memory_stats/python") -add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests") -add_python_module("tensorflow/contrib/memory_stats/python/ops") -add_python_module("tensorflow/contrib/meta_graph_transform") -add_python_module("tensorflow/contrib/metrics") -add_python_module("tensorflow/contrib/metrics/kernels") -add_python_module("tensorflow/contrib/metrics/ops") -add_python_module("tensorflow/contrib/metrics/python") -add_python_module("tensorflow/contrib/metrics/python/kernel_tests") -add_python_module("tensorflow/contrib/metrics/python/metrics") -add_python_module("tensorflow/contrib/metrics/python/ops") -add_python_module("tensorflow/contrib/model_pruning") -add_python_module("tensorflow/contrib/model_pruning/examples") -add_python_module("tensorflow/contrib/model_pruning/examples/cifar10") -add_python_module("tensorflow/contrib/model_pruning/python") -add_python_module("tensorflow/contrib/model_pruning/python/layers") -add_python_module("tensorflow/contrib/ndlstm") -add_python_module("tensorflow/contrib/ndlstm/python") -add_python_module("tensorflow/contrib/nn") -add_python_module("tensorflow/contrib/nn/python") -add_python_module("tensorflow/contrib/nn/python/ops") -add_python_module("tensorflow/contrib/nccl") -add_python_module("tensorflow/contrib/nccl/kernels") -add_python_module("tensorflow/contrib/nccl/ops") -add_python_module("tensorflow/contrib/nccl/python") -add_python_module("tensorflow/contrib/nccl/python/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/kernels") -add_python_module("tensorflow/contrib/nearest_neighbor/ops") -add_python_module("tensorflow/contrib/nearest_neighbor/python") -add_python_module("tensorflow/contrib/nearest_neighbor/python/kernel_tests") -add_python_module("tensorflow/contrib/nearest_neighbor/python/ops") -add_python_module("tensorflow/contrib/opt") -add_python_module("tensorflow/contrib/opt/python") -add_python_module("tensorflow/contrib/opt/python/training") -add_python_module("tensorflow/contrib/pi_examples") -add_python_module("tensorflow/contrib/pi_examples/camera") -add_python_module("tensorflow/contrib/pi_examples/label_image") -add_python_module("tensorflow/contrib/pi_examples/label_image/data") -add_python_module("tensorflow/contrib/predictor") -add_python_module("tensorflow/contrib/quantization") -add_python_module("tensorflow/contrib/quantization/python") -add_python_module("tensorflow/contrib/quantize") -add_python_module("tensorflow/contrib/quantize/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") -add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") -add_python_module("tensorflow/contrib/resampler") -add_python_module("tensorflow/contrib/resampler/kernels") -add_python_module("tensorflow/contrib/resampler/ops") -add_python_module("tensorflow/contrib/resampler/python") -add_python_module("tensorflow/contrib/resampler/python/ops") -add_python_module("tensorflow/contrib/rnn") -add_python_module("tensorflow/contrib/rnn/kernels") -add_python_module("tensorflow/contrib/rnn/ops") -add_python_module("tensorflow/contrib/rnn/python") -add_python_module("tensorflow/contrib/rnn/python/kernel_tests") -add_python_module("tensorflow/contrib/rnn/python/ops") -add_python_module("tensorflow/contrib/saved_model") -add_python_module("tensorflow/contrib/saved_model/python") -add_python_module("tensorflow/contrib/saved_model/python/saved_model") -add_python_module("tensorflow/contrib/seq2seq") -add_python_module("tensorflow/contrib/seq2seq/kernels") -add_python_module("tensorflow/contrib/seq2seq/ops") -add_python_module("tensorflow/contrib/seq2seq/python") -add_python_module("tensorflow/contrib/seq2seq/python/kernel_tests") -add_python_module("tensorflow/contrib/seq2seq/python/ops") -add_python_module("tensorflow/contrib/session_bundle") -add_python_module("tensorflow/contrib/session_bundle/example") -add_python_module("tensorflow/contrib/session_bundle/testdata") -add_python_module("tensorflow/contrib/signal") -add_python_module("tensorflow/contrib/signal/python") -add_python_module("tensorflow/contrib/signal/python/ops") -add_python_module("tensorflow/contrib/slim") -add_python_module("tensorflow/contrib/slim/python") -add_python_module("tensorflow/contrib/slim/python/slim") -add_python_module("tensorflow/contrib/slim/python/slim/data") -add_python_module("tensorflow/contrib/slim/python/slim/nets") -add_python_module("tensorflow/contrib/solvers") -add_python_module("tensorflow/contrib/solvers/python") -add_python_module("tensorflow/contrib/solvers/python/ops") -add_python_module("tensorflow/contrib/sparsemax") -add_python_module("tensorflow/contrib/sparsemax/python") -add_python_module("tensorflow/contrib/sparsemax/python/ops") -add_python_module("tensorflow/contrib/specs") -add_python_module("tensorflow/contrib/specs/python") -add_python_module("tensorflow/contrib/staging") -add_python_module("tensorflow/contrib/stat_summarizer") -add_python_module("tensorflow/contrib/stateless") -add_python_module("tensorflow/contrib/tensorboard") -add_python_module("tensorflow/contrib/tensorboard/plugins") -add_python_module("tensorflow/contrib/tensorboard/plugins/projector") -add_python_module("tensorflow/contrib/tensor_forest") -add_python_module("tensorflow/contrib/tensor_forest/client") -add_python_module("tensorflow/contrib/tensor_forest/core") -add_python_module("tensorflow/contrib/tensor_forest/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/data") -add_python_module("tensorflow/contrib/tensor_forest/hybrid") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/core/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/ops") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/layers") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/models") -add_python_module("tensorflow/contrib/tensor_forest/hybrid/python/ops") -add_python_module("tensorflow/contrib/tensor_forest/python") -add_python_module("tensorflow/contrib/tensor_forest/python/kernel_tests") -add_python_module("tensorflow/contrib/tensor_forest/python/ops") -add_python_module("tensorflow/contrib/testing") -add_python_module("tensorflow/contrib/testing/python") -add_python_module("tensorflow/contrib/testing/python/framework") -add_python_module("tensorflow/contrib/text") -add_python_module("tensorflow/contrib/text/kernels") -add_python_module("tensorflow/contrib/text/ops") -add_python_module("tensorflow/contrib/text/python") -add_python_module("tensorflow/contrib/text/python/ops") -add_python_module("tensorflow/contrib/tfprof") -add_python_module("tensorflow/contrib/timeseries") -add_python_module("tensorflow/contrib/timeseries/examples") -add_python_module("tensorflow/contrib/timeseries/examples/data") -add_python_module("tensorflow/contrib/timeseries/python") -add_python_module("tensorflow/contrib/timeseries/python/timeseries") -add_python_module("tensorflow/contrib/timeseries/python/timeseries/state_space_models") -add_python_module("tensorflow/contrib/tpu") -add_python_module("tensorflow/contrib/tpu/ops") -add_python_module("tensorflow/contrib/tpu/profiler") -add_python_module("tensorflow/contrib/tpu/python") -add_python_module("tensorflow/contrib/tpu/python/ops") -add_python_module("tensorflow/contrib/tpu/python/profiler") -add_python_module("tensorflow/contrib/tpu/python/tpu") -add_python_module("tensorflow/contrib/training") -add_python_module("tensorflow/contrib/training/python") -add_python_module("tensorflow/contrib/training/python/training") -add_python_module("tensorflow/contrib/util") -add_python_module("tensorflow/contrib/reduce_slice_ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/kernels") -add_python_module("tensorflow/contrib/reduce_slice_ops/ops") -add_python_module("tensorflow/contrib/reduce_slice_ops/python") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") -add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") -add_python_module("tensorflow/contrib/summary") # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") @@ -739,7 +307,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) # containing the wrappers. add_custom_command( OUTPUT ${GENERATE_PYTHON_OP_LIB_DESTINATION} - COMMAND ${tf_python_op_lib_name}_gen_python @${tensorflow_source_dir}/tensorflow/python/ops/hidden_ops.txt ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} + COMMAND ${tf_python_op_lib_name}_gen_python ${tensorflow_source_dir}/tensorflow/core/api_def/base_api,${tensorflow_source_dir}/tensorflow/core/api_def/python_api ${require_shape_fn} > ${GENERATE_PYTHON_OP_LIB_DESTINATION} DEPENDS ${tf_python_op_lib_name}_gen_python ) @@ -749,6 +317,7 @@ endfunction() GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") +GENERATE_PYTHON_OP_LIB("batch_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("math_ops") GENERATE_PYTHON_OP_LIB("functional_ops") @@ -762,9 +331,11 @@ GENERATE_PYTHON_OP_LIB("dataset_ops") GENERATE_PYTHON_OP_LIB("image_ops") GENERATE_PYTHON_OP_LIB("io_ops") GENERATE_PYTHON_OP_LIB("linalg_ops") +GENERATE_PYTHON_OP_LIB("list_ops") GENERATE_PYTHON_OP_LIB("logging_ops") GENERATE_PYTHON_OP_LIB("lookup_ops") GENERATE_PYTHON_OP_LIB("nn_ops") +GENERATE_PYTHON_OP_LIB("manip_ops") GENERATE_PYTHON_OP_LIB("parsing_ops") GENERATE_PYTHON_OP_LIB("random_ops") GENERATE_PYTHON_OP_LIB("remote_fused_graph_ops" @@ -793,10 +364,12 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_quantiles_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_quantile_ops.py) GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_coder_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/coder/python/ops/gen_coder_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_data_dataset_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_dataset_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" @@ -817,6 +390,9 @@ GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_periodic_resample_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/gen_periodic_resample_op.py) + GENERATE_PYTHON_OP_LIB("contrib_nearest_neighbor_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nearest_neighbor/ops/gen_nearest_neighbor_ops.py) GENERATE_PYTHON_OP_LIB("contrib_resampler_ops" @@ -889,6 +465,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/framework/cpp_shape_inference.cc" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.h" "${tensorflow_source_dir}/tensorflow/python/framework/python_op_gen.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/bfloat16.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/numpy.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/ndarray_tensor.h" @@ -899,6 +477,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h" + "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h" "${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc" "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h" @@ -961,7 +541,11 @@ if(WIN32) ${nsync_STATIC_LIBRARIES} ) - set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow.def") + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/pywrap_tensorflow.def") + else() + set(pywrap_tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow.def") + endif() set_source_files_properties(${pywrap_tensorflow_deffile} PROPERTIES GENERATED TRUE) add_custom_command(TARGET pywrap_tensorflow_internal_static POST_BUILD @@ -969,6 +553,7 @@ if(WIN32) --input "${pywrap_tensorflow_internal_static_dependencies}" --output "${pywrap_tensorflow_deffile}" --target _pywrap_tensorflow_internal.pyd + BYPRODUCTS ${pywrap_tensorflow_deffile} # Required for Ninja ) endif(WIN32) @@ -1015,6 +600,20 @@ target_link_libraries(pywrap_tensorflow_internal PRIVATE ) if(WIN32) + + # include contrib/periodic_resample as .so + # + set(tf_periodic_resample_srcs + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/kernels/periodic_resample_op.h" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/ops/array_ops.cc" + ) + + AddUserOps(TARGET _periodic_resample_op + SOURCES "${tf_periodic_resample_srcs}" + DEPENDS pywrap_tensorflow_internal tf_python_ops + DISTCOPY ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/periodic_resample/python/ops/) + # include contrib/nearest_neighbor as .so # set(tf_nearest_neighbor_srcs @@ -1108,11 +707,19 @@ add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/testing/python/framework/) if(WIN32) - add_custom_command(TARGET tf_python_build_pip_package POST_BUILD - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd - COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib - ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/$(Configuration)/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + else() + add_custom_command(TARGET tf_python_build_pip_package POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.dll + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/_pywrap_tensorflow_internal.pyd + COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.lib + ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/) + endif() else() add_custom_command(TARGET tf_python_build_pip_package POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_BINARY_DIR}/libpywrap_tensorflow_internal.so diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 571d2b0decb5e9afcec2314f9837546f0974e90d..6d36d5fc5c2854b2d7d2542a3cb12e033e193b88 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -46,7 +46,11 @@ if(WIN32) $ ) - set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def") + if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*") + set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/${CMAKE_BUILD_TYPE}/tensorflow.def") + else() + set(tensorflow_deffile "${CMAKE_CURRENT_BINARY_DIR}/tensorflow.def") + endif() set_source_files_properties(${tensorflow_deffile} PROPERTIES GENERATED TRUE) add_custom_command(TARGET tensorflow_static POST_BUILD diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 18b71d1f9a47b717200a01ffc368f8c8daaa1519..1c4ebd7f0c1113bcd0857fb0858df2248499f920 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -139,21 +139,27 @@ if (tensorflow_BUILD_PYTHON_TESTS) file(GLOB_RECURSE tf_test_src_py ${tf_test_rnn_src_py} + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/debug/cli/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/*_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py" "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_conv_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/ops/quantized_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/platform/build_info_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py" "${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py" "${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py" "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/coder/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/feature_column/python/feature_column/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py" @@ -186,6 +192,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/profiler/pprof_profiler_test.py" # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" + "${tensorflow_source_dir}/tensorflow/python/profiler/model_analyzer_test.py" # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" # requires scipy @@ -216,15 +223,20 @@ if (tensorflow_BUILD_PYTHON_TESTS) # TFDBG grpc:// mode is not yet available on Windows. "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/source_remote_test.py" # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/clip_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/list_ops_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/tensor_array_ops_test.py" # Needs portpicker. # Numerical issues, calculations off. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/periodic_resample/python/kernel_tests/periodic_resample_op_test.py" "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py" + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py" # Float division by zero "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. @@ -233,11 +245,11 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" # IteratorGetMax OutOfRangeError @@ -261,11 +273,10 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops # Dataset tests - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" - # Broken tensorboard test due to cmake issues. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker + "${tensorflow_source_dir}/tensorflow/python/data/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py" # Deadlocks "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/count_extremely_random_stats_op_test.py" # Results in wrong order. @@ -294,6 +305,13 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" + # Depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/control_flow_util_test.py" + # Flaky replicate_model_fn_test + "${tensorflow_source_dir}/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py" # b/71901810 + # Broken io_utils_test + "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/io_utils_test.py" # b/72894325 ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) @@ -361,7 +379,6 @@ if (tensorflow_BUILD_CC_TESTS) "${tensorflow_source_dir}/tensorflow/core/distributed_runtime/tensor_coding_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/graph_transferer_test.cc" - "${tensorflow_source_dir}/tensorflow/core/kernels/hexagon/quantized_matmul_op_for_hexagon_test.cc" ) if (NOT tensorflow_ENABLE_GPU) diff --git a/tensorflow/contrib/cmake/tf_tools.cmake b/tensorflow/contrib/cmake/tf_tools.cmake index cb58a2e7df85b2f214654eff5547c5788592f208..58c7df95c821b4d1aa2cc63c8aaf4039518b83ca 100644 --- a/tensorflow/contrib/cmake/tf_tools.cmake +++ b/tensorflow/contrib/cmake/tf_tools.cmake @@ -48,9 +48,6 @@ file(GLOB_RECURSE tf_tools_transform_graph_lib_exclude_srcs "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/compare_graphs.cc" "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/summarize_graph_main.cc" "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/transform_graph_main.cc" - "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/quantize_nodes.cc" - "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/quantize_weights.cc" - "${tensorflow_source_dir}/tensorflow/tools/graph_transforms/round_weights.cc" ) list(REMOVE_ITEM tf_tools_transform_graph_lib_srcs ${tf_tools_transform_graph_lib_exclude_srcs}) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index f67698eb99a38eae307b52e55de748a67b798cbd..53c2285699a6ca94e1e6b147080338b507f4d768 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -31,7 +31,7 @@ from __future__ import division from __future__ import print_function import argparse -import io +import codecs import os import re import subprocess @@ -103,7 +103,7 @@ def main(): for lib_path in args.input: proc = subprocess.Popen([DUMPBIN, "/nologo", "/linkermember:1", lib_path], stdout=subprocess.PIPE) - for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"): + for line in codecs.getreader("utf-8")(proc.stdout): cols = line.split() if len(cols) < 2: continue @@ -131,7 +131,7 @@ def main(): # We compare on undname but use the decorated name from candidates. dupes = 0 proc = subprocess.Popen([UNDNAME, tmpfile.name], stdout=subprocess.PIPE) - for idx, line in enumerate(io.TextIOWrapper(proc.stdout, encoding="utf-8")): + for idx, line in enumerate(codecs.getreader("utf-8")(proc.stdout)): decorated = candidates[idx] if decorated in taken: # Symbol is already in output, done. diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ec3d550b70d2aaa23b989c44f3d86fa87cffb335 --- /dev/null +++ b/tensorflow/contrib/coder/BUILD @@ -0,0 +1,167 @@ +# Description: +# Contains entropy coding related modules. + +package(default_visibility = [ + "//learning/brain:__subpackages__", + "//tensorflow:__subpackages__", +]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_custom_op_library", + "tf_custom_op_py_library", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", + "tf_kernel_library", + "tf_py_test", +) + +cc_library( + name = "range_coder", + srcs = [ + "kernels/range_coder.cc", + ], + hdrs = [ + "kernels/range_coder.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "range_coder_test", + size = "small", + srcs = ["kernels/range_coder_test.cc"], + deps = [ + ":range_coder", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +tf_gen_op_libs( + op_lib_names = ["coder_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_kernel_library( + name = "range_coder_ops", + srcs = [ + "kernels/range_coder_ops.cc", + "kernels/range_coder_ops_util.cc", + ], + hdrs = [ + "kernels/range_coder_ops_util.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":coder_ops_op_lib", + ":range_coder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "range_coder_ops_test", + size = "small", + srcs = ["kernels/range_coder_ops_test.cc"], + deps = [ + ":range_coder", + ":range_coder_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + ], +) + +cc_library( + name = "all_ops", + deps = [":coder_ops_op_lib"], +) + +cc_library( + name = "all_kernels", + deps = [":range_coder_ops"], +) + +tf_custom_op_library( + name = "python/ops/_coder_ops.so", + srcs = [ + "kernels/range_coder.cc", + "kernels/range_coder.h", + "kernels/range_coder_ops.cc", + "kernels/range_coder_ops_util.cc", + "kernels/range_coder_ops_util.h", + "ops/coder_ops.cc", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_coder_ops", + out = "python/ops/gen_coder_ops.py", + deps = [":coder_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "coder_ops_py", + srcs = [ + "__init__.py", + "python/ops/coder_ops.py", + ], + dso = [ + ":python/ops/_coder_ops.so", + ], + kernels = [ + ":all_kernels", + ], + srcs_version = "PY2AND3", + deps = [ + ":gen_coder_ops", + "//tensorflow/contrib/util:util_py", + ], +) + +tf_py_test( + name = "coder_ops_py_test", + srcs = [ + "python/ops/coder_ops_test.py", + ], + additional_deps = [ + ":coder_ops_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + ], + main = "python/ops/coder_ops_test.py", +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/coder/README.md b/tensorflow/contrib/coder/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c6c379c458893551b765327c0c1cbfff7f24f9c3 --- /dev/null +++ b/tensorflow/contrib/coder/README.md @@ -0,0 +1,73 @@ +# Entropy coder + +This module contains range encoder and range decoder which can encode integer +data into string with cumulative distribution functions (CDF). + +## Data and CDF values + +The data to be encoded should be non-negative integers in half-open interval +`[0, m)`. Then a CDF is represented as an integral vector of length `m + 1` +where `CDF(i) = f(Pr(X < i) * 2^precision)` for i = 0,1,...,m, and `precision` +is an attribute in range `0 < precision <= 16`. The function `f` maps real +values into integers, e.g., round or floor. It is important that to encode a +number `i`, `CDF(i + 1) - CDF(i)` cannot be zero. + +Note that we used `Pr(X < i)` not `Pr(X <= i)`, and therefore CDF(0) = 0 always. + +## RangeEncode: data shapes and CDF shapes + +For each data element, its CDF has to be provided. Therefore if the shape of CDF +should be `data.shape + (m + 1,)` in NumPy-like notation. For example, if `data` +is a 2-D tensor of shape (10, 10) and its elements are in `[0, 64)`, then the +CDF tensor should have shape (10, 10, 65). + +This may make CDF tensor too large, and in many applications all data elements +may have the same probability distribution. To handle this, `RangeEncode` +supports limited broadcasting CDF into data. Broadcasting is limited in the +following sense: + +- All CDF axes but the last one is broadcasted into data but not the other way + around, +- The number of CDF axes does not extend, i.e., `CDF.ndim == data.ndim + 1`. + +In the previous example where data has shape (10, 10), the following are +acceptable CDF shapes: + +- (10, 10, 65) +- (1, 10, 65) +- (10, 1, 65) +- (1, 1, 65) + +## RangeDecode + +`RangeEncode` encodes neither data shape nor termination character. Therefore +the decoder should know how many characters are encoded into the string, and +`RangeDecode` takes the encoded data shape as the second argument. The same +shape restrictions as `RangeEncode` inputs apply here. + +## Example + +```python +data = tf.random_uniform((128, 128), 0, 10, dtype=tf.int32) + +histogram = tf.bincount(data, minlength=10, maxlength=10) +cdf = tf.cumsum(histogram, exclusive=False) +# CDF should have length m + 1. +cdf = tf.pad(cdf, [[1, 0]]) +# CDF axis count must be one more than data. +cdf = tf.reshape(cdf, [1, 1, -1]) + +# Note that data has 2^14 elements, and therefore the sum of CDF is 2^14. +data = tf.cast(data, tf.int16) +encoded = coder.range_encode(data, cdf, precision=14) +decoded = coder.range_decode(encoded, tf.shape(data), cdf, precision=14) + +# data and decoded should be the same. +sess = tf.Session() +x, y = sess.run((data, decoded)) +assert np.all(x == y) +``` + +## Authors +Sung Jin Hwang (github: [ssjhv](https://github.com/ssjhv)) and Nick Johnston +(github: [nmjohn](https://github.com/nmjohn)) diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b7e663e6f1359f399cdaa80e037635a8f7546b37 --- /dev/null +++ b/tensorflow/contrib/coder/__init__.py @@ -0,0 +1,26 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Entropy code operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import +from tensorflow.contrib.coder.python.ops.coder_ops import * +# pylint: enable=wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/contrib/coder/kernels/range_coder.cc b/tensorflow/contrib/coder/kernels/range_coder.cc new file mode 100644 index 0000000000000000000000000000000000000000..21b35155ff317c6afbb1b86745f05385726505b6 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder.cc @@ -0,0 +1,374 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Range coder implementation, based on [1]. +// +// [1] G. N. N. Martin, "Range coding: an algorithm for removing redundancy from +// a digitised message", presented to the Video & Data Recording Conference, +// held in Southampton, July 24-27, 1979. +// +#include "tensorflow/contrib/coder/kernels/range_coder.h" + +#include +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +RangeEncoder::RangeEncoder(int precision) : precision_(precision) { + CHECK_GT(precision, 0); + CHECK_LE(precision, 16); +} + +void RangeEncoder::Encode(int32 lower, int32 upper, string* sink) { + // Input requirement: 0 <= lower < upper <= 2^precision. + DCHECK_LE(0, lower); + DCHECK_LT(lower, upper); + DCHECK_LE(upper, 1 << precision_); + + // `base` and `size` represent a half-open interval [base, base + size). + // Loop invariant: 2^16 <= size <= 2^32. + // + // Note that keeping size above 2^16 is important. Since the interval sizes + // are quantized to up to 16 bits, the smallest interval size the encode may + // handle is 2^-16. If size is smaller than 2^16, a small interval input may + // collapse the encoder range into an empty interval. + const uint64 size = static_cast(size_minus1_) + 1; + DCHECK_NE(size >> 16, 0); + + // For short notation, let u := lower and v := upper. + // + // The input u, v represents a half-open interval [u, v) / 2^precision. + // This narrows the current interval roughly to + // [base + (size * u) / 2^precision, base + (size * v) / 2^precision). + // + // TODO(sjhwang): Try rounding if it helps improve compression ratio, at the + // expense of more operations. In the test using Zipf distribution, the + // overhead over the theoretical compression ratio was ~0.01%. + // NOTE: The max value of `size` is 2^32 and size > 0. Therefore `size * u` + // can be rewritten as `(size - 1) * u + u` and all the computation can be + // done in 32-bit mode. If 32-bit multiply is faster, then rewrite. + const uint32 a = (size * static_cast(lower)) >> precision_; + const uint32 b = ((size * static_cast(upper)) >> precision_) - 1; + DCHECK_LE(a, b); + + // Let's confirm the RHS of a, b fit in uint32 type. + // Recall that 0 <= u < 2^precision, and size <= 2^32. Therefore + // (size * u) / 2^precision < size <= 2^32, + // and the value of a fits in uint32 type. Similarly, since v <= 2^precision, + // (size * v) / 2^precision - 1 <= size - 1 < 2^32. + // For lower bound of b, note that 1 <= v, 2^16 <= size, and 16 <= precision. + // Therefore (size * v) / 2^precision - 1 >= 2^16 / 2^precision - 1 >= 0. + + // The new interval is [base + a, base + b] = [base + a, base + b + 1). + base_ += a; // May overflow. + size_minus1_ = b - a; + const bool base_overflow = (base_ < a); + + // The encoder has two states. Let's call them state 0 and state 1. + // State 0 is when base < base + size <= 2^32. + // State 1 is when base < 2^32 < base + size. + // + // The encoder initially starts in state 0, with base = 0, size = 2^32. + // + // TODO(sjhwang): Requires some profiling, but the encoder stays in state 0 + // most of the time. Should optimize code for state 0. + // + // Each Encode() has up to two places where the interval changes: + // #1. Refine the interval. [base, base + size) -> [base + a, base + b + 1). + // #2. Expand interval if the new size is too small, + // and each change may cause a state transition. + // + // First, consider when the current state is 0. + // + // In this case, the next state after #1 is always state 0, since refining + // interval only shrinks the interval, therefore new_base + new_size <= 2^32. + // + // Let us explain #2. + // + // Recall that at the beginning of each Encode(), the encoder requires + // 2^16 < size <= 2^32. As precision <= 16, the new interval size can be as + // small as 1, but never zero. + // + // To keep size above 2^16, if new size is smaller than or equal to 2^16, the + // encoder would left-shift base and size by 16 bits: size' <- size * 2^16. + // Note that new size' is now in the range [2^16, 2^32]. + // + // Since size is left-shifted, the same should be applied to base as well. + // However, after the left-shift, base will then contain 48 bits instead of 32 + // bits. Therefore prior to the shift, The upper 16 bits in base should be + // stored somewhere else. + // + // If the upper 16 bits of all values in the interval were the same, i.e., if + // base[32:16] == (base + size - 1)[32:16], then base[32:16] can be written + // out to `output` string, since any further Encode() only narrows down the + // interval and that 16 bits would never change. + // + // If the upper 16 bits were not all the same, since this happens only when + // size <= 2^16, the upper 16 bits may differ only by one, i.e., + // base[32:16] + 1 == (base + size - 1)[32:16]. At this stage, it is not + // determined yet whether base[32:16] should be written to the output or + // (base[32:16] + 1) should be written to the output. In this case, + // (base[32:16] + 1) is temporarily stored in `delay`, and base is + // left-shifted by 16 bits. + // + // In the latter case, the condition implies that (base // 2^16) and + // ((base + size - 1) // 2^16) were different. Therefore after left-shift by + // 16 bits, the new (base + size) is greater than 2^32, i.e., the encoder + // transition to state 1. + // + // ==== Summary ==== + // To detect the current encoder state, + // state 0: delay == 0 iff (base mod 2^32) < (base + size) mod 2^32, + // state 1: delay != 0 iff (base + size) mod 2^32 <= base mod 2^32, + // because size <= 2^32. + // + // ==== Summary for state 0 ==== + // 1. Interval refinement does not cause state transition. + // 2. Interval expansion may cause state transition, depending on the upper 16 + // bits of base and base + size - 1. + // + // Now suppose the previous state was 1. This means that + // base <= 2^32 < base + size. + // + // When in state 1, an interval refinement may trigger state transition. + // After Encode() refines the interval, there are three possibilities: + // #1. base <= 2^32 < base + size (unchanged), + // #2. 2^32 <= base < base + size (base overflowed), + // #3. base < base + size <= 2^32 (base + size - 1 underflowed). + // + // In case #1, the encoder remains in state 1. + // In case #2 or #3, the encoder state changes to state 0. + // + // ==== State transition for interval refinement ==== + // 1. state 0 -> state 0, + // 2. state 1 -> state 0 or state 1. + // + // Therefore if the new state is 1, then the previous state must have been + // state 1. + if (base_ + size_minus1_ < base_) { + // If statement checked if 2^32 < base + size. The new state is 1, hence the + // previous state was also state 1. + DCHECK_NE(((base_ - a) + size) >> 32, 0); + DCHECK_NE(delay_ & 0xFFFF, 0); + + // Like in state 0, if the new size is <= 2^16, then base and size should + // be left-shifted by 16 bits. Combine the conditions + // base <= 2^32 < base + size and size <= 2^16 to conclude that + // base[32:16] >= 0xFFFF and (base + size - 1)[32:16] = 0x0000. + // + // Note that 2^32 - base < size, and since base is at least 0xFFFF0000, + // 2^16 - base[16:0] < size. Let base' and size' be the new base and size + // after the bit-shift. Then 2^32 - base' < size' => 2^32 < base' + size'. + // Therefore the encoder remains in state 1. + // + // Lastly, `delay` is modified. Conceptually, delay has to be changed to + // delay' <- delay * 2^16 + (base + size - 1)[32:16]. + // Since we know above that (base + size - 1)[32:16] = 0x0000, there is no + // need to explicitly do the computation above, but rather store how many + // trailing zeros there were. For this reason, the lower 16 bits of + // `delay` stores the delayed value when state changed from 0 to 1, and + // delay[32:16] stores the # of trailing zeros (in bytes). + // + // ==== State transition for interval expansion ==== + // 1. state 0 -> state 0 or state 1, + // 2. state 1 -> state 1. + if (size_minus1_ >> 16 == 0) { + DCHECK_EQ(base_ >> 16, 0xFFFF); + base_ <<= 16; + size_minus1_ <<= 16; + size_minus1_ |= 0xFFFF; + // TODO(sjhwang): It is possible that for very long input, delay + // overflow during below. If overflow is detected, this delay is too + // long the encoder should forcefully move to state 0. In such case, + // base can be raised to 2^32 (force case #2), or (base + size) can be + // lowered to 2^32 (force case #3), depending on which transition + // keeps size larger. + CHECK_LT(delay_, static_cast(1) << 62); + delay_ += 0x20000; // Two more bytes of zeros. Check overflow? + } + return; + } + + // If reached here, the current state is 0. + // First handle the case when the previous state was state 1. + if (delay_ != 0) { + // In case #2 or #3, the encoder state changes to state 0. Recall that when + // the encoder state changed from state 0 to state 1, the top 16 bits of + // (base + size - 1) was temporarily stored in `delay`, because the output + // could be either (delay - 1) or (delay). + // + // And from above, the delayed value encoded in `delay` is + // delay' <- delay[16:0] * 2^(8 * delay[MAX:16]) + // + // In case #2, the interval moved below 2^32. So (delay' - 1) is the + // converged value after interval refinements. Write out + // (delay[16:0] - 1) and write (8 * delay[MAX:16]) bytes of 0xFF. + // + // In case #3, the interval moved above 2^32. So delay' is the converged + // value after interval refinement. Write out delay[16:0] and write + // (8 * delay[MAX:16]) bytes of 0x00. + if (base_overflow) { + // Case #2. + DCHECK_NE((static_cast(base_ - a) + a) >> 32, 0); + sink->push_back(static_cast(delay_ >> 8)); + sink->push_back(static_cast(delay_ >> 0)); + sink->append(delay_ >> 16, static_cast(0)); + } else { + // Case #3. + DCHECK_EQ(static_cast(base_ + size_minus1_) >> 32, 0); + --delay_; + sink->push_back(static_cast(delay_ >> 8)); + sink->push_back(static_cast(delay_ >> 0)); + sink->append(delay_ >> 16, static_cast(0xFF)); + } + // Reset to state 0. + delay_ = 0; + } + + if (size_minus1_ >> 16 == 0) { + const uint32 top = base_ >> 16; + + base_ <<= 16; + size_minus1_ <<= 16; + size_minus1_ |= 0xFFFF; + + if (base_ <= base_ + size_minus1_) { + // Still in state 0. Write the top 16 bits. + sink->push_back(static_cast(top >> 8)); + sink->push_back(static_cast(top)); + } else { + // New state is 1. + DCHECK_LT(top, 0xFFFF); + delay_ = top + 1; + } + } +} + +void RangeEncoder::Finalize(string* sink) { + // Finalize the encode by writing out any number in the interval + // [base, base + size). + // + // Trailing zeros are not explicitly written out as decoder can fill in zeros + // by default. + if (delay_ != 0) { + // The last state was state 1. Since base < 2^32 < base + size, pick 2^32 + // (state 1, case #3). + // NOTE: It is a bit difficult to trigger this code path on purpose. + // TODO(sjhwang): Find a way to trigger this code path for test coverage. + sink->push_back(static_cast(delay_ >> 8)); + if ((delay_ & 0xFF) != 0) { + sink->push_back(static_cast(delay_)); + } + } else if (base_ != 0) { + // If base == 0, then pick 0 from [base, base + size) and no zeros are + // explicitly written. + // + // Otherwise, pick (base + (2^16 - base[16:0])), i.e., round up base to the + // next multiple of 2^16. As 2^16 < size, this value should be in the + // interval [base, base + size). + const uint32 mid = ((base_ - 1) >> 16) + 1; + DCHECK_EQ(mid & 0xFFFF, mid); + sink->push_back(static_cast(mid >> 8)); + if ((mid & 0xFF) != 0) { + sink->push_back(static_cast(mid >> 0)); + } + } + + base_ = 0; + size_minus1_ = std::numeric_limits::max(); + delay_ = 0; +} + +RangeDecoder::RangeDecoder(const string& source, int precision) + : current_(source.begin()), + begin_(source.begin()), + end_(source.end()), + precision_(precision) { + CHECK_LE(precision, 16); + + Read16BitValue(); + Read16BitValue(); +} + +int32 RangeDecoder::Decode(tensorflow::gtl::ArraySlice cdf) { + const uint64 size = static_cast(size_minus1_) + 1; + const uint64 offset = + ((static_cast(value_ - base_) + 1) << precision_) - 1; + + // This is similar to std::lower_range() with std::less_equal as comparison. + // After the binary search, `pv` points to the smallest number v that + // satisfies offset < (size * v) / 2^precision. + + // Assumes that cdf[0] == 0. Therefore (size * cdf[0]) / 2^precision is always + // less than or equal to offset. + const int32* pv = cdf.data() + 1; + // `len` can be cdf.size() - 2 if there is guarantee that the last element of + // cdf is 2^precision. + auto len = cdf.size() - 1; + DCHECK_GT(len, 0); + + do { + const auto half = len / 2; + const int32* mid = pv + half; + DCHECK_GE(*mid, 0); + DCHECK_LE(*mid, 1 << precision_); + if (size * static_cast(*mid) <= offset) { + pv = mid + 1; + len -= half + 1; + } else { + len = half; + } + } while (len > 0); + + // If (size * v) / 2^precision <= offset for all v in cdf, then pv points to + // one after the last element of cdf. That is a decoding error. + // + // TODO(sjhwang): Consider returning -1 to indicate error. Or start len = + // cdf.size() - 2 instead and give up detecting this error. + CHECK_LT(pv, cdf.data() + cdf.size()); + + const uint32 a = (size * static_cast(*(pv - 1))) >> precision_; + const uint32 b = ((size * static_cast(*pv)) >> precision_) - 1; + DCHECK_LE(a, offset >> precision_); + DCHECK_LE(offset >> precision_, b); + + base_ += a; + size_minus1_ = b - a; + + if (size_minus1_ >> 16 == 0) { + base_ <<= 16; + size_minus1_ <<= 16; + size_minus1_ |= 0xFFFF; + + Read16BitValue(); + } + + return pv - cdf.data() - 1; +} + +void RangeDecoder::Read16BitValue() { + value_ <<= 8; + if (current_ != end_) { + value_ |= static_cast(*current_++); + } + value_ <<= 8; + if (current_ != end_) { + value_ |= static_cast(*current_++); + } +} +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder.h b/tensorflow/contrib/coder/kernels/range_coder.h new file mode 100644 index 0000000000000000000000000000000000000000..f46413072e34a55128d7854b9c312dfdde457d85 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder.h @@ -0,0 +1,109 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ +#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ + +#include +#include + +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +class RangeEncoder { + public: + // `precision` determines the granularity of probability masses passed to + // Encode() function below. + // + // REQUIRES: 0 < precision <= 16. + explicit RangeEncoder(int precision); + + // Encodes a half-open interval [lower / 2^precision, upper / 2^precision). + // Suppose each character to be encoded is from an integer-valued + // distribution. When encoding a random character x0, the arguments lower and + // upper represent + // Pr(X < x0) = lower / 2^precision, + // Pr(X < x0 + 1) = upper / 2^precision, + // where X is a random variable following the distribution. + // + // For example, assume that the distribution has possible outputs 0, 1, 2, ... + // To encode value 0, lower = 0 and upper = Pr(X = 0). + // To encode value 1, lower = Pr(X = 0) and upper = Pr(X = 0 or 1). + // To encode value 2, lower = Pr(X = 0 or 1) and upper = Pr(X = 0, 1, or 2). + // ... + // + // REQUIRES: 0 <= lower < upper <= 2^precision. + void Encode(int32 lower, int32 upper, string* sink); + + // The encode may contain some under-determined values from previous encoding. + // After Encode() calls, Finalize() must be called. Otherwise the encoded + // string may not be decoded. + void Finalize(string* sink); + + private: + uint32 base_ = 0; + uint32 size_minus1_ = std::numeric_limits::max(); + uint64 delay_ = 0; + + const int precision_; +}; + +class RangeDecoder { + public: + // Holds a reference to `source`. The caller has to make sure that `source` + // outlives the decoder object. + // + // REQUIRES: `precision` must be the same as the encoder's precision. + // REQUIRES: 0 < precision <= 16. + RangeDecoder(const string& source, int precision); + + // Decodes a character from `source` using CDF. The size of `cdf` should be + // one more than the number of the character in the alphabet. + // + // If x0, x1, x2, ... are the possible characters (in increasing order) from + // the distribution, then + // cdf[0] = 0 + // cdf[1] = Pr(X <= x0), + // cdf[2] = Pr(X <= x1), + // cdf[3] = Pr(X <= x2), + // ... + // + // The returned value is an index to `cdf` where the decoded character + // corresponds to. + // + // REQUIRES: cdf.size() > 1. + // REQUIRES: cdf[i] <= cdf[i + 1] for i = 0, 1, ..., cdf.size() - 2. + // REQUIRES: cdf[cdf.size() - 1] <= 2^precision. + // + // In practice the last element of `cdf` should equal to 2^precision. + int32 Decode(gtl::ArraySlice cdf); + + private: + void Read16BitValue(); + + uint32 base_ = 0; + uint32 size_minus1_ = std::numeric_limits::max(); + uint32 value_ = 0; + + string::const_iterator current_; + const string::const_iterator begin_; + const string::const_iterator end_; + + const int precision_; +}; +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_H_ diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops.cc b/tensorflow/contrib/coder/kernels/range_coder_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..cde7982530fea6407aaf074f7af4a22263d50da3 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops.cc @@ -0,0 +1,307 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/coder/kernels/range_coder.h" +#include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace { +// A helper class to iterate over data and cdf simultaneously, while cdf is +// broadcasted to data. +// NOTE: Moving this class out of anonymous namespace impacts compiler +// optimization and affects performance. When moving this code around (e.g., +// into a library header), be sure to check the benchmark tests. +template +class BroadcastRange { + public: + BroadcastRange(T* data_pointer, gtl::ArraySlice data_shape, + const U* cdf_pointer, gtl::ArraySlice cdf_shape) + : data_pointer_(data_pointer), cdf_pointer_(cdf_pointer) { + CHECK(!data_shape.empty()); + CHECK_EQ(data_shape.size(), N); + CHECK_EQ(cdf_shape.size(), N + 1); + + std::copy(data_shape.begin(), data_shape.end(), &data_shape_[0]); + data_index_.fill(0); + + const int64 innermost_stride = cdf_shape[N]; + cdf_displace_.fill(innermost_stride); + + // Pre-compute the pointer displacement for cdf. + int64 stride = innermost_stride; + for (int i = N - 1; i >= 0; --i) { + const bool broadcasting = (cdf_shape[i] <= 1); + + // When the data linear index advances by one, the cdf linear index + // advances by `innermost_stride`. + // + // Suppose that the i-th axis coordinate of data increased by one, and + // that i-th axis is broadcasting. The cdf linear index should be wound + // back by i-th axis stride, so that i-th axis coordinate of cdf is + // effectively kept at 0. + if (broadcasting) { + cdf_displace_[i] -= stride; + } + stride *= cdf_shape[i]; + } + } + + // Returns the pointers to the current iterating locations to data and cdf + // tensors. + // + // Note that this function does not track whether data pointer is running past + // the end of data buffer. The caller has to make sure Next() is called no + // more than that. + std::pair Next() { + std::pair return_value = {data_pointer_, cdf_pointer_}; + + int i = N - 1; + for (; i > 0; --i) { + ++data_index_[i]; + if (data_index_[i] < data_shape_[i]) { + break; + } + data_index_[i] = 0; + } + + // Advance data pointer by one. + data_pointer_ += 1; + + // For cdf pointer, it's more complicated because of broadcasting. When i-th + // coordinate increase by one, and if i-th axis is broadcasting, then we + // need to rewind back the pointer so that the effective i-th axis + // coordinate for cdf is always 0. This value is precomputed as + // cdf_displace_. + cdf_pointer_ += cdf_displace_[i]; + return return_value; + } + + private: + std::array data_shape_; + std::array cdf_displace_; + std::array data_index_; + + T* data_pointer_; + const U* cdf_pointer_; +}; + +Status CheckCdfShape(const TensorShape& data_shape, + const TensorShape& cdf_shape) { + if (TF_PREDICT_FALSE(cdf_shape.dims() != data_shape.dims() + 1)) { + return errors::InvalidArgument( + "`cdf` should have one more axis than `data`: data shape=", + data_shape.DebugString(), ", cdf shape=", cdf_shape.DebugString()); + } + + if (TF_PREDICT_FALSE(cdf_shape.dim_size(cdf_shape.dims() - 1) <= 1)) { + return errors::InvalidArgument( + "The last dimension of `cdf` should be > 1: ", cdf_shape.DebugString()); + } + + return Status::OK(); +} + +// Non-incremental encoder op ------------------------------------------------- +class RangeEncodeOp : public OpKernel { + public: + explicit RangeEncodeOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); + OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, + errors::InvalidArgument("`precision` must be in [1, 16]: ", + precision_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& data = context->input(0); + const Tensor& cdf = context->input(1); + + OP_REQUIRES_OK(context, CheckCdfShape(data.shape(), cdf.shape())); + + std::vector data_shape, cdf_shape; + OP_REQUIRES_OK( + context, MergeAxes(data.shape(), cdf.shape(), &data_shape, &cdf_shape)); + + Tensor* output_tensor; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape{}, &output_tensor)); + string* output = &output_tensor->scalar()(); + + switch (data_shape.size()) { +#define RANGE_ENCODE_CASE(dims) \ + case dims: { \ + RangeEncodeImpl(data.flat(), data_shape, \ + cdf.flat_inner_dims(), cdf_shape, output); \ + } break + RANGE_ENCODE_CASE(1); + RANGE_ENCODE_CASE(2); + RANGE_ENCODE_CASE(3); + RANGE_ENCODE_CASE(4); + RANGE_ENCODE_CASE(5); + RANGE_ENCODE_CASE(6); +#undef RANGE_ENCODE_CASE + default: + context->CtxFailure(errors::InvalidArgument( + "Irregular broadcast pattern: ", data.shape().DebugString(), ", ", + cdf.shape().DebugString())); + return; + } + } + + private: + template + void RangeEncodeImpl(TTypes::ConstFlat data, + gtl::ArraySlice data_shape, + TTypes::ConstMatrix cdf, + gtl::ArraySlice cdf_shape, string* output) const { + const int64 data_size = data.size(); + const int64 cdf_size = cdf.size(); + const int64 chip_size = cdf.dimension(1); + + BroadcastRange view{data.data(), data_shape, + cdf.data(), cdf_shape}; + RangeEncoder encoder{precision_}; + for (int64 linear = 0; linear < data_size; ++linear) { + const auto pair = view.Next(); + + const int64 index = *pair.first; + DCHECK_GE(index, 0); + DCHECK_LT(index + 1, chip_size); + + const int32* cdf_slice = pair.second; + DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); + + const int32 lower = cdf_slice[index]; + const int32 upper = cdf_slice[index + 1]; + encoder.Encode(lower, upper, output); + } + + encoder.Finalize(output); + } + + int precision_; +}; + +REGISTER_KERNEL_BUILDER(Name("RangeEncode").Device(DEVICE_CPU), RangeEncodeOp); + +// Non-incremental decoder op ------------------------------------------------- +class RangeDecodeOp : public OpKernel { + public: + explicit RangeDecodeOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr("precision", &precision_)); + OP_REQUIRES(context, 0 < precision_ && precision_ <= 16, + errors::InvalidArgument("`precision` must be in [1, 16]: ", + precision_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& encoded_tensor = context->input(0); + const Tensor& shape = context->input(1); + const Tensor& cdf = context->input(2); + + OP_REQUIRES(context, TensorShapeUtils::IsScalar(encoded_tensor.shape()), + errors::InvalidArgument("Invalid `encoded` shape: ", + encoded_tensor.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsVector(shape.shape()), + errors::InvalidArgument("Invalid `shape` shape: ", + shape.shape().DebugString())); + TensorShape output_shape; + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(shape.vec(), + &output_shape)); + OP_REQUIRES_OK(context, CheckCdfShape(output_shape, cdf.shape())); + + std::vector data_shape, cdf_shape; + OP_REQUIRES_OK( + context, MergeAxes(output_shape, cdf.shape(), &data_shape, &cdf_shape)); + + const string& encoded = encoded_tensor.scalar()(); + + Tensor* output; + OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); + + switch (data_shape.size()) { +#define RANGE_DECODE_CASE(dim) \ + case dim: { \ + RangeDecodeImpl(output->flat(), data_shape, \ + cdf.flat_inner_dims(), cdf_shape, encoded); \ + } break + RANGE_DECODE_CASE(1); + RANGE_DECODE_CASE(2); + RANGE_DECODE_CASE(3); + RANGE_DECODE_CASE(4); + RANGE_DECODE_CASE(5); + RANGE_DECODE_CASE(6); +#undef RANGE_DECODE_CASE + default: + context->CtxFailure(errors::InvalidArgument( + "Irregular broadcast pattern: ", output_shape.DebugString(), ", ", + cdf.shape().DebugString())); + return; + } + } + + private: + template + void RangeDecodeImpl(TTypes::Flat output, + gtl::ArraySlice output_shape, + TTypes::ConstMatrix cdf, + gtl::ArraySlice cdf_shape, + const string& encoded) const { + BroadcastRange view{output.data(), output_shape, + cdf.data(), cdf_shape}; + + RangeDecoder decoder{encoded, precision_}; + + const int64 output_size = output.size(); + const int64 cdf_size = cdf.size(); + const auto chip_size = + static_cast::size_type>(cdf.dimension(1)); + + for (int64 i = 0; i < output_size; ++i) { + const auto pair = view.Next(); + + int16* data = pair.first; + DCHECK_LT(data, output.data() + output_size); + + const int32* cdf_slice = pair.second; + DCHECK_LE(cdf_slice + chip_size, cdf.data() + cdf_size); + + *data = decoder.Decode(gtl::ArraySlice{cdf_slice, chip_size}); + } + } + + int precision_; +}; + +REGISTER_KERNEL_BUILDER(Name("RangeDecode").Device(DEVICE_CPU), RangeDecodeOp); +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ae4d9d2836a0f89a9765004a85bc3c292b0e484f --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_test.cc @@ -0,0 +1,521 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/contrib/coder/kernels/range_coder.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/shape_refiner.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace { +int LogUniform(random::SimplePhilox* gen, uint32 n) { + CHECK_GT(n, 0); + + // Split [0, n) into {0}, [1, 2), [2, 4), [4, 8), ..., [2^(m-1), n). + const int m = Log2Ceiling(n); + + int outcome; + do { + // Uniform() consumes at least 32 bits per call, therefore this is somewhat + // wasteful implementation. Since this is used only for test, we do not + // refine this implementation further. + const int k = gen->Uniform(m + 1) - 1; + // If k == -1, then sample from {0}. + // If k == 0, then sample from [1, 2). + // If k == 1, then sample from [2, 4), ... and so on. + if (k < 1) { + outcome = k + 1; + } else { + outcome = (1 << k) + gen->Uniform(1 << k); + } + } while (n <= outcome); + return outcome; +} + +std::vector ComputeStrides(const TensorShape& shape) { + std::vector stride(shape.dims()); + int64 current = 1; + for (int i = shape.dims() - 1; i >= 0; --i) { + stride[i] = current; + current *= shape.dim_size(i); + } + return stride; +} + +class RangeCoderOpsTest : public OpsTestBase { + protected: + Status RunEncodeOp(int precision, gtl::ArraySlice input, + Tensor* output) { + TF_RETURN_IF_ERROR(NodeDefBuilder("encode", "RangeEncode") + .Input(tensorflow::FakeInput(DT_INT16)) + .Input(tensorflow::FakeInput(DT_INT32)) + .Attr("precision", precision) + .Finalize(node_def())); + TF_RETURN_IF_ERROR(InitOp()); + + inputs_.clear(); + std::vector copies(input.size()); + for (int i = 0; i < input.size(); ++i) { + copies[i] = input[i]; + inputs_.emplace_back(&copies[i]); + } + + TF_RETURN_IF_ERROR(RunOpKernel()); + + *output = *GetOutput(0); + inputs_.clear(); + + return Status::OK(); + } + + Status RunDecodeOp(int precision, gtl::ArraySlice input, + Tensor* output) { + TF_RETURN_IF_ERROR(NodeDefBuilder("decode", "RangeDecode") + .Input(tensorflow::FakeInput(DT_STRING)) + .Input(tensorflow::FakeInput(DT_INT32)) + .Input(tensorflow::FakeInput(DT_INT32)) + .Attr("precision", precision) + .Finalize(node_def())); + TF_RETURN_IF_ERROR(InitOp()); + + inputs_.clear(); + std::vector copies(input.size()); + for (int i = 0; i < input.size(); ++i) { + copies[i] = input[i]; + inputs_.emplace_back(&copies[i]); + } + + TF_RETURN_IF_ERROR(RunOpKernel()); + + *output = *GetOutput(0); + inputs_.clear(); + + return Status::OK(); + } + + void TestEncodeAndDecode(int precision, const Tensor& data, + const Tensor& cdf) { + Tensor encoded; + TF_ASSERT_OK(RunEncodeOp(precision, {data, cdf}, &encoded)); + + const TensorShape& data_shape = data.shape(); + Tensor shape{DT_INT32, {data_shape.dims()}}; + for (int i = 0; i < data_shape.dims(); ++i) { + shape.flat()(i) = data_shape.dim_size(i); + } + + Tensor decoded; + TF_ASSERT_OK(RunDecodeOp(precision, {encoded, shape, cdf}, &decoded)); + + EXPECT_EQ(decoded.dtype(), data.dtype()); + EXPECT_EQ(decoded.shape(), data.shape()); + EXPECT_EQ(decoded.tensor_data(), data.tensor_data()); + } + + void PopulateMaxValues(random::SimplePhilox* gen, Tensor* maxvalue_tensor, + int min_maxvalue, int max_maxvalue) { + const int range = max_maxvalue - min_maxvalue; + TTypes::Flat flat = maxvalue_tensor->flat(); + + for (int64 i = 0; i < flat.size(); ++i) { + flat(i) = min_maxvalue + gen->Uniform(range); + } + } + + void BuildCdf(random::SimplePhilox* gen, Tensor* data_tensor, + Tensor* cdf_tensor, const Tensor& maxvalue_tensor) { + CHECK(TensorShapeUtils::StartsWith(cdf_tensor->shape(), + maxvalue_tensor.shape())); + CHECK_EQ(cdf_tensor->dims(), maxvalue_tensor.dims() + 1); + const int64 chip_size = cdf_tensor->dim_size(cdf_tensor->dims() - 1); + + std::vector data_stride = ComputeStrides(data_tensor->shape()); + std::vector cdf_stride = ComputeStrides(cdf_tensor->shape()); + + for (int i = 0; i < cdf_tensor->dims(); ++i) { + if (cdf_tensor->dim_size(i) == 1) { + cdf_stride[i] = 0; + } + } + + Tensor histogram_tensor{DT_INT32, cdf_tensor->shape()}; + TTypes::Flat data = data_tensor->flat(); + TTypes::Flat histogram = histogram_tensor.flat(); + TTypes::ConstFlat maxvalue = maxvalue_tensor.flat(); + histogram.setZero(); + + for (int64 index = 0; index < data.size(); ++index) { + int64 temp = index; + int64 offset = 0; + for (int dim = 0; dim < data_stride.size(); ++dim) { + const int64 coord = temp / data_stride[dim]; + offset += coord * cdf_stride[dim]; + temp -= coord * data_stride[dim]; + } + ASSERT_EQ(temp, 0); + + const int64 maxvalue_offset = offset / chip_size; + CHECK_EQ(maxvalue_offset * chip_size, offset); + CHECK_LT(maxvalue(maxvalue_offset) + 1, chip_size); + const int value = LogUniform(gen, maxvalue(maxvalue_offset)); + data(index) = value; + histogram(offset + value + 1) += 1; + } + + cdf_tensor->flat_inner_dims() = + histogram_tensor.flat_inner_dims().cumsum(1); + } +}; + +TEST_F(RangeCoderOpsTest, NoBroadcast) { + constexpr int kPrecision = 14; + constexpr int kMaxValue = 10; + + Tensor data{DT_INT16, {1, 32, 32, 16}}; + Tensor temp{DT_INT32, {1, 1, 1, 1, kMaxValue + 2}}; + Tensor maxvalue{DT_INT16, {1, 1, 1, 1}}; + maxvalue.flat()(0) = kMaxValue; + + ASSERT_LE(data.shape().num_elements(), 1 << kPrecision); + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + BuildCdf(&gen, &data, &temp, maxvalue); + + const Eigen::array broadcast = {1, 32, 32, 16, 1}; + + Tensor cdf{DT_INT32, {1, 32, 32, 16, kMaxValue + 2}}; + cdf.tensor() = temp.tensor().broadcast(broadcast); + + TestEncodeAndDecode(kPrecision, data, cdf); +} + +TEST_F(RangeCoderOpsTest, Broadcast1Axis) { + constexpr int kPrecision = 9; + constexpr int kDimensionSize = 1 << kPrecision; + constexpr int kMinMaxValue = 10; + constexpr int kMaxMaxValue = 64; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + Tensor data{DT_INT16, {1, kDimensionSize, kDimensionSize}}; + + Tensor maxvalue{DT_INT16, {kDimensionSize}}; + PopulateMaxValues(&gen, &maxvalue, kMinMaxValue, kMaxMaxValue); + + { + // Axis 1. + Tensor maxvalue1; + ASSERT_TRUE(maxvalue1.CopyFrom(maxvalue, {1, 1, kDimensionSize})); + + Tensor cdf{DT_INT32, {1, 1, kDimensionSize, kMaxMaxValue + 2}}; + BuildCdf(&gen, &data, &cdf, maxvalue1); + TestEncodeAndDecode(kPrecision, data, cdf); + } + + { + // Axis 2. + Tensor maxvalue2; + ASSERT_TRUE(maxvalue2.CopyFrom(maxvalue, {1, kDimensionSize, 1})); + + Tensor cdf{DT_INT32, {1, kDimensionSize, 1, kMaxMaxValue + 2}}; + BuildCdf(&gen, &data, &cdf, maxvalue2); + TestEncodeAndDecode(kPrecision, data, cdf); + } +} + +TEST_F(RangeCoderOpsTest, Broadcast2Axes) { + constexpr int kPrecision = 13; + constexpr int kDimensionSize1 = 1 << (kPrecision / 2); + constexpr int kDimensionSize2 = 1 << (kPrecision - kPrecision / 2); + constexpr int kMinMaxValue = 10; + constexpr int kMaxMaxValue = 64; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + Tensor maxvalue{DT_INT16, {2, 1, 1, 7}}; + PopulateMaxValues(&gen, &maxvalue, kMinMaxValue, kMaxMaxValue); + + Tensor data{DT_INT16, {2, kDimensionSize1, kDimensionSize2, 7}}; + Tensor cdf{DT_INT32, {2, 1, 1, 7, kMaxMaxValue + 2}}; + BuildCdf(&gen, &data, &cdf, maxvalue); + TestEncodeAndDecode(kPrecision, data, cdf); +} + +TEST_F(RangeCoderOpsTest, InvalidCdfShape) { + Tensor data{DT_INT16, {3, 3}}; + Tensor cdf{DT_INT32, {3, 3}}; + + Tensor unused; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("`cdf` should have one more axis"), + string::npos); + } + + Tensor empty{DT_STRING, {}}; + Tensor shape{DT_INT32, {2}}; + shape.vec().setValues({3, 3}); + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("`cdf` should have one more axis"), + string::npos); + } + + cdf = Tensor{DT_INT32, {3, 3, 1}}; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE( + status.error_message().find("last dimension of `cdf` should be > 1"), + string::npos); + } + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE( + status.error_message().find("last dimension of `cdf` should be > 1"), + string::npos); + } +} + +TEST_F(RangeCoderOpsTest, DecoderShapeFn) { + Tensor encoded_tensor{DT_STRING, {}}; + Tensor shape_tensor{DT_INT32, {3}}; + Tensor cdf_tensor{DT_INT32, {4, 6, 8, 2}}; + + shape_tensor.flat().setValues({4, 6, 8}); + + Graph g{OpRegistry::Global()}; + Node* encoded = test::graph::Constant(&g, encoded_tensor); + Node* shape = test::graph::Constant(&g, shape_tensor); + Node* cdf = test::graph::Constant(&g, cdf_tensor); + Node* decode; + TF_ASSERT_OK(NodeBuilder("range_decode", "RangeDecode", g.op_registry()) + .Input(encoded) + .Input(shape) + .Input(cdf) + .Attr("precision", 10) + .Finalize(&g, &decode)); + + ShapeRefiner refiner{g.versions().producer(), g.op_registry()}; + TF_ASSERT_OK(refiner.AddNode(encoded)); + TF_ASSERT_OK(refiner.AddNode(shape)); + TF_ASSERT_OK(refiner.AddNode(cdf)); + TF_ASSERT_OK(refiner.AddNode(decode)); + + auto* context = refiner.GetContext(decode); + ASSERT_NE(context, nullptr); + + ASSERT_EQ(context->num_outputs(), 1); + auto shape_handle = context->output(0); + + ASSERT_EQ(context->Rank(shape_handle), 3); + EXPECT_EQ(context->Value(context->Dim(shape_handle, 0)), 4); + EXPECT_EQ(context->Value(context->Dim(shape_handle, 1)), 6); + EXPECT_EQ(context->Value(context->Dim(shape_handle, 2)), 8); +} + +TEST_F(RangeCoderOpsTest, InvalidBroadcast) { + Tensor data{DT_INT16, {3, 3}}; + Tensor cdf{DT_INT32, {3, 2, 2}}; + + Tensor unused; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Cannot broadcast shape"), + string::npos); + } + + data = Tensor{DT_INT16, {3, 1}}; + cdf = Tensor{DT_INT32, {3, 3, 2}}; + Tensor empty{DT_STRING, {}}; + Tensor shape{DT_INT32, {2}}; + shape.vec().setValues({3, 1}); + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Cannot broadcast shape"), + string::npos); + } + + std::vector shape_vector = {2, 2, 2, 2, 2, 2, 2, 2, 2}; + data = Tensor{DT_INT16, TensorShape{shape_vector}}; + cdf = Tensor{DT_INT32, {2, 1, 2, 1, 2, 1, 2, 1, 2, 2}}; + { + const Status status = RunEncodeOp(10, {data, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Irregular broadcast"), string::npos); + } + + shape = Tensor{DT_INT32, {static_cast(shape_vector.size())}}; + for (int i = 0; i < shape_vector.size(); ++i) { + shape.flat()(i) = shape_vector[i]; + } + { + const Status status = RunDecodeOp(10, {empty, shape, cdf}, &unused); + EXPECT_FALSE(status.ok()); + EXPECT_NE(status.error_message().find("Irregular broadcast"), string::npos); + } +} + +// Benchmark ------------------------------------------------------------- + +// This function creates RangeEncode graph with CDF built from a separate data +// sample. +Graph* CreateRangeEncodeFullBroadcastGraph(const TensorShape& shape, + int precision) { + CHECK_EQ(shape.dims(), 4); + + constexpr int kAlphabetSize = 70; + + Tensor histogram{DT_INT32, {kAlphabetSize + 1}}; + TTypes::Vec h = histogram.vec(); + h.setConstant(1); + h(0) = 0; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + for (int i = 0; i < (1 << precision) - kAlphabetSize; ++i) { + const int value = LogUniform(&gen, kAlphabetSize - 1); + h(value + 1) += 1; + } + + Tensor cdf{DT_INT32, {1, 1, 1, 1, kAlphabetSize + 1}}; + cdf.flat() = h.cumsum(0); + + Tensor data{DT_INT16, shape}; + TTypes::Flat d = data.flat(); + for (int64 i = 0; i < d.size(); ++i) { + d(i) = LogUniform(&gen, kAlphabetSize - 1); + } + + Graph* g = new Graph(OpRegistry::Global()); + TF_CHECK_OK(NodeBuilder("range_encode", "RangeEncode", g->op_registry()) + .Input(test::graph::Constant(g, data)) + .Input(test::graph::Constant(g, cdf)) + .Attr("precision", precision) + .Finalize(g, nullptr)); + return g; +} + +// This function creates RangeDecode graph with CDF built from a separate data +// sample. +Graph* CreateRangeDecodeFullBroadcastGraph(const TensorShape& shape, + int precision) { + CHECK_EQ(shape.dims(), 4); + + constexpr int kAlphabetSize = 200; + const int64 num_elements = shape.num_elements(); + + Tensor histogram{DT_INT32, {kAlphabetSize + 1}}; + TTypes::Vec h = histogram.vec(); + h.setConstant(1); + h(0) = 0; + + random::PhiloxRandom philox(random::New64(), random::New64()); + random::SimplePhilox gen(&philox); + for (int i = 0; i < (1 << precision) - kAlphabetSize; ++i) { + const int value = LogUniform(&gen, kAlphabetSize - 1); + h(value + 1) += 1; + } + + Tensor cdf_tensor{DT_INT32, {1, 1, 1, 1, kAlphabetSize + 1}}; + TTypes::Flat cdf = cdf_tensor.flat(); + cdf = h.cumsum(0); + + Tensor string_tensor{DT_STRING, TensorShape{}}; + string& sink = string_tensor.scalar()(); + + RangeEncoder encoder{precision}; + for (int64 i = 0; i < num_elements; ++i) { + const int value = LogUniform(&gen, kAlphabetSize - 1); + encoder.Encode(cdf(value), cdf(value + 1), &sink); + } + encoder.Finalize(&sink); + + Tensor shape_tensor{DT_INT32, {shape.dims()}}; + for (int i = 0; i < shape.dims(); ++i) { + shape_tensor.flat()(i) = shape.dim_size(i); + } + + Graph* g = new Graph(OpRegistry::Global()); + TF_CHECK_OK(NodeBuilder("range_decode", "RangeDecode", g->op_registry()) + .Input(test::graph::Constant(g, string_tensor)) + .Input(test::graph::Constant(g, shape_tensor)) + .Input(test::graph::Constant(g, cdf_tensor)) + .Attr("precision", precision) + .Finalize(g, nullptr)); + return g; +} + +void RunTensorFlowBenchmark(int iters, Graph* g, int64 num_elements) { + SessionOptions opts; + opts.config.set_intra_op_parallelism_threads(1); + opts.config.set_inter_op_parallelism_threads(1); + + testing::UseRealTime(); + test::Benchmark("cpu", g, &opts).Run(iters); + + const int64 num_items = static_cast(iters) * num_elements; + testing::ItemsProcessed(num_items); +} + +void BM_RangeEncodeFullBroadcast(int iters, int code_size) { + constexpr int kPrecision = 14; + const TensorShape shape = {1, code_size, code_size, 256}; + Graph* g = CreateRangeEncodeFullBroadcastGraph(shape, kPrecision); + RunTensorFlowBenchmark(iters, g, shape.num_elements()); +} + +BENCHMARK(BM_RangeEncodeFullBroadcast)->Arg(32)->Arg(64); + +void BM_RangeDecodeFullBroadcast(int iters, int code_size) { + constexpr int kPrecision = 14; + const TensorShape shape = {1, code_size, code_size, 256}; + Graph* g = CreateRangeDecodeFullBroadcastGraph(shape, kPrecision); + RunTensorFlowBenchmark(iters, g, shape.num_elements()); +} + +BENCHMARK(BM_RangeDecodeFullBroadcast)->Arg(32)->Arg(64); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc b/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..d66730cb4881ea92b5477047c500291fa9c0c290 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/coder/kernels/range_coder_ops_util.h" + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::errors::InvalidArgument; + +namespace tensorflow { +Status MergeAxes(const TensorShape& broadcast_shape, + const TensorShape& storage_shape, + std::vector* merged_broadcast_shape_pointer, + std::vector* merged_storage_shape_pointer) { + CHECK_EQ(storage_shape.dims(), broadcast_shape.dims() + 1); + + std::vector& merged_broadcast_shape = *merged_broadcast_shape_pointer; + std::vector& merged_storage_shape = *merged_storage_shape_pointer; + + // The shapes are simplified so that the conversions between linear index + // and coordinates takes less CPU cycles. Two adjacent dimensions are + // merged if they both are broadcasting dimensions or if they both are + // non-broadcasting dimensions. + merged_broadcast_shape.resize(1); + merged_broadcast_shape[0] = 1; + merged_storage_shape.resize(1); + merged_storage_shape[0] = 1; + + for (int i = 0, j = 0; j < broadcast_shape.dims(); ++j) { + if (TF_PREDICT_FALSE( + (broadcast_shape.dim_size(j) != storage_shape.dim_size(j)) && + (storage_shape.dim_size(j) != 1))) { + return InvalidArgument("Cannot broadcast shape ", + storage_shape.DebugString(), " to ", + broadcast_shape.DebugString()); + } + + const bool was_broadcasting = (merged_storage_shape[i] == 1); + const bool is_broadcasting = (storage_shape.dim_size(j) == 1); + + // Merge two adjacent axes if they both are broadcasting or both are + // non-broadcasting axes. The second and the third conditions in the if + // clause below are when the previously merged axis or the next j-th axis + // may be interpreted as either a broadcasting or a non-broadcasting axis. + const bool merge = (was_broadcasting == is_broadcasting) || + (broadcast_shape.dim_size(j) <= 1) || + (merged_broadcast_shape[i] <= 1); + + if (merge) { + merged_broadcast_shape[i] *= broadcast_shape.dim_size(j); + merged_storage_shape[i] *= storage_shape.dim_size(j); + } else { + // Move to the next axis. + merged_broadcast_shape.push_back(broadcast_shape.dim_size(j)); + merged_storage_shape.push_back(storage_shape.dim_size(j)); + ++i; + } + } + + int64 storage_stride = 1; + for (int i = broadcast_shape.dims(); i < storage_shape.dims(); ++i) { + storage_stride *= storage_shape.dim_size(i); + } + merged_storage_shape.push_back(storage_stride); + + return Status::OK(); +} +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/kernels/range_coder_ops_util.h b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h new file mode 100644 index 0000000000000000000000000000000000000000..b8aabcef62e9de53810397960f871abc4adc0cf9 --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_ops_util.h @@ -0,0 +1,33 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ +#define TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ + +#include + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +// The shapes are simplified to reduce indexing cost. +Status MergeAxes(const TensorShape& broadcast_shape, + const TensorShape& storage_shape, + std::vector* merged_broadcast_shape_pointer, + std::vector* merged_storage_shape_pointer); +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_CODER_KERNELS_RANGE_CODER_OPS_UTIL_H_ diff --git a/tensorflow/contrib/coder/kernels/range_coder_test.cc b/tensorflow/contrib/coder/kernels/range_coder_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..442994bf7c7566c1cbe1c439050a69e5b9a4208e --- /dev/null +++ b/tensorflow/contrib/coder/kernels/range_coder_test.cc @@ -0,0 +1,116 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/coder/kernels/range_coder.h" + +#include + +#include "tensorflow/core/lib/random/distribution_sampler.h" +#include "tensorflow/core/lib/random/philox_random.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { +void RangeEncodeDecodeTest(int precision, random::SimplePhilox* gen) { + constexpr int kAlphabetSize = 256; + + std::vector distribution_weight; + distribution_weight.reserve(kAlphabetSize); + for (int i = 1; i <= kAlphabetSize; ++i) { + distribution_weight.push_back(std::pow(static_cast(i), -2.0f)); + } + + random::DistributionSampler sampler(distribution_weight); + + const int multiplier = (precision > 7) ? 32 : 1; + std::vector histogram(kAlphabetSize, multiplier - 1); + + const int data_size = + (multiplier << precision) - histogram.size() * (multiplier - 1); + CHECK_GE(data_size, 0); + std::vector data(data_size); + for (uint8& x : data) { + x = sampler.Sample(gen); + ++histogram[x]; + } + + std::vector cdf(histogram.size() + 1, 0); + int partial_sum = 0; + for (int i = 0; i < histogram.size(); ++i) { + partial_sum += histogram[i]; + cdf[i + 1] = partial_sum / multiplier; + } + + ASSERT_EQ(cdf.front(), 0); + ASSERT_EQ(cdf.back(), 1 << precision); + + std::vector ideal_code_length(histogram.size()); + const double normalizer = static_cast(1 << precision); + for (int i = 0; i < ideal_code_length.size(); ++i) { + ideal_code_length[i] = -std::log2((cdf[i + 1] - cdf[i]) / normalizer); + } + + RangeEncoder encoder(precision); + string encoded; + double ideal_length = 0.0; + for (uint8 x : data) { + encoder.Encode(cdf[x], cdf[x + 1], &encoded); + ideal_length += ideal_code_length[x]; + } + encoder.Finalize(&encoded); + + LOG(INFO) << "Encoded string length (bits): " << 8 * encoded.size() + << ", whereas ideal " << ideal_length << " (" + << (8 * encoded.size()) / ideal_length << " of ideal) " + << " (ideal compression rate " << ideal_length / (8 * data.size()) + << ")"; + + RangeDecoder decoder(encoded, precision); + for (int i = 0; i < data.size(); ++i) { + const int32 decoded = decoder.Decode(cdf); + ASSERT_EQ(decoded, static_cast(data[i])) << i; + } +} + +TEST(RangeCoderTest, Precision1To11) { + random::PhiloxRandom gen(random::New64(), random::New64()); + random::SimplePhilox rand(&gen); + const int precision = 1 + rand.Uniform(11); + RangeEncodeDecodeTest(precision, &rand); +} + +TEST(RangeCoderTest, Precision12To16) { + random::PhiloxRandom gen(random::New64(), random::New64()); + random::SimplePhilox rand(&gen); + for (int precision = 12; precision < 17; ++precision) { + RangeEncodeDecodeTest(precision, &rand); + } +} + +TEST(RangeCoderTest, FinalizeState0) { + constexpr int kPrecision = 2; + + string output; + RangeEncoder encoder(kPrecision); + encoder.Encode(0, 2, &output); + encoder.Finalize(&output); + + RangeDecoder decoder(output, kPrecision); + EXPECT_EQ(decoder.Decode({0, 2, 4}), 0); +} +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/ops/coder_ops.cc b/tensorflow/contrib/coder/ops/coder_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..9056d1a6963d7be92f499db31385fb6afe2dc515 --- /dev/null +++ b/tensorflow/contrib/coder/ops/coder_ops.cc @@ -0,0 +1,119 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +// clang-format off +REGISTER_OP("RangeEncode") + .Input("data: int16") + .Input("cdf: int32") + .Output("encoded: string") + .Attr("precision: int >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Using the provided cumulative distribution functions (CDF) inside `cdf`, returns +a range-code of `data`. + +The shape of `cdf` should have one more axis than the shape of `data`, and the +prefix `cdf.shape[:-1]` should be broadcastable to `data.shape`. That is, for +every `i = 0,...,rank(data) - 1`, the op requires that either +`cdf.shape[i] == 1` or `cdf.shape[i] == data.shape[i]`. Note that this +broadcasting is limited in the sense that the number of axes must match, and +broadcasts only `cdf` but not `data`. + +`data` should have an upper bound `m > 0` such that each element is an integer +in range `[0, m)`. Then the last dimension size of `cdf` must be `m + 1`. For +each element of `data`, the innermost strip of `cdf` is a vector representing a +CDF. For each k = 0,...,m, `cdf[..., k] / 2^precision` is the probability that +an outcome is less than `k` (not less than or equal to). + +``` + cdf[..., 0] / 2^precision = Pr(data[...] < 0) + cdf[..., 1] / 2^precision = Pr(data[...] < 1) = Pr(data[...] <= 0) + cdf[..., 2] / 2^precision = Pr(data[...] < 2) = Pr(data[...] <= 1) + ... + cdf[..., m] / 2^precision = Pr(data[...] < m) = 1 +``` + +Therefore each element of `cdf` must be in `[0, 2^precision]`. + +Ideally `cdf[..., m]` should equal to `2^precision` but this is not a hard +requirement as long as `cdf[..., m] <= 2^precision`. + +The encoded string neither contains the shape information of the encoded data +nor a termination symbol. Therefore the shape of the encoded data must be +explicitly provided to the decoder. + +Implementation notes: + +- Because of potential performance issues, the op does not check whether +elements of `data` is in the correct range `[0, m)`, or if `cdf` satisfies +monotonic increase property. + +- For the range coder to decode the encoded string correctly, the decoder should +be able to reproduce the internal states of the encoder precisely. Otherwise, +the decoding would fail and once an error occur, all subsequent decoded values +are incorrect. For this reason, the range coder uses integer arithmetics and +avoids using any floating point operations internally, and `cdf` should contain +integers representing quantized probability mass rather than floating points. + +data: An int32 tensor. +cdf: An int32 tensor representing the CDF's of `data`. Each integer is divided + by `2^precision` to represent a fraction. +encoded: A range-coded scalar string. +precision: The number of bits for probability quantization. Must be <= 16. +)doc"); + + +REGISTER_OP("RangeDecode") + .Input("encoded: string") + .Input("shape: int32") + .Input("cdf: int32") + .Output("decoded: int16") + .Attr("precision: int >= 1") + .SetShapeFn([] (InferenceContext* c) { + ShapeHandle out; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); + c->set_output(0, out); + return Status::OK(); + }) + .Doc(R"doc( +Decodes a range-coded `code` into an int32 tensor of shape `shape`. + +This is the reverse op of RangeEncode. The shape of the tensor that was encoded +should be known by the caller. + +Implementation notes: + +- If wrong input was given (e.g., corrupt `encoded` string, or `cdf` or +`precision` do not match encoder), the decode is unsuccessful. Because of +potential performance issues, the decoder does not return error status. + +encoded: A scalar string tensor from RangeEncode. +shape: An int32 1-D tensor representing the shape of the data encoded by + RangeEncode. +decoded: An int32 tensor with shape equal to `shape`. +precision: The number of bits for probability quantization. Must be <= 16, and + must match the precision used by RangeEncode that produced `encoded`. +)doc"); +// clang-format on +} // namespace tensorflow diff --git a/tensorflow/contrib/coder/python/ops/coder_ops.py b/tensorflow/contrib/coder/python/ops/coder_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..bb262e338baf1d9c3c043f03a02c2d2851e22b49 --- /dev/null +++ b/tensorflow/contrib/coder/python/ops/coder_ops.py @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Range coder operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=wildcard-import,unused-import +from tensorflow.contrib.coder.python.ops import gen_coder_ops +from tensorflow.contrib.coder.python.ops.gen_coder_ops import * +# pylint: enable=wildcard-import,unused-import +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + + +_coder_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("_coder_ops.so")) diff --git a/tensorflow/contrib/coder/python/ops/coder_ops_test.py b/tensorflow/contrib/coder/python/ops/coder_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d5e14e7a641b5673e97882daf2b5a1796ee1bbef --- /dev/null +++ b/tensorflow/contrib/coder/python/ops/coder_ops_test.py @@ -0,0 +1,53 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Coder operations tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.coder.python.ops import coder_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class CoderOpsTest(test.TestCase): + """Coder ops test. + + Coder ops have C++ tests. Python test just ensures that Python binding is not + broken. + """ + + def testReadmeExample(self): + data = random_ops.random_uniform((128, 128), 0, 10, dtype=dtypes.int32) + histogram = math_ops.bincount(data, minlength=10, maxlength=10) + cdf = math_ops.cumsum(histogram, exclusive=False) + cdf = array_ops.pad(cdf, [[1, 0]]) + cdf = array_ops.reshape(cdf, [1, 1, -1]) + + data = math_ops.cast(data, dtypes.int16) + encoded = coder_ops.range_encode(data, cdf, precision=14) + decoded = coder_ops.range_decode( + encoded, array_ops.shape(data), cdf, precision=14) + + with self.test_session() as sess: + self.assertAllEqual(*sess.run((data, decoded))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 2108e42bce4eba1eed158fe85888f1699a69ba7e..29a593f6bcfa05dcafcdb2f94087380ad720dba1 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import function from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed +from tensorflow.python.framework import test_util from tensorflow.python.ops import gradients from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops @@ -169,6 +170,7 @@ class JITTest(test.TestCase): self.assertEqual(b"jit_scope_0", func_attrs["_XlaScope"].s) +@test_util.with_c_api class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): @@ -188,7 +190,7 @@ class CompilationEnabledInGradientTest(test.TestCase): for cg in c_grad_ops: self.assertTrue(cg.get_attr("_XlaCompile")) for ncg in nc_grad_ops: - with self.assertRaisesRegexp(ValueError, "No attr named"): + with self.assertRaisesRegexp(ValueError, "[Nn]o attr named"): ncg.get_attr("_XlaCompile") # d/dx (x ** 4) = 4 * (x ** 3) diff --git a/tensorflow/contrib/copy_graph/__init__.py b/tensorflow/contrib/copy_graph/__init__.py index 30a0aac140b576c501595fd6c8767b7dddde8e58..61ee39e4be1f0471309bb2672476dd9100cbfd49 100644 --- a/tensorflow/contrib/copy_graph/__init__.py +++ b/tensorflow/contrib/copy_graph/__init__.py @@ -13,8 +13,6 @@ # limitations under the License. # ============================================================================== """Functions to copy elements between graphs. - -See the @{$python/contrib.copy_graph} guide. """ from __future__ import absolute_import diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py index d060eda0a74010db10d9506b2a1c2345b2731709..b806799202bff4f2f6dbf717fbeea74a04b8cd6e 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py @@ -35,10 +35,10 @@ from tensorflow.python.ops.variables import Variable from tensorflow.python.client.session import Session from tensorflow.python.framework import ops -__all__ = ["copy_op_to_graph", "copy_variable_to_graph", "get_copied_op"] +__all__ = ['copy_op_to_graph', 'copy_variable_to_graph', 'get_copied_op'] -def copy_variable_to_graph(org_instance, to_graph, scope=""): +def copy_variable_to_graph(org_instance, to_graph, scope=''): """Given a `Variable` instance from one `Graph`, initializes and returns a copy of it from another `Graph`, under the specified scope (default `""`). @@ -56,12 +56,11 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""): """ if not isinstance(org_instance, Variable): - raise TypeError(str(org_instance) + " is not a Variable") + raise TypeError(str(org_instance) + ' is not a Variable') #The name of the new variable - if scope != "": - new_name = (scope + '/' + - org_instance.name[:org_instance.name.index(':')]) + if scope != '': + new_name = (scope + '/' + org_instance.name[:org_instance.name.index(':')]) else: new_name = org_instance.name[:org_instance.name.index(':')] @@ -73,15 +72,15 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""): for name, collection in org_instance.graph._collections.items(): if org_instance in collection: if (name == ops.GraphKeys.GLOBAL_VARIABLES or - name == ops.GraphKeys.TRAINABLE_VARIABLES or - scope == ''): + name == ops.GraphKeys.TRAINABLE_VARIABLES or scope == ''): collections.append(name) else: collections.append(scope + '/' + name) #See if its trainable. - trainable = (org_instance in org_instance.graph.get_collection( - ops.GraphKeys.TRAINABLE_VARIABLES)) + trainable = ( + org_instance in org_instance.graph.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES)) #Get the initial value with org_instance.graph.as_default(): temp_session = Session() @@ -89,17 +88,17 @@ def copy_variable_to_graph(org_instance, to_graph, scope=""): #Initialize the new variable with to_graph.as_default(): - new_var = Variable(init_value, - trainable, - name=new_name, - collections=collections, - validate_shape=False) + new_var = Variable( + init_value, + trainable, + name=new_name, + collections=collections, + validate_shape=False) return new_var -def copy_op_to_graph(org_instance, to_graph, variables, - scope=""): +def copy_op_to_graph(org_instance, to_graph, variables, scope=''): """Returns a copy of an operation from another Graph under a specified scope. Given an `Operation` `org_instance` from one `Graph`, @@ -139,14 +138,12 @@ def copy_op_to_graph(org_instance, to_graph, variables, #If a variable by the new name already exists, return the #correspondng tensor that will act as an input if new_name in copied_variables: - return to_graph.get_tensor_by_name( - copied_variables[new_name].name) + return to_graph.get_tensor_by_name(copied_variables[new_name].name) #If an instance of the same name exists, return appropriately try: - already_present = to_graph.as_graph_element(new_name, - allow_tensor=True, - allow_operation=True) + already_present = to_graph.as_graph_element( + new_name, allow_tensor=True, allow_operation=True) return already_present except: pass @@ -184,20 +181,21 @@ def copy_op_to_graph(org_instance, to_graph, variables, #If it has an original_op parameter, copy it if op._original_op is not None: - new_original_op = copy_op_to_graph(op._original_op, to_graph, - variables, scope) + new_original_op = copy_op_to_graph(op._original_op, to_graph, variables, + scope) else: new_original_op = None #If it has control inputs, call this function recursively on each. - new_control_inputs = [copy_op_to_graph(x, to_graph, variables, - scope) - for x in op.control_inputs] + new_control_inputs = [ + copy_op_to_graph(x, to_graph, variables, scope) + for x in op.control_inputs + ] #If it has inputs, call this function recursively on each. - new_inputs = [copy_op_to_graph(x, to_graph, variables, - scope) - for x in op.inputs] + new_inputs = [ + copy_op_to_graph(x, to_graph, variables, scope) for x in op.inputs + ] #Make a new node_def based on that of the original. #An instance of tensorflow.core.framework.node_def_pb2.NodeDef, it @@ -216,15 +214,11 @@ def copy_op_to_graph(org_instance, to_graph, variables, op_def = deepcopy(op._op_def) #Initialize a new Operation instance - new_op = ops.Operation(new_node_def, - to_graph, - new_inputs, - output_types, - new_control_inputs, - input_types, - new_original_op, + new_op = ops.Operation(new_node_def, to_graph, new_inputs, output_types, + new_control_inputs, input_types, new_original_op, op_def) #Use Graph's hidden methods to add the op + to_graph._add_op(new_op) # pylint: disable=protected-access to_graph._record_op_seen_by_control_dependencies(new_op) for device_function in reversed(to_graph._device_function_stack): new_op._set_device(device_function(new_op)) @@ -232,10 +226,10 @@ def copy_op_to_graph(org_instance, to_graph, variables, return new_op else: - raise TypeError("Could not copy instance: " + str(org_instance)) + raise TypeError('Could not copy instance: ' + str(org_instance)) -def get_copied_op(org_instance, graph, scope=""): +def get_copied_op(org_instance, graph, scope=''): """Given an `Operation` instance from some `Graph`, returns its namesake from `graph`, under the specified scope (default `""`). @@ -258,5 +252,5 @@ def get_copied_op(org_instance, graph, scope=""): else: new_name = org_instance.name - return graph.as_graph_element(new_name, allow_tensor=True, - allow_operation=True) + return graph.as_graph_element( + new_name, allow_tensor=True, allow_operation=True) diff --git a/tensorflow/contrib/copy_graph/python/util/copy_test.py b/tensorflow/contrib/copy_graph/python/util/copy_test.py index 2798d31229d048561f8ebd9b63d3df94a44c45c7..05744bec4e05405c04b5ec442e72e4495737ab5b 100644 --- a/tensorflow/contrib/copy_graph/python/util/copy_test.py +++ b/tensorflow/contrib/copy_graph/python/util/copy_test.py @@ -17,9 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np from tensorflow.contrib.copy_graph.python.util import copy_elements -from tensorflow.contrib.framework.python.framework import tensor_util from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops diff --git a/tensorflow/contrib/crf/__init__.py b/tensorflow/contrib/crf/__init__.py index bc749339bd4d49c8372bc731da98732f8c19cbe1..046c509626bc2eb20a65c0b38495ff37c294e0e1 100644 --- a/tensorflow/contrib/crf/__init__.py +++ b/tensorflow/contrib/crf/__init__.py @@ -16,15 +16,15 @@ See the @{$python/contrib.crf} guide. -@@crf_sequence_score -@@crf_log_norm -@@crf_log_likelihood -@@crf_unary_score @@crf_binary_score @@crf_decode -@@CrfForwardRnnCell -@@CrfDecodeForwardRnnCell +@@crf_log_likelihood +@@crf_log_norm +@@crf_sequence_score +@@crf_unary_score @@CrfDecodeBackwardRnnCell +@@CrfDecodeForwardRnnCell +@@CrfForwardRnnCell @@viterbi_decode """ @@ -32,16 +32,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.crf.python.ops.crf import _lengths_to_masks from tensorflow.contrib.crf.python.ops.crf import crf_binary_score from tensorflow.contrib.crf.python.ops.crf import crf_decode from tensorflow.contrib.crf.python.ops.crf import crf_log_likelihood from tensorflow.contrib.crf.python.ops.crf import crf_log_norm from tensorflow.contrib.crf.python.ops.crf import crf_sequence_score from tensorflow.contrib.crf.python.ops.crf import crf_unary_score -from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell -from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import CrfDecodeBackwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfDecodeForwardRnnCell +from tensorflow.contrib.crf.python.ops.crf import CrfForwardRnnCell from tensorflow.contrib.crf.python.ops.crf import viterbi_decode from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py index b47fb426a193e0fcc075deafae3eaab698f18ec9..721dc4d0801d1f0e116921888e3851a95e0b72b0 100644 --- a/tensorflow/contrib/crf/python/kernel_tests/crf_test.py +++ b/tensorflow/contrib/crf/python/kernel_tests/crf_test.py @@ -179,17 +179,6 @@ class CrfTest(test.TestCase): tf_total_log_likelihood = sess.run(total_log_likelihood) self.assertAllClose(tf_total_log_likelihood, 0.0) - def testLengthsToMasks(self): - with self.test_session() as sess: - sequence_lengths = [4, 1, 8, 2] - max_sequence_length = max(sequence_lengths) - mask = crf._lengths_to_masks(sequence_lengths, max_sequence_length) - tf_mask = sess.run(mask) - self.assertEqual(len(tf_mask), len(sequence_lengths)) - for m, l in zip(tf_mask, sequence_lengths): - self.assertAllEqual(m[:l], [1] * l) - self.assertAllEqual(m[l:], [0] * (len(m) - l)) - def testViterbiDecode(self): inputs = np.array( [[4, 5, -3], [3, -1, 3], [-1, 2, 1], [0, 0, 0]], dtype=np.float32) diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py index 7f5ae937b26f465076c6976429697c35924432e5..faa78769b98699af59047aed2865771120110fc2 100644 --- a/tensorflow/contrib/crf/python/ops/crf.py +++ b/tensorflow/contrib/crf/python/ops/crf.py @@ -70,25 +70,6 @@ __all__ = [ ] -def _lengths_to_masks(lengths, max_length): - """Creates a binary matrix that can be used to mask away padding. - - Args: - lengths: A vector of integers representing lengths. - max_length: An integer indicating the maximum length. All values in - lengths should be less than max_length. - Returns: - masks: Masks that can be used to get rid of padding. - """ - tiled_ranges = array_ops.tile( - array_ops.expand_dims(math_ops.range(max_length), 0), - [array_ops.shape(lengths)[0], 1]) - lengths = array_ops.expand_dims(lengths, 1) - masks = math_ops.to_float( - math_ops.to_int64(tiled_ranges) < math_ops.to_int64(lengths)) - return masks - - def crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params): """Computes the unnormalized score for a tag sequence. @@ -185,8 +166,8 @@ def crf_log_likelihood(inputs, sequence_lengths: A [batch_size] vector of true sequence lengths. transition_params: A [num_tags, num_tags] transition matrix, if available. Returns: - log_likelihood: A scalar containing the log-likelihood of the given sequence - of tag indices. + log_likelihood: A [batch_size] `Tensor` containing the log-likelihood of + each example, given the sequence of tag indices. transition_params: A [num_tags, num_tags] transition matrix. This is either provided by the caller or created in this function. """ @@ -201,7 +182,7 @@ def crf_log_likelihood(inputs, transition_params) log_norm = crf_log_norm(inputs, sequence_lengths, transition_params) - # Normalize the scores to get the log-likelihood. + # Normalize the scores to get the log-likelihood per example. log_likelihood = sequence_scores - log_norm return log_likelihood, transition_params @@ -234,7 +215,9 @@ def crf_unary_score(tag_indices, sequence_lengths, inputs): array_ops.gather(flattened_inputs, flattened_tag_indices), [batch_size, max_seq_len]) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) unary_scores = math_ops.reduce_sum(unary_scores * masks, 1) return unary_scores @@ -268,7 +251,9 @@ def crf_binary_score(tag_indices, sequence_lengths, transition_params): binary_scores = array_ops.gather(flattened_transition_params, flattened_transition_indices) - masks = _lengths_to_masks(sequence_lengths, array_ops.shape(tag_indices)[1]) + masks = array_ops.sequence_mask(sequence_lengths, + maxlen=array_ops.shape(tag_indices)[1], + dtype=dtypes.float32) truncated_masks = array_ops.slice(masks, [0, 1], [-1, -1]) binary_scores = math_ops.reduce_sum(binary_scores * truncated_masks, 1) return binary_scores diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index fce2c03e69bc4b8b0ac46b8e081a33c43c9d41ab..fec358c4e1067dc8dc8173d1b9d05dc90b90ca05 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -25,6 +25,7 @@ tf_custom_op_library( ], deps = [ "//tensorflow/core/kernels:bounds_check_lib", + "@farmhash_archive//:farmhash", ], ) @@ -39,6 +40,7 @@ tf_kernel_library( "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:bounds_check_lib", "//third_party/eigen3", + "@farmhash_archive//:farmhash", ], ) @@ -146,10 +148,10 @@ cuda_py_test( cuda_py_test( name = "cudnn_rnn_ops_benchmark", - size = "large", + size = "small", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_ops_py", + ":cudnn_rnn_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", @@ -164,7 +166,6 @@ cuda_py_test( "//tensorflow/python:variables", ], tags = [ - "manual", "noasan", # http://b/62067814 "nomsan", "notsan", diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc index 5d5f593d016a3bb9f7b5ea8f5cd40c29268dc4f5..ba9686e94ee7072cc485c955decb2287bd4a56f3 100644 --- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc @@ -39,6 +39,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/inlined_vector.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/fingerprint.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/env_var.h" @@ -369,6 +370,27 @@ struct CudnnModelShapes { } }; +// Utility class for using CudnnModelShapes as a hash table key. +struct CudnnModelShapesHasher { + uint64 operator()(const CudnnModelShapes& to_hash) const { + uint64 hash = static_cast(to_hash.num_layers); + hash = tensorflow::FingerprintCat64( + hash, static_cast(to_hash.input_size)); + hash = tensorflow::FingerprintCat64(hash, + static_cast(to_hash.num_units)); + return tensorflow::FingerprintCat64(hash, + static_cast(to_hash.dir_count)); + } +}; + +// Utility class for using CudnnModelShapes as a hash table key. +struct CudnnModelShapesComparator { + bool operator()(const CudnnModelShapes& first, + const CudnnModelShapes& second) const { + return first.IsCompatibleWith(second); + } +}; + // Extract and checks the forward input tensors, parameters, and shapes from the // OpKernelContext. Status ExtractForwardInput(OpKernelContext* context, @@ -627,7 +649,7 @@ class CudnnRNNParamsToCanonical : public CudnnRNNKernelCommon { } const int num_params_per_layer = num_params_ / num_layers / num_dirs; // Number of params applied on inputs. The rest are applied on recurrent - // hiddden states. + // hidden states. const int num_params_input_state = num_params_per_layer / 2; CHECK(num_params_ % (num_layers * num_dirs) == 0) << "Number of params is not a multiple of num_layers * num_dirs."; @@ -764,6 +786,13 @@ TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU +// Pointers to RNN scratch space for a specific set of shape parameters (used as +// a hash table value in CudnnRNNForwardOp and CudnnRNNBackwardOp). +struct RnnScratchSpace { + std::unique_ptr rnn_desc; + std::unique_ptr dropout_state_allocator; +}; + // Run the forward operation of the RNN model. template class CudnnRNNForwardOp : public CudnnRNNKernelCommon { @@ -808,32 +837,7 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, model_shapes.input_size, &input_mode)); - // TODO(zhengxq): cache the descriptor so we don't have to create them all - // the time. auto data_type = ToDataType::value; - { - mutex_lock l(mu_); - if (model_shapes_ == nullptr) { - model_shapes_.reset(new CudnnModelShapes(model_shapes)); - } else { - OP_REQUIRES(context, model_shapes_->IsCompatibleWith(model_shapes), - errors::InvalidArgument( - "Incompatible rnn model shapes inferred: expecting ", - model_shapes_->RnnDescDebugString(), ", getting ", - model_shapes.RnnDescDebugString(), ".")); - } - if (rnn_desc_ == nullptr || ResetRndGenState()) { - dropout_state_allocator_.reset( - new CudnnRNNPersistentSpaceAllocator(context)); - auto rnn_desc_s = executor->createRnnDescriptor( - model_shapes_->num_layers, model_shapes_->num_units, - model_shapes_->input_size, input_mode, rnn_direction_mode(), - rnn_mode(), data_type, dropout(), seed(), - dropout_state_allocator_.get()); - OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); - rnn_desc_ = std::move(rnn_desc_s.ConsumeValueOrDie()); - } - } auto input_desc_s = executor->createRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), @@ -882,14 +886,27 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { bool launch_status = false; { mutex_lock l(mu_); + RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; + if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { + CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = + new CudnnRNNPersistentSpaceAllocator(context); + rnn_state.dropout_state_allocator.reset(dropout_state_allocator); + auto rnn_desc_s = executor->createRnnDescriptor( + model_shapes.num_layers, model_shapes.num_units, + model_shapes.input_size, input_mode, rnn_direction_mode(), + rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); + OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); + rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); + } launch_status = stream - ->ThenRnnForward( - *rnn_desc_, *input_desc, input_data, *hidden_state_desc, - input_h_data, *hidden_state_desc, input_c_data, params_data, - *output_desc, &output_data, *hidden_state_desc, - &output_h_data, *hidden_state_desc, &output_c_data, - is_training_, &reserve_space_allocator, &workspace_allocator) + ->ThenRnnForward(*rnn_state.rnn_desc, *input_desc, input_data, + *hidden_state_desc, input_h_data, + *hidden_state_desc, input_c_data, params_data, + *output_desc, &output_data, *hidden_state_desc, + &output_h_data, *hidden_state_desc, + &output_c_data, is_training_, + &reserve_space_allocator, &workspace_allocator) .ok(); } OP_REQUIRES(context, launch_status, @@ -899,10 +916,9 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { private: mutex mu_; bool is_training_; - std::unique_ptr model_shapes_ GUARDED_BY(mu_); - std::unique_ptr rnn_desc_ GUARDED_BY(mu_); - std::unique_ptr dropout_state_allocator_ - GUARDED_BY(mu_); + std::unordered_map + rnn_state_cache_ GUARDED_BY(mu_); }; #define REGISTER_GPU(T) \ @@ -1022,32 +1038,6 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { OP_REQUIRES_OK(context, ToRNNInputMode(rnn_input_mode(), model_shapes.num_units, model_shapes.input_size, &input_mode)); - // TODO(zhengxq): cache the descriptor so we don't have to create them all - // the time. - { - mutex_lock l(mu_); - if (model_shapes_ == nullptr) { - model_shapes_.reset(new CudnnModelShapes(model_shapes)); - } else { - OP_REQUIRES(context, model_shapes_->IsCompatibleWith(model_shapes), - errors::InvalidArgument( - "Incompatible rnn model shapes inferred: expecting ", - model_shapes_->RnnDescDebugString(), ", getting ", - model_shapes.RnnDescDebugString(), ".")); - } - - if (rnn_desc_ == nullptr || ResetRndGenState()) { - dropout_state_allocator_.reset( - new CudnnRNNPersistentSpaceAllocator(context)); - auto rnn_desc_s = executor->createRnnDescriptor( - model_shapes.num_layers, model_shapes.num_units, - model_shapes.input_size, input_mode, rnn_direction_mode(), - rnn_mode(), data_type, dropout(), seed(), - dropout_state_allocator_.get()); - OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); - rnn_desc_ = std::move(rnn_desc_s.ConsumeValueOrDie()); - } - } auto input_desc_s = executor->createRnnSequenceTensorDescriptor( input_shape.dim_size(0), input_shape.dim_size(1), @@ -1100,17 +1090,30 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { bool launch_status = false; { mutex_lock l(mu_); + RnnScratchSpace& rnn_state = rnn_state_cache_[model_shapes]; + if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) { + CudnnRNNPersistentSpaceAllocator* dropout_state_allocator = + new CudnnRNNPersistentSpaceAllocator(context); + rnn_state.dropout_state_allocator.reset(dropout_state_allocator); + auto rnn_desc_s = executor->createRnnDescriptor( + model_shapes.num_layers, model_shapes.num_units, + model_shapes.input_size, input_mode, rnn_direction_mode(), + rnn_mode(), data_type, dropout(), seed(), dropout_state_allocator); + OP_REQUIRES_OK(context, FromExecutorStatus(rnn_desc_s)); + rnn_state.rnn_desc = std::move(rnn_desc_s.ConsumeValueOrDie()); + } launch_status = stream - ->ThenRnnBackward( - *rnn_desc_, *input_desc, input_data, *hidden_state_desc, - input_h_data, *hidden_state_desc, input_c_data, params_data, - *output_desc, output_data, *hidden_state_desc, output_h_data, - *hidden_state_desc, output_c_data, output_backprop_data, - output_h_backprop_data, output_c_backprop_data, - &input_backprop_data, &input_h_backprop_data, - &input_c_backprop_data, ¶ms_backprop_data, - &reserve_space_uint8, &workspace_allocator) + ->ThenRnnBackward(*rnn_state.rnn_desc, *input_desc, input_data, + *hidden_state_desc, input_h_data, + *hidden_state_desc, input_c_data, params_data, + *output_desc, output_data, *hidden_state_desc, + output_h_data, *hidden_state_desc, + output_c_data, output_backprop_data, + output_h_backprop_data, output_c_backprop_data, + &input_backprop_data, &input_h_backprop_data, + &input_c_backprop_data, ¶ms_backprop_data, + &reserve_space_uint8, &workspace_allocator) .ok(); } OP_REQUIRES(context, launch_status, @@ -1119,10 +1122,9 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { private: mutex mu_; - std::unique_ptr model_shapes_ GUARDED_BY(mu_); - std::unique_ptr rnn_desc_ GUARDED_BY(mu_); - std::unique_ptr dropout_state_allocator_ - GUARDED_BY(mu_); + std::unordered_map + rnn_state_cache_ GUARDED_BY(mu_); }; #define REGISTER_GPU(T) \ diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc index 9e41e67857101534e8bfef8d5d0b8a45ed8f1f76..1a79bf066c3a27e040099729fb079ee963f59270 100644 --- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc @@ -251,9 +251,8 @@ REGISTER_OP("CudnnRNNParamsToCanonical") TF_RETURN_IF_ERROR(c->GetAttr("num_params", &num_params)); // Set shape for weight matrices for (int i = 0; i < num_params; i++) { - c->set_output(i, - c->Matrix(InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim)); + c->set_output(i, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); } // Set shape for bias vectors for (int i = 0; i < num_params; i++) { @@ -300,6 +299,7 @@ upcoming training or inferences. num_params: number of parameter sets for all layers. Each layer may contain multiple parameter sets, with each set consisting of a weight matrix and a bias vector. -)doc", kCudnnRNNCommonAttrs)); +)doc", + kCudnnRNNCommonAttrs)); } // namespace tensorflow diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py index ff409ac71826f1f0f57e9133d768003f849abc09..933df6d71dd7c972efe63d54fa7344ecfc39b0a7 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_ops_benchmark.py @@ -20,8 +20,9 @@ from __future__ import print_function import time +from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.contrib import rnn as contrib_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops -from tensorflow.contrib.rnn.python.ops import core_rnn from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -29,8 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -44,19 +44,19 @@ class CudnnRNNBenchmark(test.Benchmark): "large": { "num_layers": 4, "num_units": 1024, - "seq_length": 40, + "seq_length": 50, "batch_size": 64, }, "medium": { "num_layers": 4, "num_units": 512, - "seq_length": 30, + "seq_length": 50, "batch_size": 64, }, "small": { "num_layers": 4, "num_units": 128, - "seq_length": 20, + "seq_length": 50, "batch_size": 64, }, } @@ -71,7 +71,7 @@ class CudnnRNNBenchmark(test.Benchmark): def _BenchmarkOp(self, op, desc): burn_in_steps = 10 - benchmark_steps = 40 + benchmark_steps = 20 with session.Session() as sess: sess.run(variables.global_variables_initializer()) for i in xrange(burn_in_steps + benchmark_steps): @@ -126,16 +126,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=127) - - cell = rnn_cell.LSTMCell( - num_units=num_units, initializer=initializer, state_is_tuple=True) - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [contrib_rnn.BasicLSTMCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) @@ -154,14 +150,12 @@ class CudnnRNNBenchmark(test.Benchmark): seq_length = config["seq_length"] with ops.Graph().as_default(), ops.device("/device:GPU:0"): - inputs = seq_length * [ - array_ops.zeros([batch_size, num_units], dtypes.float32) - ] - cell = lambda: lstm_ops.LSTMBlockCell(num_units=num_units) # pylint: disable=cell-var-from-loop - - multi_cell = rnn_cell.MultiRNNCell( - [cell() for _ in range(num_layers)]) - outputs, final_state = core_rnn.static_rnn( + inputs = array_ops.zeros([batch_size, seq_length, num_units], + dtypes.float32) + + multi_cell = contrib_rnn.MultiRNNCell( + [lstm_ops.LSTMBlockCell(num_units) for _ in range(num_layers)]) + outputs, final_state = rnn.dynamic_rnn( multi_cell, inputs, dtype=dtypes.float32) trainable_variables = ops.get_collection( ops.GraphKeys.TRAINABLE_VARIABLES) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py index e65394cba07574ed49398981f1cbd8bcb402e24f..9897c31a98e0b335c18a84825fc518ed1fc310a2 100644 --- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -29,6 +29,8 @@ import numpy as np from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed @@ -49,7 +51,11 @@ from tensorflow.python.ops.losses import losses from tensorflow.python.platform import googletest from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import adagrad +from tensorflow.python.training import adam from tensorflow.python.training import gradient_descent +from tensorflow.python.training import momentum +from tensorflow.python.training import rmsprop from tensorflow.python.training import saver as saver_lib @@ -314,6 +320,150 @@ class CudnnRNNTestBasic(TensorFlowTestCase): self.assertEqual(0, total_sum2_v) self.assertEqual(0, total_sum3_v) + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testOptimizersSupport(self): + for opt in ("adagrad", "adam", "rmsprop", "momentum", "sgd"): + self._TestOptimizerSupportHelper(opt) + + def _GetOptimizer(self, opt): + if opt == "adagrad": + return adagrad.AdagradOptimizer(learning_rate=1e-2) + elif opt == "adam": + return adam.AdamOptimizer(learning_rate=1e-2) + elif opt == "rmsprop": + return rmsprop.RMSPropOptimizer(learning_rate=1e-2) + elif opt == "momentum": + return momentum.MomentumOptimizer(learning_rate=1e-2, momentum=0.9) + elif opt == "sgd": + return gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + else: + raise ValueError("Unsupported optimizer: %s" % opt) + + def _TestOptimizerSupportHelper(self, opt): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + with ops.Graph().as_default() as g: + kernel_initializer = init_ops.constant_initializer(0.) + bias_initializer = init_ops.constant_initializer(0.) + inputs = random_ops.random_uniform([ + num_layers * dir_count, batch_size, num_units], dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + outputs, _ = lstm(inputs) + loss = math_ops.reduce_sum(outputs) + optimizer = self._GetOptimizer(opt) + train_op = optimizer.minimize(loss) + + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(train_op) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveableGraphDeviceAssignment(self): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + def DeviceFn(op): + if op.type in ("Variable", "VariableV2"): + return "/cpu:0" + else: + return "/gpu:0" + + with ops.Graph().as_default() as g: + with ops.device(DeviceFn): + with vs.variable_scope("main"): + kernel_initializer = init_ops.constant_initializer(3.14) + bias_initializer = init_ops.constant_initializer(1.59) + inputs = random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], + dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + outputs = lstm(inputs) + + # saver is created in the scope of DeviceFn. + saver = saver_lib.Saver() + + with self.test_session(use_gpu=True, graph=g) as sess: + save_path = os.path.join(self.get_temp_dir(), + "test-saveable-device-assignment") + sess.run(variables.global_variables_initializer()) + + saver.save(sess, save_path) + saver.restore(sess, save_path) + sess.run(outputs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testDifferentShapesEager(self): + # Checks that kernel caching does not cause sharing of temporary storage + # across different input shapes when executing eagerly. + with context.eager_mode(): + with ops.device("gpu:0"): + first_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 28])) + second_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 100])) + self.assertAllEqual([28, 100, 100], first_output.shape) + self.assertAllEqual([28, 100, 100], second_output.shape) + + def _LossFunc(): + first_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 28])) + second_output, _ = cudnn_rnn.CudnnGRU(1, 100)( + array_ops.zeros([28, 100, 100])) + return (math_ops.reduce_sum(first_output) + + math_ops.reduce_sum(second_output)) + + backprop.implicit_grad(_LossFunc)() + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testDifferentShapesGraph(self): + # Tests that a single kernel instance presented with multiple input shapes + # does not crash with graph execution. + with ops.device("gpu:0"): + layer = cudnn_rnn.CudnnGRU(1, 100) + layer(array_ops.zeros([28, 100, 100])) + + def _Cond(index, accumulation): + del accumulation # unused + return math_ops.less(index, 4) + + def _Body(index, accumulation): + layer_input = accumulation[:, :, 10 * (1 + index % 2):] + output, _ = layer(layer_input) + return index + 1, accumulation + output + + original_input = array_ops.zeros([28, 100, 100]) + _, accumulation = control_flow_ops.while_loop(_Cond, _Body, + [0, original_input]) + grad, = gradients.gradients( + math_ops.reduce_sum(accumulation), (original_input,)) + init_op = variables.global_variables_initializer() + with self.test_session() as sess: + sess.run(init_op) + accumulation_eval, grad_eval = sess.run((accumulation, grad)) + self.assertAllEqual([28, 100, 100], accumulation_eval.shape) + self.assertAllEqual([28, 100, 100], grad_eval.shape) + # TODO(jamesqin): Transform to parameterized test after it is included in the # TF open source codebase. diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py index 37c61a71a3bdac4fadef58ba8c24b853fb3638ef..36fba917a8f56c26fd5b4c3468d1d980a8ba2ba5 100644 --- a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -176,8 +176,9 @@ class _CudnnRNN(base_layer.Layer): otherwise, it implies 'linear_input'. direction: the direction model that the model operates. Can be either 'unidirectional' or 'bidirectional' - dropout: dropout rate, a number between [0, 1]. Dropout is applied on - inputs of each layer. When set to 0, dropout is disabled. + dropout: dropout rate, a number between [0, 1]. Dropout is applied between + each layer (no dropout is applied for a model with a single layer). + When set to 0, dropout is disabled. seed: the op seed used for initializing dropout. See @{tf.set_random_seed} for behavior. dtype: tf.float16, tf.float32 or tf.float64 @@ -358,7 +359,7 @@ class _CudnnRNN(base_layer.Layer): # Create saveable in the outer scope of the cudnn subgraph, such that # alternative subgraph with platform-independent rnn cells can load the # checkpoints directly. - if not (self.built or vs.get_variable_scope().reuse): + if not (self.built or vs.get_variable_scope().reuse is True): self._create_saveable() self.built = True @@ -450,17 +451,18 @@ class _CudnnRNN(base_layer.Layer): raise RuntimeError( "%s._canonical_to_opaque invoked before input shape is known" % type(self).__name__) - return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( - rnn_mode=self._rnn_mode, - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - input_mode=self._input_mode, - seed=self._seed, - dropout=self._dropout, - direction=self._direction) + with ops.device("/gpu:0"): + return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( + rnn_mode=self._rnn_mode, + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + input_mode=self._input_mode, + seed=self._seed, + dropout=self._dropout, + direction=self._direction) def _forward(self, inputs, h, c, opaque_params, training): output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access @@ -489,14 +491,14 @@ class _CudnnRNN(base_layer.Layer): if self._saveable is not None: raise RuntimeError("Cudnn saveable already created.") self._saveable = self._saveable_cls( # pylint:disable=not-callable - self.trainable_variables[0], - self.num_layers, - self.num_units, - self.input_size, - self.input_mode, - self.direction, + opaque_params=self.trainable_variables[0], + num_layers=self.num_layers, + num_units=self.num_units, + input_size=self.input_size, + input_mode=self.input_mode, + direction=self.direction, scope=vs.get_variable_scope(), - name="%s_saveable" % self.trainable_variables[0].op.name) + name="%s_saveable" % self.trainable_variables[0].name.split(":")[0]) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index dcd3d4732a27ae4bec579ac12ac568dc4a53baaa..e87162f0ee9cc4eed795555171f55a93639e83cf 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -72,7 +72,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( num_units, forget_bias=0, cell_clip=None, use_peephole=False, - reuse=reuse) + reuse=reuse, name="cudnn_compatible_lstm_cell") self._names.update({"scope": "cudnn_compatible_lstm_cell"}) @@ -303,16 +303,17 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Returns: 2 list for weights and biases respectively. """ - weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - params=self._variables, - num_params=self._num_params, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) - return (weights, biases) + with ops.device("/gpu:0"): + weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + params=self._variables, + num_params=self._num_params, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) + return (weights, biases) def _CanonicalToOpaqueParams(self, cu_weights, cu_biases): """Converts from Cudnn canonical format to opaque params. @@ -323,15 +324,16 @@ class CudnnOpaqueParamsSaveable(saver.BaseSaverBuilder.SaveableObject): Returns: a single opaque tensor. """ - return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( - num_layers=self._num_layers, - num_units=self._num_units, - input_size=self._input_size, - weights=cu_weights, - biases=cu_biases, - rnn_mode=self._rnn_mode, - input_mode=self._input_mode, - direction=self._direction) + with ops.device("/gpu:0"): + return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + rnn_mode=self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction) def _TransformCanonical(self, cu_weights, cu_biases): r"""Transform from Cudnn canonical to tf canonical. @@ -1352,7 +1354,7 @@ class _CudnnRNN(object): params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. Returns: - output: the output sequuence. + output: the output sequence. output_h: the final state for h. output_c: the final state for c. This is only relevant for LSTM. """ @@ -1470,7 +1472,7 @@ class CudnnLSTM(_CudnnRNN): params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. Returns: - output: the output sequuence. + output: the output sequence. output_h: the final state for h. output_c: the final state for c. """ @@ -1540,7 +1542,7 @@ class _CudnnRNNNoInputC(_CudnnRNN): params: the parameter buffer created for this model. is_training: whether this operation will be used in training or inference. Returns: - output: the output sequuence. + output: the output sequence. output_h: the final state for h. """ return _cudnn_rnn_no_input_c( diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index f7d8a084d9c12c05c411ae0751854d1823a818ec..0458199ff771bc45603106411550a39448e515b8 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -18,20 +18,22 @@ py_library( "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:util", "//tensorflow/python/data/ops:iterator_ops", ], ) tf_custom_op_library( - name = "_prefetching_ops.so", - srcs = ["ops/prefetching_ops.cc"], - deps = ["//tensorflow/contrib/data/kernels:prefetching_kernels"], + name = "_dataset_ops.so", + srcs = ["ops/dataset_ops.cc"], + deps = ["//tensorflow/contrib/data/kernels:dataset_kernels"], ) tf_gen_op_libs( - op_lib_names = ["prefetching_ops"], + op_lib_names = ["dataset_ops"], ) filegroup( diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 7c6244f22b0f41656369595d3e3e6c23b7088bcb..fcdccdd26ca1824bf13f1fd0cfd80b20ca8a10c3 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -12,30 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""`tf.contrib.data.Dataset` API for input pipelines. +"""Experimental API for building input pipelines. + +This module contains experimental `Dataset` sources and transformations that can +be used in conjunction with the @{tf.data.Dataset} API. Note that the +`tf.contrib.data` API is not subject to the same backwards compatibility +guarantees as `tf.data`, but we will provide deprecation advice in advance of +removing existing functionality. See the @{$datasets$Importing Data} Programmer's Guide for an overview. -@@Dataset @@Counter -@@Iterator -@@TFRecordDataset -@@FixedLengthRecordDataset -@@TextLineDataset @@batch_and_drop_remainder -@@padded_batch_and_drop_remainder @@dense_to_sparse_batch @@enumerate_dataset @@group_by_window @@ignore_errors @@make_saveable_from_iterator -@@read_batch_features -@@unbatch +@@map_and_batch +@@padded_batch_and_drop_remainder @@parallel_interleave +@@read_batch_features @@rejection_resample @@scan +@@shuffle_and_repeat @@sloppy_interleave +@@unbatch @@get_single_element """ @@ -48,25 +51,22 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch +from tensorflow.contrib.data.python.ops.batching import map_and_batch from tensorflow.contrib.data.python.ops.batching import padded_batch_and_drop_remainder from tensorflow.contrib.data.python.ops.batching import unbatch from tensorflow.contrib.data.python.ops.counter import Counter -from tensorflow.contrib.data.python.ops.dataset_ops import Dataset -from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset from tensorflow.contrib.data.python.ops.error_ops import ignore_errors +from tensorflow.contrib.data.python.ops.get_single_element import get_single_element from tensorflow.contrib.data.python.ops.grouping import group_by_window from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator -from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset from tensorflow.contrib.data.python.ops.readers import read_batch_features from tensorflow.contrib.data.python.ops.readers import SqlDataset -from tensorflow.contrib.data.python.ops.readers import TextLineDataset -from tensorflow.contrib.data.python.ops.readers import TFRecordDataset from tensorflow.contrib.data.python.ops.resampling import rejection_resample from tensorflow.contrib.data.python.ops.scan_ops import scan -from tensorflow.python.data.ops.iterator_ops import Iterator +from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD index 4cb53741ebf8cd0db41b382c878bd2ccd1dcf7f1..56471911c5c0d1c1825955c67997b5bbc0786463 100644 --- a/tensorflow/contrib/data/kernels/BUILD +++ b/tensorflow/contrib/data/kernels/BUILD @@ -17,6 +17,28 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "ignore_errors_dataset_op", + srcs = ["ignore_errors_dataset_op.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +cc_library( + name = "dataset_kernels", + deps = [ + ":ignore_errors_dataset_op", + ":prefetching_kernels", + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/core/kernels/ignore_errors_dataset_op.cc b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc similarity index 97% rename from tensorflow/core/kernels/ignore_errors_dataset_op.cc rename to tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc index 8cf263d87fed601ed987e5d13909dd433391e5bd..bb29df60e8f114aaa50f578c43e73874f72ab0a3 100644 --- a/tensorflow/core/kernels/ignore_errors_dataset_op.cc +++ b/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc @@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/kernels/dataset.h" - +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/random/random.h" @@ -109,7 +108,7 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - Status RestoreInternal(OpKernelContext* ctx, + Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { mutex_lock l(mu_); if (reader->Contains(full_name("input_impls_empty"))) diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc index c9a3537c70c711290fb1111a1594e6dea3bc07a9..d3df14bdd03476e9ee4015b374512e5bb9893a63 100644 --- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -83,11 +83,10 @@ class FunctionBufferingResource : public ResourceBase { return Status::OK(); } AttrValueMap attr_values = func_.attr(); - AttrValue v; - v.set_s(target_device_); - AddAttr("_target", v, &attr_values); - - return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_); + FunctionLibraryRuntime::InstantiateOptions opts; + opts.target = target_device_; + return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts, + &handle_); } // Returns true if we've got to the end of the sequence and exhausted the diff --git a/tensorflow/contrib/data/ops/prefetching_ops.cc b/tensorflow/contrib/data/ops/dataset_ops.cc similarity index 86% rename from tensorflow/contrib/data/ops/prefetching_ops.cc rename to tensorflow/contrib/data/ops/dataset_ops.cc index 23cb62b6f0dbfed15667dd00ae0039b33aa944d4..289ffa1d9c29092cdf434e86ed5553ff9644d43e 100644 --- a/tensorflow/contrib/data/ops/prefetching_ops.cc +++ b/tensorflow/contrib/data/ops/dataset_ops.cc @@ -17,6 +17,16 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("IgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +)doc"); + REGISTER_OP("FunctionBufferingResource") .Input("string_arg: string") .Input("target_device: string") diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index 1d4817fa2670317f4f4e9e63c724a79e18aa35bc..e51d57cc896dc32d8e11912cd89f34a04a858c78 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -4,7 +4,7 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_test( name = "batch_dataset_op_test", @@ -36,6 +36,7 @@ py_test( srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", @@ -51,37 +52,17 @@ py_test( ], ) -py_test( - name = "cache_dataset_op_test", - size = "small", - srcs = ["cache_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - py_test( name = "concatenate_dataset_op_test", size = "small", srcs = ["concatenate_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], @@ -89,7 +70,7 @@ py_test( py_test( name = "dataset_constructor_op_test", - size = "small", + size = "medium", srcs = ["dataset_constructor_op_test.py"], srcs_version = "PY2AND3", tags = [ @@ -118,7 +99,6 @@ py_test( py_library( name = "dataset_serialization_test", - testonly = 1, srcs = [ "dataset_serialization_test_base.py", ], @@ -128,6 +108,7 @@ py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:platform", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", @@ -157,14 +138,13 @@ py_test( ], ) -py_test( +tf_py_test( name = "flat_map_dataset_op_test", - size = "small", + size = "medium", srcs = ["flat_map_dataset_op_test.py"], - srcs_version = "PY2AND3", - tags = ["no_pip"], - deps = [ + additional_deps = [ ":dataset_serialization_test", + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -177,17 +157,19 @@ py_test( "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//third_party/py/numpy", ], + grpc_enabled = True, + tags = ["no_pip"], ) py_test( name = "interleave_dataset_op_test", - size = "small", + size = "medium", srcs = ["interleave_dataset_op_test.py"], srcs_version = "PY2AND3", tags = [ - "manual", # b/67958761 + "no_oss", + "no_pip", ], deps = [ ":dataset_serialization_test", @@ -207,77 +189,25 @@ py_test( ], ) -py_test( - name = "iterator_ops_cluster_test", +tf_py_test( + name = "get_single_element_test", size = "small", - srcs = ["iterator_ops_cluster_test.py"], - srcs_version = "PY2AND3", - tags = ["no_windows"], - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:session", - "//tensorflow/python/data/ops:iterator_ops", - ], -) - -py_test( - name = "iterator_ops_test", - size = "small", - srcs = ["iterator_ops_test.py"], - srcs_version = "PY2AND3", - deps = [ + srcs = ["get_single_element_test.py"], + additional_deps = [ + "//third_party/py/numpy", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:readers", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", - "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:function", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:io_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:session", - "//tensorflow/python:training", - "//tensorflow/python/data/ops:iterator_ops", - "//third_party/py/numpy", - ], -) - -py_test( - name = "list_files_dataset_op_test", - size = "small", - srcs = ["list_files_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", - "//tensorflow/python:util", ], ) py_test( name = "map_dataset_op_test", - size = "small", + size = "medium", srcs = ["map_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -304,6 +234,7 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/data/ops:dataset_ops", "//third_party/py/numpy", ], ) @@ -327,8 +258,8 @@ py_test( srcs = ["range_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -339,11 +270,8 @@ py_test( "//tensorflow/python:framework_ops", "//tensorflow/python:io_ops", "//tensorflow/python:parsing_ops", - "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", - "//tensorflow/python:training", "//tensorflow/python:variables", - "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -389,8 +317,27 @@ py_test( ) py_test( - name = "sequence_dataset_op_test", + name = "scan_dataset_op_test", size = "small", + srcs = ["scan_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python/data/ops:dataset_ops", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sequence_dataset_op_test", + size = "medium", srcs = ["sequence_dataset_op_test.py"], srcs_version = "PY2AND3", tags = ["no_pip"], @@ -406,33 +353,36 @@ py_test( ) py_test( - name = "shard_dataset_op_test", + name = "serialization_integration_test", size = "small", - srcs = ["shard_dataset_op_test.py"], + srcs = ["serialization_integration_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", ], ) py_test( name = "shuffle_dataset_op_test", - size = "small", + size = "medium", srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:shuffle_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", - "//tensorflow/python:training", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", @@ -450,19 +400,41 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "@org_sqlite//:python", ], ) py_test( name = "stats_dataset_ops_test", - size = "small", + size = "medium", srcs = ["stats_dataset_ops_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":dataset_serialization_test", + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + ], +) + +py_test( + name = "unique_dataset_op_test", + size = "small", + srcs = ["unique_dataset_op_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/contrib/data/python/ops:transformation_ops", + "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//third_party/py/numpy", ], ) @@ -493,7 +465,7 @@ py_test( "no_oss", # b/68785503 ], deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_py", + "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index a939b3c841286a3b5786268dc3a9c82fd7359bfb..71dc1c1172c9d515d4c85f85257c952135098329 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -23,292 +23,33 @@ import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test -from tensorflow.python.util import compat class BatchDatasetTest(test.TestCase): - def testBatchDataset(self): - """Test an dataset that maps a TF function across its input elements.""" - # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(count) -> BatchDataset(batch_size). - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - count = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count).batch(batch_size).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([[None] + list(c.shape[1:]) for c in components], - [t.shape.as_list() for t in get_next]) - - with self.test_session() as sess: - # Batch of a finite input, where the batch_size divides the - # total number of elements. - sess.run(init_op, feed_dict={count: 28, batch_size: 14}) - num_batches = (28 * 7) // 14 - for i in range(num_batches): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(14): - self.assertAllEqual(component[(i * 14 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Batch of a finite input, where the batch_size does not - # divide the total number of elements. - sess.run(init_op, feed_dict={count: 14, batch_size: 8}) - - # We expect (num_batches - 1) full-sized batches. - num_batches = int(math.ceil((14 * 7) / 8)) - for i in range(num_batches - 1): - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range(8): - self.assertAllEqual(component[(i * 8 + j) % 7]**2, - result_component[j]) - result = sess.run(get_next) - for component, result_component in zip(components, result): - for j in range((14 * 7) % 8): - self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, - result_component[j]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Batch of an empty input should fail straight away. - sess.run(init_op, feed_dict={count: 0, batch_size: 8}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Empty batch should be an initialization time error. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={count: 14, batch_size: 0}) - def assertSparseValuesEqual(self, a, b): self.assertAllEqual(a.indices, b.indices) self.assertAllEqual(a.values, b.values) self.assertAllEqual(a.dense_shape, b.dense_shape) - def testBatchSparse(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) - - iterator = dataset_ops.Dataset.range(10).map(_sparse).batch( - 5).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(2): - actual = sess.run(get_next) - expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]], - values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4], - dense_shape=[5, 1]) - self.assertTrue(sparse_tensor.is_sparse(actual)) - self.assertSparseValuesEqual(actual, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNestedBatchSparse(self): - - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=[[0]], values=(i * [1]), dense_shape=[1]) - - iterator = dataset_ops.Dataset.range(10).map(_sparse).batch(5).batch( - 2).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - actual = sess.run(get_next) - expected = sparse_tensor.SparseTensorValue( - indices=[[0, 0, 0], [0, 1, 0], [0, 2, 0], [0, 3, 0], [0, 4, 0], - [1, 0, 0], [1, 1, 0], [1, 2, 0], [1, 3, 0], [1, 4, 0]], - values=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], - dense_shape=[2, 5, 1]) - self.assertTrue(sparse_tensor.is_sparse(actual)) - self.assertSparseValuesEqual(actual, expected) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPaddedBatchDataset(self): - seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) - padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(seq_lens) - .map(lambda x: array_ops.fill([x], x)).padded_batch( - 4, padded_shapes=padded_shape).make_initializable_iterator()) - - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - # Test with random sequence lengths, and max padding. - random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run( - init_op, feed_dict={ - padded_shape: [-1], - seq_lens: random_seq_lens - }) - for i in range(8): - result = sess.run(get_next) - padded_len = np.max(result) - self.assertEqual((4, padded_len), result.shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] - self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[j, seq_len:], [0] * (padded_len - seq_len)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test with random sequence lengths, and constant padding. - sess.run( - init_op, feed_dict={ - padded_shape: [25], - seq_lens: random_seq_lens - }) - for i in range(8): - result = sess.run(get_next) - self.assertEqual((4, 25), result.shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] - self.assertAllEqual(result[j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[j, seq_len:], [0] * (25 - seq_len)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test correct handling of empty tensors. - sess.run(init_op, feed_dict={padded_shape: [-1], seq_lens: [0, 0, 0, 0]}) - result = sess.run(get_next) - self.assertAllEqual([[], [], [], []], result) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test error handling with constant sequence lengths, and - # too-short padding. - sess.run(init_op, feed_dict={padded_shape: [5], seq_lens: [6, 5, 5, 5]}) - with self.assertRaises(errors.DataLossError): - result = sess.run(get_next) - - def testPaddedBatchDatasetNonDefaultPadding(self): - seq_lens = array_ops.placeholder(dtypes.int32, shape=[None]) - padded_shape = array_ops.placeholder(dtypes.int64, shape=[1]) - - def fill_tuple(x): - filled = array_ops.fill([x], x) - return (filled, string_ops.as_string(filled)) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(seq_lens).map(fill_tuple) - .padded_batch( - 4, - padded_shapes=(padded_shape, padded_shape), - padding_values=(-1, "")).make_initializable_iterator()) - - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - # Test with random sequence lengths, and max padding. - random_seq_lens = np.random.randint(20, size=(32,)).astype(np.int32) - sess.run( - init_op, feed_dict={ - padded_shape: [-1], - seq_lens: random_seq_lens - }) - for i in range(8): - result = sess.run(get_next) - padded_len = np.max(result[0]) - self.assertEqual((4, padded_len), result[0].shape) - self.assertEqual((4, padded_len), result[1].shape) - for j in range(4): - seq_len = random_seq_lens[(i * 4) + j] - self.assertAllEqual(result[0][j, :seq_len], [seq_len] * seq_len) - self.assertAllEqual(result[0][j, seq_len:], - [-1] * (padded_len - seq_len)) - self.assertAllEqual(result[1][j, :seq_len], - [compat.as_bytes(str(seq_len))] * seq_len) - self.assertAllEqual(result[1][j, seq_len:], - [b""] * (padded_len - seq_len)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPaddedBatchDatasetShapeSpecifications(self): - int_placeholder = array_ops.placeholder(dtypes.int32) - float_placeholder = array_ops.placeholder(dtypes.float32) - string_placeholder = array_ops.placeholder(dtypes.string) - input_dataset = dataset_ops.Dataset.from_tensors( - (int_placeholder, float_placeholder, string_placeholder)) - - # Test different ways of specifying the `padded_shapes` argument. - dynamic_padding_from_tensor_shapes = input_dataset.padded_batch( - 32, - padded_shapes=(tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([None, None]), - tensor_shape.TensorShape([37]))) - dynamic_padding_from_lists = input_dataset.padded_batch( - 32, padded_shapes=([None], [None, None], [37])) - dynamic_padding_from_lists_with_minus_one = input_dataset.padded_batch( - 32, padded_shapes=([-1], [-1, -1], [37])) - dynamic_padding_from_tensors = input_dataset.padded_batch( - 32, - padded_shapes=(constant_op.constant([-1], dtype=dtypes.int64), - constant_op.constant([-1, -1], dtype=dtypes.int64), - constant_op.constant([37], dtype=dtypes.int64))) - - for dataset in [ - dynamic_padding_from_tensor_shapes, dynamic_padding_from_lists, - dynamic_padding_from_lists_with_minus_one, dynamic_padding_from_tensors - ]: - self.assertEqual([None, None], dataset.output_shapes[0].as_list()) - self.assertEqual([None, None, None], dataset.output_shapes[1].as_list()) - self.assertEqual([None, 37], dataset.output_shapes[2].as_list()) - - def testPaddedBatchSparseError(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i - - with self.assertRaises(TypeError): - _ = dataset_ops.Dataset.range(10).map(_map_fn).padded_batch(10) - def testDenseToSparseBatchDataset(self): components = np.random.randint(12, size=(100,)).astype(np.int32) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([x], x)).apply( - batching.dense_to_sparse_batch(4, - [12])).make_initializable_iterator()) + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -334,9 +75,9 @@ class BatchDatasetTest(test.TestCase): dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([x, x], x)).apply( batching.dense_to_sparse_batch( - 4, [5, -1])).make_initializable_iterator()) + 4, [5, None])).make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: sess.run(init_op) @@ -363,25 +104,18 @@ class BatchDatasetTest(test.TestCase): def testDenseToSparseBatchDatasetWithInvalidShape(self): input_tensor = array_ops.constant([[1]]) - iterator = ( - dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, [-2])) - .make_initializable_iterator()) - init_op = iterator.initializer - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dimension -2 must be >= -1"): - sess.run(init_op) + with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"): + dataset_ops.Dataset.from_tensors(input_tensor).apply( + batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator() def testDenseToSparseBatchDatasetShapeErrors(self): input_tensor = array_ops.placeholder(dtypes.int32) iterator = ( dataset_ops.Dataset.from_tensors(input_tensor).apply( - batching.dense_to_sparse_batch(4, - [12])).make_initializable_iterator()) + batching.dense_to_sparse_batch(4, [12])) + .make_initializable_iterator()) init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + get_next = iterator.get_next() with self.test_session() as sess: # Initialize with an input tensor of incompatible rank. @@ -577,7 +311,7 @@ class BatchDatasetTest(test.TestCase): self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) - def testBatchAndMapDataset(self): + def _testBatchAndMapDatasetHelper(self, num_parallel_batches=1): """Test a dataset that maps a TF function across its input elements.""" # The pipeline is TensorSliceDataset -> # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). @@ -593,7 +327,10 @@ class BatchDatasetTest(test.TestCase): iterator = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply( - batching.map_and_batch(_map_fn, batch_size)) + batching.map_and_batch( + map_func=_map_fn, + batch_size=batch_size, + num_parallel_batches=num_parallel_batches)) .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -627,7 +364,11 @@ class BatchDatasetTest(test.TestCase): for j in range(8): self.assertAllEqual(component[(i * 8 + j) % 7]**2, result_component[j]) - # The last batch should fail with `OutOfRange`. + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range((14 * 7) % 8): + self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2, + result_component[j]) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) @@ -640,6 +381,12 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + def testBatchAndMapDataset(self): + return self._testBatchAndMapDatasetHelper() + + def testBatchAndMapDatasetWithParallelBatching(self): + return self._testBatchAndMapDatasetHelper(num_parallel_batches=10) + def testMapAndBatchSparse(self): def _sparse(i): @@ -722,6 +469,39 @@ class BatchDatasetSerializationTest( lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), num_outputs) + def _build_dataset_dense_to_sparse(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.fill([x], x)).apply( + batching.dense_to_sparse_batch(4, [12])) + + # TODO(b/70988345): Re-enable when sparse tensors are properly supported by + # the DatasetSerializationTestBase. + def _testDenseToSparseBatchDatasetCore(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + diff_comp = np.random.randint(2, size=(100,)).astype(np.int32) + + num_outputs = len(components) // 4 + self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components), + lambda: self._build_dataset_dense_to_sparse(diff_comp), + num_outputs) + + def _sparse(self, i): + return sparse_tensor.SparseTensorValue( + indices=[[0]], values=(i * [1]), dense_shape=[1]) + + def _build_dataset_sparse(self, batch_size=5): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size) + + def testSparseCore(self): + self.run_core_tests(self._build_dataset_sparse, + lambda: self._build_dataset_sparse(2), 2) + + def _build_dataset_nested_sparse(self): + return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2) + + def testNestedSparseCore(self): + self.run_core_tests(self._build_dataset_nested_sparse, None, 1) + class PaddedBatchDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 765ed53618958a8c49b26e416c57be28ea3bba73..f1b494e1a620992365ed75613b508e32f94b40a4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -19,8 +19,9 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import grouping +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -40,8 +41,7 @@ class GroupByWindowTest(test.TestCase): dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) .apply( grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), - 4)) - .make_initializable_iterator()) + 4)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -52,7 +52,8 @@ class GroupByWindowTest(test.TestCase): while True: result = sess.run(get_next) self.assertTrue( - all(x % 2 == 0 for x in result) or all(x % 2 == 1) + all(x % 2 == 0 + for x in result) or all(x % 2 == 1) for x in result) counts.append(result.shape[0]) @@ -115,8 +116,8 @@ class GroupByWindowTest(test.TestCase): iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( - grouping.group_by_window(lambda x, _: x % 2, reduce_func, 32)) - .make_initializable_iterator()) + grouping.group_by_window(lambda x, _: x % 2, reduce_func, + 32)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -135,7 +136,8 @@ class GroupByWindowTest(test.TestCase): window.padded_batch( 4, padded_shapes=tensor_shape.TensorShape([None])), window.padded_batch( - 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),)) + 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])), + )) iterator = ( dataset_ops.Dataset.from_tensor_slices(components) @@ -160,6 +162,34 @@ class GroupByWindowTest(test.TestCase): self.assertEqual(len(components), sum(counts)) +class GroupByWindowSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, components): + return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4)) + + def testCoreGroupByWindow(self): + components = np.array( + [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) + self.verify_unused_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_init_before_restore( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_multiple_breaks( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_reset_restored_iterator( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + self.verify_restore_in_empty_graph( + lambda: self._build_dataset(components), 12, verify_exhausted=False) + diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64) + self.verify_restore_in_modified_graph( + lambda: self._build_dataset(components), + lambda: self._build_dataset(diff_components), + 12, + verify_exhausted=False) + + # NOTE(mrry): These tests are based on the tests in bucket_ops_test.py. # Currently, they use a constant batch size, though should be made to use a # different batch size per key. @@ -171,9 +201,10 @@ class BucketTest(test.TestCase): # dynamically and does not rely on static shape information about # the arguments. return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), window.padded_batch( - 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape([None]), - tensor_shape.TensorShape([3]))))) + (dataset_ops.Dataset.from_tensors(bucket), + window.padded_batch( + 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape( + [None]), tensor_shape.TensorShape([3]))))) def testSingleBucket(self): @@ -278,12 +309,13 @@ class BucketTest(test.TestCase): def _dynamic_pad_fn(bucket, window, _): return dataset_ops.Dataset.zip( - (dataset_ops.Dataset.from_tensors(bucket), window.padded_batch( - 32, { - "x": tensor_shape.TensorShape([]), - "y": tensor_shape.TensorShape([None]), - "z": tensor_shape.TensorShape([3]) - }))) + (dataset_ops.Dataset.from_tensors(bucket), + window.padded_batch( + 32, { + "x": tensor_shape.TensorShape([]), + "y": tensor_shape.TensorShape([None]), + "z": tensor_shape.TensorShape([3]) + }))) input_dataset = ( dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn) diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py deleted file mode 100644 index 9818020680afb9d0f0197d272ec5339c6358db36..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py +++ /dev/null @@ -1,300 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from os import path -import shutil -import tempfile - -import numpy as np - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - - -class FilesystemCacheDatasetTest(test.TestCase): - - def setUp(self): - self.tmp_dir = tempfile.mkdtemp() - self.cache_prefix = path.join(self.tmp_dir, "cache") - - def tearDown(self): - if self.tmp_dir: - shutil.rmtree(self.tmp_dir, ignore_errors=True) - - def testCacheDatasetPassthrough(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) - - cache_dataset = repeat_dataset.cache(filename_placeholder) - - self.assertEqual( - tuple([c.shape[1:] for c in components]), cache_dataset.output_shapes) - - # Create initialization ops for iterators without and with - # caching, respectively. - iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types, - cache_dataset.output_shapes) - init_fifo_op = iterator.make_initializer(repeat_dataset) - init_cache_op = iterator.make_initializer(cache_dataset) - - get_next = iterator.get_next() - - with self.test_session() as sess: - # First run without caching to collect the "ground truth". - sess.run(init_fifo_op) - elements = [] - for _ in range(20): - elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Assert that the cached dataset has the same elements as the - # "ground truth". - sess.run( - init_cache_op, feed_dict={filename_placeholder: self.cache_prefix}) - cached_elements = [] - for _ in range(20): - cached_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertAllEqual(elements, cached_elements) - - # Re-initialize with an empty upstream (to throw errors.OutOfRangeError - # if we didn't use the cache). - sess.run( - init_cache_op, - feed_dict={ - count_placeholder: 0, - filename_placeholder: self.cache_prefix - }) - replayed_elements = [] - for _ in range(20): - replayed_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(cached_elements, replayed_elements) - - # Re-initialize with an empty upstream and a missing cache file (should - # throw errors.OutOfRangeError immediately). - sess.run( - init_cache_op, - feed_dict={ - count_placeholder: 0, - filename_placeholder: self.cache_prefix + "nonsense" - }) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcurrentWriters(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - - iterator1 = cache_dataset1.make_initializable_iterator() - iterator2 = cache_dataset2.make_initializable_iterator() - init_cache_op1 = iterator1.initializer - init_cache_op2 = iterator2.initializer - - get_next1 = iterator1.get_next() - get_next2 = iterator2.get_next() - - with self.test_session() as sess: - sess.run( - init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) - sess.run(get_next1) # this should succeed - - sess.run( - init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix}) - with self.assertRaises(errors.AlreadyExistsError): - sess.run(get_next2) - - sess.run(get_next1) # this should continue to succeed - - def testConcurrentReaders(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - cache_dataset1 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - cache_dataset2 = (dataset_ops.Dataset.from_tensor_slices(components) - .cache(filename_placeholder)) - - iterator1 = cache_dataset1.make_initializable_iterator() - iterator2 = cache_dataset2.make_initializable_iterator() - init_cache_op1 = iterator1.initializer - init_cache_op2 = iterator2.initializer - - get_next1 = iterator1.get_next() - get_next2 = iterator2.get_next() - - with self.test_session() as sess: - sess.run( - init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) - elements = [] - for _ in range(4): - elements.append(sess.run(get_next1)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next1) - - # Re-initialize - sess.run( - init_cache_op1, feed_dict={filename_placeholder: self.cache_prefix}) - sess.run( - init_cache_op2, feed_dict={filename_placeholder: self.cache_prefix}) - - # Reading concurrently should succeed. - elements_itr1 = [] - elements_itr2 = [] - elements_itr2.append(sess.run(get_next2)) - elements_itr1.append(sess.run(get_next1)) - elements_itr2.append(sess.run(get_next2)) - elements_itr1.append(sess.run(get_next1)) - # Intentionally reversing the order - elements_itr1.append(sess.run(get_next1)) - elements_itr2.append(sess.run(get_next2)) - elements_itr1.append(sess.run(get_next1)) - elements_itr2.append(sess.run(get_next2)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next2) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next1) - - self.assertAllEqual(elements, elements_itr1) - self.assertAllEqual(elements, elements_itr2) - - -class MemoryCacheDatasetTest(test.TestCase): - - def testCacheDatasetPassthrough(self): - repeat_count = variables.Variable(constant_op.constant(10, dtypes.int64)) - dataset = dataset_ops.Dataset.range(3).flat_map( - lambda x: dataset_ops.Dataset.from_tensors(x).repeat(repeat_count)) - - cached_dataset = dataset.cache().repeat(2) - uncached_dataset = dataset.repeat(2) - - # Needs to be initializable to capture the variable. - cached_iterator = cached_dataset.make_initializable_iterator() - cached_next = cached_iterator.get_next() - uncached_iterator = uncached_dataset.make_initializable_iterator() - uncached_next = uncached_iterator.get_next() - - with self.test_session() as sess: - - sess.run(repeat_count.initializer) - sess.run(cached_iterator.initializer) - sess.run(uncached_iterator.initializer) - - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) - self.assertEqual(sess.run(uncached_next), i) - - sess.run(repeat_count.assign(0)) - - # The uncached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(uncached_next) - - # The cached iterator replays from cache. - for i in range(3): - for _ in range(10): - self.assertEqual(sess.run(cached_next), i) - - # The cached iterator should now be empty. - with self.assertRaises(errors.OutOfRangeError): - sess.run(cached_next) - - def testEmptyCacheReading(self): - components = (np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0])) - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - - repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) - - cache_dataset = repeat_dataset.cache() - - # Create initialization ops for iterators without and with - # caching, respectively. - iterator = cache_dataset.make_initializable_iterator() - init_cache_op = iterator.initializer - - get_next = iterator.get_next() - - with self.test_session() as sess: - # Initialize with an empty upstream and a missing cache file (should - # throw errors.OutOfRangeError immediately). - sess.run(init_cache_op, feed_dict={count_placeholder: 0}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcurrentReaders(self): - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - dataset = dataset_ops.Dataset.range(count_placeholder).cache() - d1 = dataset.map(lambda x: x + 1) - d2 = dataset.map(lambda x: x + 6) - - i1 = d1.make_initializable_iterator() - i2 = d2.make_initializable_iterator() - - with self.test_session() as sess: - sess.run(i1.initializer) - - self.assertEqual(1, sess.run(i1.get_next())) - self.assertEqual(2, sess.run(i1.get_next())) - self.assertEqual(3, sess.run(i1.get_next())) - - sess.run(i2.initializer, feed_dict={count_placeholder: 3}) - - self.assertEqual(6, sess.run(i2.get_next())) - self.assertEqual(7, sess.run(i2.get_next())) - self.assertEqual(4, sess.run(i1.get_next())) # interleave execution - self.assertEqual([8, 5], sess.run([i2.get_next(), i1.get_next()])) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(i1.get_next()) - with self.assertRaises(errors.OutOfRangeError): - sess.run(i2.get_next()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py index 870352209a08e6bc08bcca227ba455ad1851e8bf..17f2980157ddd0350dafd1d745cbb9b64e65f7c5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py @@ -17,255 +17,32 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops -from tensorflow.python.data.util import nest -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib -class ConcatenateDatasetTest(test.TestCase): +class ConcatenateDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - def testConcatenateDataset(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 15), - np.array([37.0, 38.0, 39.0, 40.0])) - to_concatenate_components = ( - np.tile(np.array([[1], [2], [3], [4], [5]]), 20), - np.tile(np.array([[12], [13], [14], [15], [16]]), 15), - np.array([37.0, 38.0, 39.0, 40.0, 41.0])) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - concatenated = input_dataset.concatenate(dataset_to_concatenate) - self.assertEqual(concatenated.output_shapes, (tensor_shape.TensorShape( - [20]), tensor_shape.TensorShape([15]), tensor_shape.TensorShape([]))) - - iterator = concatenated.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(9): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcatenateDatasetDifferentShape(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = ( - np.tile(np.array([[1], [2], [3], [4], [5]]), 20), - np.tile(np.array([[12], [13], [14], [15], [16]]), 15)) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - concatenated = input_dataset.concatenate(dataset_to_concatenate) - self.assertEqual( - [ts.as_list() - for ts in nest.flatten(concatenated.output_shapes)], [[20], [None]]) - - iterator = concatenated.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(9): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testConcatenateDatasetDifferentStructure(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 5), - np.tile(np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = ( - np.tile(np.array([[1], [2], [3], [4], [5]]), 20), - np.tile(np.array([[12], [13], [14], [15], [16]]), 15), - np.array([37.0, 38.0, 39.0, 40.0, 41.0])) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - - with self.assertRaisesRegexp(ValueError, - "don't have the same number of elements"): - input_dataset.concatenate(dataset_to_concatenate) - - def testConcatenateDatasetDifferentType(self): - input_components = ( - np.tile(np.array([[1], [2], [3], [4]]), 5), - np.tile(np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = ( - np.tile(np.array([[1.0], [2.0], [3.0], [4.0]]), 5), - np.tile(np.array([[12], [13], [14], [15]]), 15)) - - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - - with self.assertRaisesRegexp(TypeError, "have different types"): - input_dataset.concatenate(dataset_to_concatenate) - - def _iterator_checkpoint_prefix(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _build_graph(self, input_components, to_concatenate_components): - input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) - dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( - to_concatenate_components) - iterator = input_dataset.concatenate( - dataset_to_concatenate).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - saveable = iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - # TODO(shivaniagrawal) : non-intuitive way, add support in mata_graph - for t in nest.flatten(get_next): - ops.add_to_collection("get_next", t) - return init_op, get_next - - def _testSaveRestoreUtility(self, start, break_range, stop): - path = self._iterator_checkpoint_prefix() - step = 0 - meta_filename = path + "-%d.meta" % step - - input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 4)) - to_concatenate_components = (np.tile( - np.array([[5], [6], [7], [8], [9]]), 20), np.tile( - np.array([[16], [17], [18], [19], [20]]), 15)) - - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph(input_components, - to_concatenate_components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(start, break_range): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - saver.save(sess, path, step) - - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - with self.test_session(graph=g) as sess: - get_next = nest.pack_sequence_as(("a", "b"), - ops.get_collection("get_next")) - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_range, stop): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreAtFirstDataset(self): - start = 0 - stop = 9 - break_range = 3 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreAtSecondDataset(self): - start = 0 - stop = 9 - break_range = 6 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreAtBetweenDatasets(self): - start = 0 - stop = 9 - break_range = 4 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreExhaustedIterator(self): - start = 0 - stop = 9 - break_range = 9 - self._testSaveRestoreUtility(start, break_range, stop) - - def testRestoreInModifiedGraph(self): - start = 0 - stop = 9 - break_range = 6 - path = self._iterator_checkpoint_prefix() - step = 0 - - input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 4)) + def _build_concatenate_dataset(self, var_array): + input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), + np.tile(np.array([[12], [13], [14], [15]]), 4)) to_concatenate_components = (np.tile( - np.array([[5], [6], [7], [8], [9]]), 20), np.tile( - np.array([[16], [17], [18], [19], [20]]), 15)) - - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph(input_components, - to_concatenate_components) - saver = saver_lib.Saver(allow_empty=True) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for i in range(start, break_range): - result = sess.run(get_next) - if i < 4: - for component, result_component in zip(input_components, result): - self.assertAllEqual(component[i], result_component) - else: - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - saver.save(sess, path, step) - - new_to_concatenate_components = (np.array([[5], [6], [7], [8], [9]]), - np.array([[16], [17], [18], [19], [20]])) - with ops.Graph().as_default() as g: - init_op, get_next = self._build_graph(input_components, - new_to_concatenate_components) - saver = saver_lib.Saver() - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_range, stop): - result = sess.run(get_next) - for component, result_component in zip(to_concatenate_components, - result): - self.assertAllEqual(component[i - 4], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + np.array([[5], [6], [7], [8], [9]]), 20), var_array) + + return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate( + dataset_ops.Dataset.from_tensor_slices(to_concatenate_components)) + + def testConcatenateCore(self): + num_outputs = 9 + array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15) + diff_array = np.array([[1], [2], [3], [4], [5]]) + self.run_core_tests(lambda: self._build_concatenate_dataset(array), + lambda: self._build_concatenate_dataset(diff_array), + num_outputs) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index 55a1d3b95b212466b262ad3c26f1efd7ed0e067e..a842502cc6fe3605dde0be5f50cf46e3e37d7ed4 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,712 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import threading - import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import batching -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test class DatasetConstructorTest(test.TestCase): - def testFromTensors(self): - """Test an dataset that represents a single tuple of tensors.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - - iterator = (dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - - def testFromTensorsSparse(self): - """Test an dataset that represents a single tuple of tensors.""" - components = (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1]]), - values=np.array([-1, 1]), - dense_shape=np.array([2, 2]))) - - iterator = ( - dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual( - [tensor_shape.TensorShape(c.dense_shape) for c in components], - [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertSparseValuesEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorsMixed(self): - """Test an dataset that represents a single tuple of tensors.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0), - sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1]]), - values=np.array([-1, 1]), - dense_shape=np.array([2, 2]))) - - iterator = ( - dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([ - tensor_shape.TensorShape(c.dense_shape) - if sparse_tensor.is_sparse(c) else c.shape for c in components - ], [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - if sparse_tensor.is_sparse(component): - self.assertSparseValuesEqual(component, result_component) - else: - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlices(self): - """Test an dataset that represents the slices from a tuple of tensors.""" - components = ( - np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( - np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0]) - ) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - sess.run(init_op) - for i in range(4): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlicesSparse(self): - """Test an dataset that represents the slices from a tuple of tensors.""" - components = (sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 0], [2, 0]]), - values=np.array([0, 0, 0]), - dense_shape=np.array([3, 1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1], [2, 2]]), - values=np.array([1, 2, 3]), - dense_shape=np.array([3, 3]))) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual( - [tensor_shape.TensorShape(c.dense_shape[1:]) for c in components], - [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - expected = [ - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([1]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[1]]), - values=np.array([2]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[2]]), - values=np.array([3]), - dense_shape=np.array([3]))), - ] - for i in range(3): - results = sess.run(get_next) - for component, result_component in zip(expected[i], results): - self.assertSparseValuesEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlicesMixed(self): - """Test an dataset that represents the slices from a tuple of tensors.""" - components = (np.tile(np.array([[1], [2], [3]]), 20), - np.tile(np.array([[12], [13], [14]]), 22), - np.array([37.0, 38.0, 39.0]), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 0], [2, 0]]), - values=np.array([0, 0, 0]), - dense_shape=np.array([3, 1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0], [1, 1], [2, 2]]), - values=np.array([1, 2, 3]), - dense_shape=np.array([3, 3]))) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([ - tensor_shape.TensorShape(c.dense_shape[1:]) - if sparse_tensor.is_sparse(c) else c.shape[1:] for c in components - ], [shape for shape in iterator.output_shapes]) - - with self.test_session() as sess: - sess.run(init_op) - expected = [ - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([1]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[1]]), - values=np.array([2]), - dense_shape=np.array([3]))), - (sparse_tensor.SparseTensorValue( - indices=np.array([[0]]), - values=np.array([0]), - dense_shape=np.array([1])), - sparse_tensor.SparseTensorValue( - indices=np.array([[2]]), - values=np.array([3]), - dense_shape=np.array([3]))), - ] - for i in range(3): - results = sess.run(get_next) - for component, result_component in zip( - (zip(*components[:3])[i] + expected[i]), results): - if sparse_tensor.is_sparse(component): - self.assertSparseValuesEqual(component, result_component) - else: - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromTensorSlicesWithDict(self): - components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual(dtypes.int32, iterator.output_types["foo"]) - self.assertEqual(dtypes.float32, iterator.output_types["bar"]) - self.assertEqual((), iterator.output_shapes["foo"]) - self.assertEqual((1,), iterator.output_shapes["bar"]) - - with self.test_session() as sess: - sess.run(init_op) - for i in range(3): - results = sess.run(get_next) - self.assertEqual(components["foo"][i], results["foo"]) - self.assertEqual(components["bar"][i], results["bar"]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromSparseTensorSlices(self): - """Test a dataset based on slices of a `tf.SparseTensor`.""" - st = array_ops.sparse_placeholder(dtypes.float64) - iterator = (dataset_ops.Dataset.from_sparse_tensor_slices(st) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = sparse_tensor.SparseTensor(*iterator.get_next()) - - with self.test_session() as sess: - slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []] - - # Test with sparse tensor in the appropriate order. - indices = np.array( - [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))]) - values = np.array([val for s in slices for val in s]) - dense_shape = np.array([len(slices), max(len(s) for s in slices) + 1]) - sparse_feed = sparse_tensor.SparseTensorValue(indices, values, - dense_shape) - sess.run(init_op, feed_dict={st: sparse_feed}) - for i, s in enumerate(slices): - results = sess.run(get_next) - self.assertAllEqual(s, results.values) - expected_indices = np.array( - [[j] for j in range(len(slices[i]))]).reshape([-1, 1]) - self.assertAllEqual(expected_indices, results.indices) - self.assertAllEqual(dense_shape[1:], results.dense_shape) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test with sparse tensor in the reverse order, which is not - # currently supported. - reverse_order_indices = indices[::-1, :] - reverse_order_values = values[::-1] - sparse_feed = sparse_tensor.SparseTensorValue( - reverse_order_indices, reverse_order_values, dense_shape) - with self.assertRaises(errors.UnimplementedError): - sess.run(init_op, feed_dict={st: sparse_feed}) - - # Test with an empty sparse tensor. - empty_indices = np.empty((0, 4), dtype=np.int64) - empty_values = np.empty((0,), dtype=np.float64) - empty_dense_shape = [0, 4, 37, 9] - sparse_feed = sparse_tensor.SparseTensorValue(empty_indices, empty_values, - empty_dense_shape) - sess.run(init_op, feed_dict={st: sparse_feed}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # pylint: disable=g-long-lambda,unnecessary-lambda - def testNestedStructure(self): - components = (np.array([1, 2, 3]), (np.array([4., 5.]), np.array([6., 7.])), - np.array([8, 9, 10])) - - dataset = dataset_ops.Dataset.from_tensors(components) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.shuffle(10, 10) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.repeat(-1) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.filter(lambda x, y, z: True) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.take(5) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([3], ([2], [2]), [3]), dataset.output_shapes) - - dataset = dataset.map(lambda x, y, z: ((x, z), (y[0], y[1]))) - self.assertEquals(((dtypes.int64, dtypes.int64), - (dtypes.float64, dtypes.float64)), dataset.output_types) - self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes) - - dataset = dataset.flat_map( - lambda x, y: dataset_ops.Dataset.from_tensors(((x[0], x[1]), - (y[0], y[1]))) - ) - self.assertEquals(((dtypes.int64, dtypes.int64), - (dtypes.float64, dtypes.float64)), dataset.output_types) - self.assertEquals((([3], [3]), ([2], [2])), dataset.output_shapes) - - dataset = dataset.batch(32) - self.assertEquals(((dtypes.int64, dtypes.int64), - (dtypes.float64, dtypes.float64)), dataset.output_types) - self.assertEquals((([None, 3], [None, 3]), ([None, 2], [None, 2])), - nest.pack_sequence_as(dataset.output_shapes, [ - s.as_list() - for s in nest.flatten(dataset.output_shapes) - ])) - - iterator = dataset.make_one_shot_iterator() - (w, x), (y, z) = iterator.get_next() - self.assertEquals(dtypes.int64, w.dtype) - self.assertEquals(dtypes.int64, x.dtype) - self.assertEquals(dtypes.float64, y.dtype) - self.assertEquals(dtypes.float64, z.dtype) - self.assertEquals([None, 3], w.shape.as_list()) - self.assertEquals([None, 3], x.shape.as_list()) - self.assertEquals([None, 2], y.shape.as_list()) - self.assertEquals([None, 2], z.shape.as_list()) - - iterator = dataset.make_initializable_iterator() - (w, x), (y, z) = iterator.get_next() - self.assertEquals(dtypes.int64, w.dtype) - self.assertEquals(dtypes.int64, x.dtype) - self.assertEquals(dtypes.float64, y.dtype) - self.assertEquals(dtypes.float64, z.dtype) - self.assertEquals([None, 3], w.shape.as_list()) - self.assertEquals([None, 3], x.shape.as_list()) - self.assertEquals([None, 2], y.shape.as_list()) - self.assertEquals([None, 2], z.shape.as_list()) - - # Define a separate set of components with matching leading - # dimension for the from-slices constructor. - components_for_slices = (np.array([1, 2, 3]), (np.array( - [4., 5., 6.]), np.array([7., 8., 9.])), np.array([10, 11, 12])) - - dataset = dataset_ops.Dataset.from_tensor_slices(components_for_slices) - self.assertEquals((dtypes.int64, (dtypes.float64, dtypes.float64), - dtypes.int64), dataset.output_types) - self.assertEquals(([], ([], []), []), dataset.output_shapes) - - def testNestedDict(self): - components = {"a": {"aa": 1, "ab": [2.0, 2.0]}, "b": [3, 3, 3]} - dataset = dataset_ops.Dataset.from_tensors(components) - self.assertEquals(dtypes.int32, dataset.output_types["a"]["aa"]) - self.assertEquals(dtypes.float32, dataset.output_types["a"]["ab"]) - self.assertEquals(dtypes.int32, dataset.output_types["b"]) - self.assertEquals([], dataset.output_shapes["a"]["aa"]) - self.assertEquals([2], dataset.output_shapes["a"]["ab"]) - self.assertEquals([3], dataset.output_shapes["b"]) - - def testNonSequenceNestedStructure(self): - components = np.array([1, 2, 3]) - - dataset = dataset_ops.Dataset.from_tensors(components) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([3], dataset.output_shapes) - - dataset = dataset.filter( - lambda x: math_ops.reduce_all(math_ops.equal(x, components))) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([3], dataset.output_shapes) - - dataset = dataset.map(lambda x: array_ops.stack([x, x])) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([2, 3], dataset.output_shapes) - - dataset = dataset.flat_map( - lambda x: dataset_ops.Dataset.from_tensor_slices(x)) - self.assertEquals(dtypes.int64, dataset.output_types) - self.assertEquals([3], dataset.output_shapes) - - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - self.assertEquals(dtypes.int64, get_next.dtype) - self.assertEquals([3], get_next.shape) - - def _testFromGenerator(self, generator, elem_sequence, num_repeats): - iterator = ( - dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) - .repeat(num_repeats) - .prefetch(5) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - for _ in range(2): # Run twice to test reinitialization. - sess.run(init_op) - for _ in range(num_repeats): - for elem in elem_sequence: - self.assertAllEqual(elem, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def _testFromGeneratorOneShot(self, generator, elem_sequence, num_repeats): - iterator = ( - dataset_ops.Dataset.from_generator(generator, output_types=dtypes.int64) - .repeat(num_repeats) - .prefetch(5) - .make_one_shot_iterator()) - get_next = iterator.get_next() - - with self.test_session() as sess: - for _ in range(num_repeats): - for elem in elem_sequence: - self.assertAllEqual(elem, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorUsingFunction(self): - def generator(): - for i in range(1, 100): - yield [i] * i - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - self._testFromGeneratorOneShot(generator, elem_sequence, 1) - self._testFromGeneratorOneShot(generator, elem_sequence, 5) - - def testFromGeneratorUsingList(self): - generator = lambda: [[i] * i for i in range(1, 100)] - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - - def testFromGeneratorUsingNdarray(self): - generator = lambda: np.arange(100, dtype=np.int64) - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - - def testFromGeneratorUsingGeneratorExpression(self): - # NOTE(mrry): Generator *expressions* are not repeatable (or in - # general reusable), because they eagerly evaluate the `for` - # expression as `iter(range(1, 100))` and discard the means of - # reconstructing `range(1, 100)`. Wrapping the generator - # expression in a `lambda` makes it repeatable. - generator = lambda: ([i] * i for i in range(1, 100)) - elem_sequence = list(generator()) - self._testFromGenerator(generator, elem_sequence, 1) - self._testFromGenerator(generator, elem_sequence, 5) - - def testFromMultipleConcurrentGenerators(self): - num_inner_repeats = 5 - num_outer_repeats = 100 - - def generator(): - for i in range(1, 10): - yield ([i] * i, [i, i ** 2, i ** 3]) - input_list = list(generator()) - - # The interleave transformation is essentially a flat map that - # draws from multiple input datasets concurrently (in a cyclic - # fashion). By placing `Datsaet.from_generator()` inside an - # interleave, we test its behavior when multiple iterators are - # active at the same time; by additionally prefetching inside the - # interleave, we create the possibility of parallel (modulo GIL) - # invocations to several iterators created by the same dataset. - def interleave_fn(_): - return (dataset_ops.Dataset.from_generator( - generator, output_types=(dtypes.int64, dtypes.int64), - output_shapes=([None], [3])) - .repeat(num_inner_repeats).prefetch(5)) - - iterator = ( - dataset_ops.Dataset.range(num_outer_repeats) - .interleave(interleave_fn, cycle_length=10, - block_length=len(input_list)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(num_inner_repeats * num_outer_repeats): - for elem in input_list: - val0, val1 = sess.run(get_next) - self.assertAllEqual(elem[0], val0) - self.assertAllEqual(elem[1], val1) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorsRunningInParallel(self): - num_parallel_iterators = 3 - - # Define shared state that multiple iterator instances will access to - # demonstrate their concurrent activity. - lock = threading.Lock() - condition = threading.Condition(lock) - next_ticket = [0] # GUARDED_BY(lock) - - def generator(): - # NOTE(mrry): We yield one element before the barrier, because - # the current implementation of `Dataset.interleave()` must - # fetch one element from each incoming dataset to start the - # prefetching. - yield 0 - - # Define a barrier that `num_parallel_iterators` iterators must enter - # before any can proceed. Demonstrates that multiple iterators may be - # active at the same time. - condition.acquire() - ticket = next_ticket[0] - next_ticket[0] += 1 - if ticket == num_parallel_iterators - 1: - # The last iterator to join the barrier notifies the others. - condition.notify_all() - else: - # Wait until the last iterator enters the barrier. - while next_ticket[0] < num_parallel_iterators: - condition.wait() - condition.release() - - yield 1 - - # As in `testFromMultipleConcurrentGenerators()`, we use a combination of - # `Dataset.interleave()` and `Dataset.prefetch()` to cause multiple - # iterators to be active concurrently. - def interleave_fn(_): - return dataset_ops.Dataset.from_generator( - generator, output_types=dtypes.int64, output_shapes=[]).prefetch(2) - - iterator = ( - dataset_ops.Dataset.range(num_parallel_iterators) - .interleave( - interleave_fn, cycle_length=num_parallel_iterators, block_length=1) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for elem in [0, 1]: - for _ in range(num_parallel_iterators): - self.assertAllEqual(elem, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorImplicitConversion(self): - def generator(): - yield [1] - yield [2] - yield [3] - - for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]: - iterator = (dataset_ops.Dataset.from_generator( - generator, output_types=dtype, output_shapes=[1]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual(dtype, get_next.dtype) - - with self.test_session() as sess: - sess.run(init_op) - for expected in [[1], [2], [3]]: - next_val = sess.run(get_next) - self.assertEqual(dtype.as_numpy_dtype, next_val.dtype) - self.assertAllEqual(expected, next_val) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorTypeError(self): - def generator(): - yield np.array([1, 2, 3], dtype=np.int64) - yield np.array([4, 5, 6], dtype=np.int64) - yield "ERROR" - yield np.array([7, 8, 9], dtype=np.int64) - - iterator = (dataset_ops.Dataset.from_generator( - generator, output_types=dtypes.int64, output_shapes=[3]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - self.assertAllEqual([4, 5, 6], sess.run(get_next)) - with self.assertRaisesOpError(r"invalid literal for long\(\)"): - sess.run(get_next) - self.assertAllEqual([7, 8, 9], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFromGeneratorShapeError(self): - def generator(): - yield np.array([1, 2, 3], dtype=np.int64) - yield np.array([4, 5, 6], dtype=np.int64) - yield np.array([7, 8, 9, 10], dtype=np.int64) - yield np.array([11, 12, 13], dtype=np.int64) - - iterator = (dataset_ops.Dataset.from_generator( - generator, output_types=dtypes.int64, output_shapes=[3]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - self.assertAllEqual([4, 5, 6], sess.run(get_next)) - with self.assertRaisesOpError(r"element of shape \(3,\) was expected"): - sess.run(get_next) - self.assertAllEqual([11, 12, 13], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSplitPipelineFailsWithPlacementError(self): - with session.Session( - target="", - config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: - - dataset = dataset_ops.Dataset.from_tensors(0) - - # Define a pipeline that attempts to use variables on two - # different devices. - # - # Initialize the variables before creating to iterator, to avoid the - # placement algorithm overriding the DT_RESOURCE colocation constraints. - with ops.device("/cpu:0"): - var_0 = resource_variable_ops.ResourceVariable(initial_value=0) - dataset = dataset.map(lambda x: x + var_0.read_value()) - sess.run(var_0.initializer) - - with ops.device("/cpu:1"): - var_1 = resource_variable_ops.ResourceVariable(initial_value=0) - dataset = dataset.map(lambda x: x + var_1.read_value()) - sess.run(var_1.initializer) - - iterator = dataset.make_initializable_iterator() - - with self.assertRaisesRegexp( - errors.InvalidArgumentError, - "Trying to access resource located in device"): - sess.run(iterator.initializer) - def testRestructureDataset(self): components = (array_ops.placeholder(dtypes.int32), (array_ops.placeholder(dtypes.int32, shape=[None]), diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py index bf25cc60a1c0efc09bed6501fd2d6f4ccb07764b..dbc35097ddda9f0375060d43aeb43efa8107f929 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -24,9 +24,11 @@ import numpy as np from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test @@ -34,12 +36,29 @@ from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import nest +def remove_variants(get_next_op): + # TODO(b/72408568): Remove this once session.run can get + # variant tensors. + """Remove variants from a nest structure, so sess.run will execute.""" + + def _remove_variant(x): + if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant: + return () + else: + return x + + return nest.map_structure(_remove_variant, get_next_op) + + class DatasetSerializationTestBase(test.TestCase): """Base class for testing serializable datasets.""" def tearDown(self): self._delete_ckpt() + # TODO(b/72657739): Remove sparse_tensor argument, which is to test the + # (deprecated) saveable `SparseTensorSliceDataset`, once the API + # `from_sparse_tensor_slices()`and related tests are deleted. def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False): """Runs the core tests. @@ -231,10 +250,10 @@ class DatasetSerializationTestBase(test.TestCase): saver = self._import_meta_graph() init_op, get_next_op = self._get_iterator_ops_from_collection( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) for _ in range(num_outputs): actual.append(sess.run(get_next_op)) if verify_exhausted: @@ -294,6 +313,7 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: _, get_next_op, saver = self._build_graph( ds_fn2, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): @@ -354,6 +374,7 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: get_next_op, saver = self._build_empty_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: self._restore(saver, sess) for _ in range(num_outputs - break_point): @@ -387,9 +408,9 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default() as g: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) for _ in range(break_point): sess.run(get_next_op) with self.assertRaises(error): @@ -483,20 +504,20 @@ class DatasetSerializationTestBase(test.TestCase): else: init_op, get_next_op, saver = self._build_graph( ds_fn, sparse_tensors=sparse_tensors) + get_next_op = remove_variants(get_next_op) return init_op, get_next_op, saver for i in range(len(break_points) + 1): with ops.Graph().as_default() as g: init_op, get_next_op, saver = get_ops() + get_next_op = remove_variants(get_next_op) with self.test_session(graph=g) as sess: if ckpt_saved: if init_before_restore: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) self._restore(saver, sess) else: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) + self._initialize(init_op, sess) start = break_points[i - 1] if i > 0 else 0 end = break_points[i] if i < len(break_points) else num_outputs num_iters = end - start @@ -560,13 +581,16 @@ class DatasetSerializationTestBase(test.TestCase): get_next = sparse_tensor.SparseTensor(*iterator.get_next()) else: get_next = iterator.get_next() - self._add_iterator_ops_to_collection(init_op, get_next, sparse_tensors) + self._add_iterator_ops_to_collection(init_op, get_next, ds_fn, + sparse_tensors) saver = saver_lib.Saver(allow_empty=True) return init_op, get_next, saver def _build_empty_graph(self, ds_fn, sparse_tensors=False): iterator = iterator_ops.Iterator.from_structure( - self._get_output_types(ds_fn), self._get_output_shapes(ds_fn)) + self._get_output_types(ds_fn), + output_shapes=self._get_output_shapes(ds_fn), + output_classes=self._get_output_classes(ds_fn)) saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) if sparse_tensors: @@ -579,12 +603,19 @@ class DatasetSerializationTestBase(test.TestCase): def _add_iterator_ops_to_collection(self, init_op, get_next, + ds_fn, sparse_tensors=False): ops.add_to_collection("iterator_ops", init_op) # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections # do not support tuples we flatten the tensors and restore the shape in # `_get_iterator_ops_from_collection`. - if sparse_tensors: + + # TODO(shivaniagrwal): `output_classes` is a nested structure of classes, + # this base class is specific to current test cases. Update when tests are + # added with `output_classes` as a nested structure with at least one of the + # component being `tf.SparseTensor`. + if (sparse_tensors or + self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): ops.add_to_collection("iterator_ops", get_next.indices) ops.add_to_collection("iterator_ops", get_next.values) ops.add_to_collection("iterator_ops", get_next.dense_shape) @@ -594,7 +625,8 @@ class DatasetSerializationTestBase(test.TestCase): def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False): all_ops = ops.get_collection("iterator_ops") - if sparse_tensors: + if (sparse_tensors or + self._get_output_classes(ds_fn) is sparse_tensor.SparseTensor): init_op, indices, values, dense_shape = all_ops return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape) else: @@ -609,6 +641,10 @@ class DatasetSerializationTestBase(test.TestCase): with ops.Graph().as_default(): return ds_fn().output_shapes + def _get_output_classes(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_classes + def _ckpt_path(self): return os.path.join(self.get_temp_dir(), "iterator") @@ -619,8 +655,14 @@ class DatasetSerializationTestBase(test.TestCase): saver.save(sess, self._ckpt_path()) def _restore(self, saver, sess): + sess.run(lookup_ops.tables_initializer()) saver.restore(sess, self._latest_ckpt()) + def _initialize(self, init_op, sess): + sess.run(variables.global_variables_initializer()) + sess.run(lookup_ops.tables_initializer()) + sess.run(init_op) + def _import_meta_graph(self): meta_file_path = self._ckpt_path() + ".meta" return saver_lib.import_meta_graph(meta_file_path) diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py index 5921be2ae89ba1bbbb8d6e3a509cf49c65949544..b572d6ed770fc0fe0f852359baf343c55966eddd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py @@ -20,144 +20,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import sparse_tensor -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test -class FilterDatasetTest(test.TestCase): - - def testFilterDataset(self): - components = ( - np.arange(7, dtype=np.int64), - np.array([[1, 2, 3]], dtype=np.int64) * np.arange( - 7, dtype=np.int64)[:, np.newaxis], - np.array(37.0, dtype=np.float64) * np.arange(7) - ) - count = array_ops.placeholder(dtypes.int64, shape=[]) - modulus = array_ops.placeholder(dtypes.int64) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count) - .filter(lambda x, _y, _z: math_ops.equal(math_ops.mod(x, modulus), 0)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Test that we can dynamically feed a different modulus value for each - # iterator. - def do_test(count_val, modulus_val): - sess.run(init_op, feed_dict={count: count_val, modulus: modulus_val}) - for _ in range(count_val): - for i in [x for x in range(7) if x**2 % modulus_val == 0]: - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - do_test(14, 2) - do_test(4, 18) - - # Test an empty dataset. - do_test(0, 1) - - def testFilterRange(self): - dataset = dataset_ops.Dataset.range(100).filter( - lambda x: math_ops.not_equal(math_ops.mod(x, 3), 2)) - iterator = dataset.make_one_shot_iterator() - get_next = iterator.get_next() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(get_next)) - self.assertEqual(1, sess.run(get_next)) - self.assertEqual(3, sess.run(get_next)) - - def testFilterDict(self): - iterator = (dataset_ops.Dataset.range(10) - .map(lambda x: {"foo": x * 2, "bar": x ** 2}) - .filter(lambda d: math_ops.equal(d["bar"] % 2, 0)) - .map(lambda d: d["foo"] + d["bar"]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - if (i ** 2) % 2 == 0: - self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testUseStepContainerInFilter(self): - input_data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int64) - - # Define a predicate that returns true for the first element of - # the sequence and not the second, and uses `tf.map_fn()`. - def _predicate(xs): - squared_xs = functional_ops.map_fn(lambda x: x * x, xs) - summed = math_ops.reduce_sum(squared_xs) - return math_ops.equal(summed, 1 + 4 + 9) - - iterator = ( - dataset_ops.Dataset.from_tensor_slices([[1, 2, 3], [4, 5, 6]]) - .filter(_predicate) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual(input_data[0], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) - - def testSparse(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])), i - - def _filter_fn(_, i): - return math_ops.equal(i % 2, 0) - - iterator = ( - dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( - lambda x, i: x).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(5): - actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, _map_fn(i * 2)[0]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - class FilterDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): @@ -194,6 +62,10 @@ class FilterDatasetSerializationTest( return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map( lambda x, i: x) + def testSparseCore(self): + num_outputs = 5 + self.run_core_tests(self._build_sparse_filter, None, num_outputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py index d4fbaa5cdcdd315aa0524134b48eb0515169722c..f3feecef32e587045be25056815315136a883ca7 100644 --- a/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/flat_map_dataset_op_test.py @@ -17,13 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import random - -import numpy as np - from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.client import session +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -34,124 +29,6 @@ from tensorflow.python.ops import random_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test -from tensorflow.python.training import server_lib - - -class FlatMapDatasetTest(test.TestCase): - - # pylint: disable=g-long-lambda - def testFlatMapDataset(self): - repeats = [1, 2, 3, 4, 5, 0, 1] - components = np.array(repeats, dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensors([x]).repeat(x)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in repeats: - for _ in range(i): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNestedFlatMapDataset(self): - repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] - components = np.array(repeats, dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) - .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) - .repeat(y))).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for row in repeats: - for i in row: - for _ in range(i): - self.assertEqual(i, sess.run(get_next)) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSharedResourceNestedFlatMapDataset(self): - repeats = [[1, 2], [3, 4], [5, 0], [1, 7]] - components = np.array(repeats, dtype=np.int64) - iterator = ( - dataset_ops.Dataset.from_tensor_slices(components) - .flat_map(lambda x: dataset_ops.Dataset.from_tensor_slices(x) - .flat_map(lambda y: dataset_ops.Dataset.from_tensors(y) - .repeat(y))).make_initializable_iterator( - shared_name="shared_flat_map_iterator")) - init_op = iterator.initializer - get_next = iterator.get_next() - - # Create two concurrent sessions that share the same iterator - # resource on the same server, and verify that a random - # interleaving of `Session.run(get_next)` calls on the two - # sessions yields the expected result. - server = server_lib.Server.create_local_server() - with session.Session(server.target) as sess1: - with session.Session(server.target) as sess2: - for _ in range(3): - sess = random.choice([sess1, sess2]) - sess.run(init_op) - for row in repeats: - for i in row: - for _ in range(i): - sess = random.choice([sess1, sess2]) - self.assertEqual(i, sess.run(get_next)) - - with self.assertRaises(errors.OutOfRangeError): - sess = random.choice([sess1, sess2]) - sess.run(get_next) - - def testMapDict(self): - iterator = (dataset_ops.Dataset.range(10) - .map(lambda x: {"foo": x * 2, "bar": x ** 2}) - .flat_map(lambda d: dataset_ops.Dataset.from_tensors(d["foo"]) - .repeat(d["bar"])) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - for _ in range(i ** 2): - self.assertEqual(i * 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - # pylint: enable=g-long-lambda - - def testSparse(self): - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _flat_map_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - iterator = ( - dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - for j in range(2): - expected = [i, 0] if j % 2 == 0 else [0, -i] - self.assertAllEqual(expected, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) class FlatMapDatasetSerializationTest( @@ -225,6 +102,21 @@ class FlatMapDatasetSerializationTest( self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError) + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _flat_map_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_ds(): + return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn) + + self.run_core_tests(_build_ds, None, 20) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py new file mode 100644 index 0000000000000000000000000000000000000000..32ea44f7c7ba329dc253bb9fbbcac0a1ed16aec7 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import get_single_element +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class GetSingleElementTest(test.TestCase): + + def testGetSingleElement(self): + skip_value = array_ops.placeholder(dtypes.int64, shape=[]) + take_value = array_ops.placeholder_with_default( + constant_op.constant(1, dtype=dtypes.int64), shape=[]) + + dataset = (dataset_ops.Dataset.range(100) + .skip(skip_value) + .map(lambda x: x * x) + .take(take_value)) + + element = get_single_element.get_single_element(dataset) + + with self.test_session() as sess: + self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) + self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) + self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset was empty."): + sess.run(element, feed_dict={skip_value: 100}) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset had more than one element."): + sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index e66ed3f7aa2a512813ef353d2d0744ae67005884..256ad8d94dc1a7c2b26df3f1ebf8e8e321882c15 100644 --- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -26,8 +26,8 @@ import numpy as np from six.moves import zip_longest from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import interleave_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import sparse_tensor @@ -38,181 +38,7 @@ from tensorflow.python.ops import sparse_ops from tensorflow.python.platform import test -class InterleaveDatasetTest(test.TestCase): - - def _interleave(self, lists, cycle_length, block_length): - num_open = 0 - - # `all_iterators` acts as a queue of iterators over each element of `lists`. - all_iterators = [iter(l) for l in lists] - - # `open_iterators` are the iterators whose elements are currently being - # interleaved. - open_iterators = [] - for i in range(cycle_length): - if all_iterators: - open_iterators.append(all_iterators.pop(0)) - num_open += 1 - else: - open_iterators.append(None) - - while num_open or all_iterators: - for i in range(cycle_length): - if open_iterators[i] is None: - if all_iterators: - open_iterators[i] = all_iterators.pop(0) - num_open += 1 - else: - continue - for _ in range(block_length): - try: - yield next(open_iterators[i]) - except StopIteration: - open_iterators[i] = None - num_open -= 1 - break - - def testPythonImplementation(self): - input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], - [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] - - # Cycle length 1 acts like `Dataset.flat_map()`. - expected_elements = itertools.chain(*input_lists) - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 1, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1. - expected_elements = [4, 5, 4, 5, 4, 5, 4, - 5, 5, 6, 6, # NOTE(mrry): When we cycle back - # to a list and are already at - # the end of that list, we move - # on to the next element. - 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 1)): - self.assertEqual(expected, produced) - - # Cycle length > 1 and block length > 1. - expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, - 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 2, 3)): - self.assertEqual(expected, produced) - - # Cycle length > len(input_values). - expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, - 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - for expected, produced in zip( - expected_elements, self._interleave(input_lists, 7, 2)): - self.assertEqual(expected, produced) - - def testInterleaveDataset(self): - input_values = array_ops.placeholder(dtypes.int64, shape=[None]) - cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) - block_length = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_count = 2 - - dataset = ( - dataset_ops.Dataset.from_tensor_slices(input_values) - .repeat(repeat_count) - .interleave(lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), - cycle_length, block_length)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - next_element = iterator.get_next() - - with self.test_session() as sess: - # Cycle length 1 acts like `Dataset.flat_map()`. - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 1, block_length: 3}) - - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 1, 3): - self.assertEqual(expected_element, sess.run(next_element)) - - # Cycle length > 1. - # expected: [4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, - # 6, 5, 6, 5, 6, 5, 6, 5] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 1}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 1): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > 1 and block length > 1. - # expected: [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, 4, 5, - # 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Cycle length > len(input_values) * repeat_count. - # expected: [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, - # 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] - sess.run(init_op, feed_dict={input_values: [4, 5, 6], - cycle_length: 7, block_length: 2}) - for expected_element in self._interleave( - [[4] * 4, [5] * 5, [6] * 6] * repeat_count, 7, 2): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Empty input. - sess.run(init_op, feed_dict={input_values: [], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Non-empty input leading to empty output. - sess.run(init_op, feed_dict={input_values: [0, 0, 0], - cycle_length: 2, block_length: 3}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Mixture of non-empty and empty interleaved datasets. - sess.run(init_op, feed_dict={input_values: [4, 0, 6], - cycle_length: 2, block_length: 3}) - for expected_element in self._interleave( - [[4] * 4, [], [6] * 6] * repeat_count, 2, 3): - self.assertEqual(expected_element, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testSparse(self): - - def _map_fn(i): - return sparse_tensor.SparseTensorValue( - indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) - - def _interleave_fn(x): - return dataset_ops.Dataset.from_tensor_slices( - sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) - - iterator = ( - dataset_ops.Dataset.range(10).map(_map_fn).interleave( - _interleave_fn, cycle_length=1).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - for j in range(2): - expected = [i, 0] if j % 2 == 0 else [0, -i] - self.assertAllEqual(expected, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - -class InterleaveDatasetSeriazationTest( +class InterleaveDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def _build_iterator_graph(self, input_values, cycle_length, block_length): @@ -251,15 +77,35 @@ class InterleaveDatasetSeriazationTest( None, num_outputs) # pylint: enable=g-long-lambda + def testSparseCore(self): + + def _map_fn(i): + return sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) + + def _interleave_fn(x): + return dataset_ops.Dataset.from_tensor_slices( + sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) + + def _build_dataset(): + return dataset_ops.Dataset.range(10).map(_map_fn).interleave( + _interleave_fn, cycle_length=1) + + self.run_core_tests(_build_dataset, None, 20) + class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): + self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) self.sloppy = array_ops.placeholder(dtypes.bool, shape=[]) + self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[]) + self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[]) + self.error = None self.repeat_count = 2 # Set up threading events used to sequence when items are produced that @@ -276,6 +122,10 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.write_coordination_events[x].wait() self.write_coordination_events[x].clear() self.read_coordination_events[x].release() + if self.error: + err = self.error + self.error = None + raise err # pylint: disable=raising-bad-type return x * x def map_fn(x): @@ -286,11 +136,13 @@ class ParallelInterleaveDatasetTest(test.TestCase): dataset = dataset.repeat(x) return dataset.map(map_fn) - self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) - .repeat(self.repeat_count).apply( - interleave_ops.parallel_interleave( - interleave_fn, self.cycle_length, - self.block_length, self.sloppy))) + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) self.iterator = self.dataset.make_initializable_iterator() self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() @@ -380,7 +232,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): for i in range(4, 7): self.write_coordination_events[i].set() - def _testSingleThreaded(self, sloppy=False): + def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. with self.test_session() as sess: @@ -391,7 +243,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 1, self.block_length: 1, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, }) for expected_element in self._interleave( @@ -408,6 +262,41 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testSingleThreadedSloppy(self): self._testSingleThreaded(sloppy=True) + def testSingleThreadedPrefetch1Itr(self): + self._testSingleThreaded(prefetch_input_elements=1) + + def testSingleThreadedPrefetch1ItrSloppy(self): + self._testSingleThreaded(prefetch_input_elements=1, sloppy=True) + + def testSingleThreadedRagged(self): + # Tests a sequence with wildly different elements per iterator. + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [3, 7, 4], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, + }) + + # Add coordination values for 3 and 7 + self.read_coordination_events[3] = threading.Semaphore(0) + self.write_coordination_events[3] = threading.Event() + self.read_coordination_events[7] = threading.Semaphore(0) + self.write_coordination_events[7] = threading.Event() + + for expected_element in self._interleave( + [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1): + self.write_coordination_events[expected_element].set() + output = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, output) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior @@ -420,7 +309,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 1, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -463,6 +354,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -472,7 +365,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True @@ -502,7 +395,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -545,7 +440,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 2, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -555,7 +452,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.read_coordination_events[expected_element].acquire() else: self.write_coordination_events[expected_element].set() - time.sleep(0.1) # Sleep to consistently "avoid" the race condition. + time.sleep(0.5) # Sleep to consistently "avoid" the race condition. actual_element = sess.run(self.next_element) if not done_first_event: done_first_event = True @@ -583,7 +480,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [], self.cycle_length: 2, self.block_length: 3, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) @@ -604,7 +503,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [0, 0, 0], self.cycle_length: 2, self.block_length: 3, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) @@ -615,7 +516,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): def testNonEmptyInputIntoEmptyOutputsSloppy(self): self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) - def _testPartiallyEmptyOutputs(self, sloppy=False): + def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1): + race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds # Mixture of non-empty and empty interleaved datasets. with self.test_session() as sess: self._clear_coordination_events() @@ -627,27 +529,31 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: prefetch_input_elements, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): self.write_coordination_events[expected_element].set() - if done_first_event: # First event starts the worker threads + # First event starts the worker threads. Additionally, when running the + # sloppy case with prefetch_input_elements=0, we get stuck if we wait + # for the read coordination event for certain event orderings in the + # presence of finishing iterators. + if done_first_event and not (sloppy and (i in race_indices)): self.read_coordination_events[expected_element].acquire() actual_element = sess.run(self.next_element) - if not done_first_event: + if not done_first_event or (sloppy and (i in race_indices)): done_first_event = True self.read_coordination_events[expected_element].acquire() self.assertEqual(expected_element * expected_element, actual_element, "At index %s: %s expected, got: %s" % (i, expected_element, actual_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.next_element) def testPartiallyEmptyOutputs(self): self._testPartiallyEmptyOutputs() def testPartiallyEmptyOutputsSloppy(self): - self._testPartiallyEmptyOutputs(sloppy=True) + self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0) def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid @@ -661,6 +567,8 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.cycle_length: 2, self.block_length: 1, self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) mis_ordering = [ @@ -683,8 +591,10 @@ class ParallelInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 3, - self.sloppy: True + self.block_length: 1, + self.sloppy: True, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 1, }) # Test against a generating sequence that differs from the uncontended # case, in order to prove sloppy correctness. @@ -692,7 +602,7 @@ class ParallelInterleaveDatasetTest(test.TestCase): self._interleave( [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, cycle_length=2, - block_length=2)): + block_length=3)): self.write_coordination_events[expected_element].set() if done_first_event: # First event starts the worker threads. self.read_coordination_events[expected_element].acquire() @@ -716,7 +626,9 @@ class ParallelInterleaveDatasetTest(test.TestCase): self.input_values: [4, 5, 6], self.cycle_length: 3, self.block_length: 2, - self.sloppy: sloppy + self.sloppy: sloppy, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, }) for i in range(4, 7): self.write_coordination_events[i].set() @@ -790,6 +702,139 @@ class ParallelInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testErrorsInOutputFn(self): + with self.test_session() as sess: + self._clear_coordination_events() + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + + except_on_element_indices = set([3]) + + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, + 1)): + if i in except_on_element_indices: + self.error = ValueError() + self.write_coordination_events[expected_element].set() + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + self.write_coordination_events[expected_element].set() + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element * expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInputFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def map_fn(x): + return script_ops.py_func(map_py_fn, [x], x.dtype) + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + dataset = dataset.repeat(x) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + + def testErrorsInInterleaveFn(self): + + def map_py_fn(x): + if x == 5: + raise ValueError() + return x + + def interleave_fn(x): + dataset = dataset_ops.Dataset.from_tensors(x) + y = script_ops.py_func(map_py_fn, [x], x.dtype) + dataset = dataset.repeat(y) + return dataset + + self.dataset = ( + dataset_ops.Dataset.from_tensor_slices(self.input_values) + .repeat(self.repeat_count).apply( + interleave_ops.parallel_interleave(interleave_fn, self.cycle_length, + self.block_length, self.sloppy, + self.buffer_output_elements, + self.prefetch_input_elements))) + + self.iterator = self.dataset.make_initializable_iterator() + self.init_op = self.iterator.initializer + self.next_element = self.iterator.get_next() + + with self.test_session() as sess: + sess.run( + self.init_op, + feed_dict={ + self.input_values: [4, 5, 6], + self.cycle_length: 2, + self.block_length: 1, + self.sloppy: False, + self.buffer_output_elements: 1, + self.prefetch_input_elements: 0, + }) + for i, expected_element in enumerate( + self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)): + if expected_element == 5: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(self.next_element) + else: + actual_element = sess.run(self.next_element) + self.assertEqual(expected_element, actual_element, + "At index %s: %s expected, got: %s" % + (i, expected_element, actual_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(self.next_element) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py deleted file mode 100644 index 02379d064d4ab857ce9c7d13881a3ae37eea0980..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops that need test_util.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.platform import test - - -class IteratorClusterTest(test.TestCase): - - def testRemoteIteratorWithoutRemoteCallFail(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - with ops.device("/job:worker/replica:0/task:0/cpu:1"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - with ops.device("/job:worker/replica:0/task:0/cpu:0"): - remote_it = iterator_ops.Iterator.from_string_handle( - iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes) - get_next_op = remote_it.get_next() - - with session.Session(worker[0].target) as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next_op) - - def _testRemoteIteratorHelper(self, device0, device1, target): - with ops.device(device1): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device(device0): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with session.Session(target) as sess: - elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) - self.assertEqual(elem, [1]) - # Fails when target is cpu:0 where the resource is not located. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(remote_op, feed_dict={target_placeholder: device0}) - elem = sess.run(iterator_3.get_next()) - self.assertEqual(elem, [2]) - elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(remote_op, feed_dict={target_placeholder: device1}) - - def testRemoteIteratorUsingRemoteCallOp(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", - "/job:worker/replica:0/task:0/cpu:1", - worker[0].target) - - def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): - workers, _ = test_util.create_local_cluster(2, 1) - - self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", - "/job:worker/replica:0/task:1/cpu:0", - workers[0].target) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py deleted file mode 100644 index bda9a2a4a37e9c3d35ff99041d1150ffc43f4c43..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import numpy as np - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import readers -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import session -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import functional_ops -from tensorflow.python.ops import gen_dataset_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import io_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import script_ops -from tensorflow.python.platform import test -from tensorflow.python.training import server_lib - - -class IteratorTest(test.TestCase): - - def testAttemptingGradientsRaiseExceptions(self): - component = constant_op.constant([1]) - side = constant_op.constant(0) - add = lambda x: x + side - dataset = dataset_ops.Dataset.from_tensor_slices(component).map(add) - value = dataset.make_one_shot_iterator().get_next() - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, component) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, side) - with self.assertRaisesRegexp(LookupError, "No gradient defined"): - gradients_impl.gradients(value, [component, side]) - - def testOneShotIterator(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(14).make_one_shot_iterator()) - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorCaptureByValue(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - tensor_components = tuple([ops.convert_to_tensor(c) for c in components]) - - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - - iterator = (dataset_ops.Dataset.from_tensor_slices(tensor_components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorInsideContainer(self): - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - - def within_container(): - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .map(_map_fn).repeat(14).make_one_shot_iterator()) - return iterator.get_next() - - server = server_lib.Server.create_local_server() - - # Create two iterators within unique containers, and run them to - # make sure that the resources aren't shared. - # - # The test below would fail if cname were the same across both - # sessions. - for i in range(2): - with session.Session(server.target) as sess: - cname = "iteration%d" % i - with ops.container(cname): - get_next = within_container() - - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testOneShotIteratorNonBlocking(self): - dataset = dataset_ops.Dataset.from_tensors([1, 2, 3]).map(lambda x: x * x) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - # Create a session with a single thread to ensure that the - # one-shot iterator initializer does not deadlock. - config = config_pb2.ConfigProto(inter_op_parallelism_threads=1, - use_per_session_threads=True) - with session.Session(config=config) as sess: - self.assertAllEqual([1, 4, 9], sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - # Test with multiple threads invoking the one-shot iterator concurrently. - with session.Session(config=config) as sess: - results = [] - def consumer_thread(): - try: - results.append(sess.run(next_element)) - except errors.OutOfRangeError: - results.append(None) - - num_threads = 8 - threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - self.assertEqual(num_threads, len(results)) - self.assertEqual(num_threads - 1, - len([None for r in results if r is None])) - self.assertAllEqual([[1, 4, 9]], [r for r in results if r is not None]) - - def testOneShotIteratorInitializerFails(self): - # Define a dataset whose initialization will always fail. - dataset = dataset_ops.Dataset.from_tensors( - array_ops.check_numerics( - constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - # Test that subsequent attempts to use the iterator also fail. - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - with self.test_session() as sess: - def consumer_thread(): - with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): - sess.run(next_element) - - num_threads = 8 - threads = [ - self.checkedThread(consumer_thread) for _ in range(num_threads)] - for t in threads: - t.start() - for t in threads: - t.join() - - def testSimpleSharedResource(self): - components = ( - np.array(1, dtype=np.int64), - np.array([1, 2, 3], dtype=np.int64), - np.array(37.0, dtype=np.float64) - ) - - server = server_lib.Server.create_local_server() - - # Create two non-overlapping sessions that share the same iterator - # resource on the same server, and verify that an action of the - # first session (initializing the iterator) is visible in the - # second session. - with ops.Graph().as_default(): - iterator = (dataset_ops.Dataset.from_tensors(components) - .map(lambda x, y, z: (x, y, z)).make_initializable_iterator( - shared_name="shared_iterator")) - init_op = iterator.initializer - get_next = iterator.get_next() - - with session.Session(server.target) as sess: - sess.run(init_op) - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Re-initialize the iterator in the first session. - sess.run(init_op) - - with ops.Graph().as_default(): - # Re-define the iterator manually, without defining any of the - # functions in this graph, to ensure that we are not - # accidentally redefining functions with the same names in the - # new graph. - iterator = iterator_ops.Iterator.from_structure( - shared_name="shared_iterator", - output_types=(dtypes.int64, dtypes.int64, dtypes.float64), - output_shapes=([], [3], [])) - get_next = iterator.get_next() - - with session.Session(server.target) as sess: - # Use the iterator without re-initializing in the second session. - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNotInitializedError(self): - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - iterator = (dataset_ops.Dataset.from_tensors(components) - .make_initializable_iterator()) - get_next = iterator.get_next() - - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.FailedPreconditionError, - "iterator has not been initialized"): - sess.run(get_next) - - def testReinitializableIterator(self): - dataset_3 = dataset_ops.Dataset.from_tensors( - constant_op.constant([1, 2, 3])) - dataset_4 = dataset_ops.Dataset.from_tensors( - constant_op.constant([4, 5, 6, 7])) - iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, - [None]) - - dataset_3_init_op = iterator.make_initializer(dataset_3) - dataset_4_init_op = iterator.make_initializer(dataset_4) - get_next = iterator.get_next() - - self.assertEqual(dataset_3.output_types, iterator.output_types) - self.assertEqual(dataset_4.output_types, iterator.output_types) - self.assertEqual([None], iterator.output_shapes.as_list()) - - with self.test_session() as sess: - # The iterator is initially uninitialized. - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next) - - # Initialize with one dataset. - sess.run(dataset_3_init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Initialize with a different dataset. - sess.run(dataset_4_init_op) - self.assertAllEqual([4, 5, 6, 7], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Reinitialize with the first dataset. - sess.run(dataset_3_init_op) - self.assertAllEqual([1, 2, 3], sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testReinitializableIteratorStaticErrors(self): - # Non-matching structure for types and shapes. - with self.assertRaises(TypeError): - iterator = iterator_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64), [None]) - - # Test validation of dataset argument. - iterator = iterator_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64)) - - # Incompatible structure. - with self.assertRaises(ValueError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors(((constant_op.constant( - [1, 2, 3], dtype=dtypes.int64),), (constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float64),)))) - - # Incompatible types. - with self.assertRaises(TypeError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int32), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float32)))) - - # Incompatible shapes. - iterator = iterator_ops.Iterator.from_structure( - (dtypes.int64, dtypes.float64), ([None], [])) - with self.assertRaises(TypeError): - iterator.make_initializer( - dataset_ops.Dataset.from_tensors((constant_op.constant( - [1, 2, 3], dtype=dtypes.int64), constant_op.constant( - [4., 5., 6., 7.], dtype=dtypes.float64)))) - - def testIteratorStringHandle(self): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - dataset_4 = dataset_ops.Dataset.from_tensor_slices([10, 20, 30, 40]) - - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_4 = dataset_4.make_one_shot_iterator() - - handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - feedable_iterator = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) - next_element = feedable_iterator.get_next() - - self.assertEqual(dataset_3.output_types, feedable_iterator.output_types) - self.assertEqual(dataset_4.output_types, feedable_iterator.output_types) - self.assertEqual([], feedable_iterator.output_shapes) - - with self.test_session() as sess: - iterator_3_handle = sess.run(iterator_3.string_handle()) - iterator_4_handle = sess.run(iterator_4.string_handle()) - - self.assertEqual( - 10, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 1, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 20, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 2, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 30, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - self.assertEqual( - 3, sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle})) - self.assertEqual( - 40, sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle})) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_3_handle}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element, - feed_dict={handle_placeholder: iterator_4_handle}) - - def testIteratorStringHandleError(self): - dataset_int_scalar = (dataset_ops.Dataset.from_tensor_slices([1, 2, - 3]).repeat()) - dataset_float_vector = (dataset_ops.Dataset.from_tensors([1.0, 2.0, 3.0])) - - handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - - feedable_int_scalar = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32, []) - feedable_int_vector = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32, [None]) - feedable_int_any = iterator_ops.Iterator.from_string_handle( - handle_placeholder, dtypes.int32) - - with self.test_session() as sess: - handle_int_scalar = sess.run( - dataset_int_scalar.make_one_shot_iterator().string_handle()) - handle_float_vector = sess.run( - dataset_float_vector.make_one_shot_iterator().string_handle()) - - self.assertEqual(1, - sess.run( - feedable_int_scalar.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - self.assertEqual(2, - sess.run( - feedable_int_any.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - with self.assertRaises(errors.InvalidArgumentError): - print(sess.run( - feedable_int_vector.get_next(), - feed_dict={handle_placeholder: handle_int_scalar})) - - with self.assertRaises(errors.InvalidArgumentError): - print(sess.run( - feedable_int_vector.get_next(), - feed_dict={handle_placeholder: handle_float_vector})) - - def testRemoteIteratorUsingRemoteCallOpDirectSession(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 3 - - with ops.device("/job:localhost/replica:0/task:0/cpu:1"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - @function.Defun(dtypes.string) - def _remote_fn(h): - remote_iterator = iterator_ops.Iterator.from_string_handle( - h, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device("/job:localhost/replica:0/task:0/cpu:0"): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with self.test_session(config=worker_config) as sess: - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [1]) - # Fails when target is cpu:2 where the resource is not located. - with self.assertRaises(errors.InvalidArgumentError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:2" - }) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:1" - }) - - def testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - with ops.device("/job:localhost/replica:0/task:0/cpu:0"): - dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) - iterator_3 = dataset_3.make_one_shot_iterator() - iterator_3_handle = iterator_3.string_handle() - - def _encode_raw(byte_array): - return bytes(bytearray(byte_array)) - - @function.Defun(dtypes.uint8) - def _remote_fn(h): - handle = script_ops.py_func(_encode_raw, [h], dtypes.string) - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, dataset_3.output_types, dataset_3.output_shapes) - return remote_iterator.get_next() - - with ops.device("/job:localhost/replica:0/task:0/device:GPU:0"): - target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - iterator_3_handle_uint8 = parsing_ops.decode_raw( - bytes=iterator_3_handle, out_type=dtypes.uint8) - remote_op = functional_ops.remote_call( - args=[iterator_3_handle_uint8], - Tout=[dtypes.int32], - f=_remote_fn, - target=target_placeholder) - - with self.test_session() as sess: - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [1]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - self.assertEqual(elem, [3]) - with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" - }) - - def testIncorrectIteratorRestore(self): - - def _path(): - return os.path.join(self.get_temp_dir(), "iterator") - - def _save_op(iterator_resource): - iterator_state_variant = gen_dataset_ops.serialize_iterator( - iterator_resource) - save_op = io_ops.write_file( - _path(), parsing_ops.serialize_tensor(iterator_state_variant)) - return save_op - - def _restore_op(iterator_resource): - iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(_path()), dtypes.variant) - restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, - iterator_state_variant) - return restore_op - - def _build_range_dataset_graph(): - start = 1 - stop = 10 - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - def _build_reader_dataset_graph(): - filenames = ["test"] # Does not exist but we don't care in this test. - iterator = readers.FixedLengthRecordDataset( - filenames, 1, 0, 0).make_initializable_iterator() - init_op = iterator.initializer - get_next_op = iterator.get_next() - save_op = _save_op(iterator._iterator_resource) - restore_op = _restore_op(iterator._iterator_resource) - return init_op, get_next_op, save_op, restore_op - - # Saving iterator for RangeDataset graph. - with ops.Graph().as_default() as g: - init_op, _, save_op, _ = _build_range_dataset_graph() - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(save_op) - - # Attempt to restore the saved iterator into an IteratorResource of - # incompatible type. An iterator of RangeDataset has output type int64, - # while an iterator of FixedLengthRecordDataset has output type string. - # So an InvalidArgumentError should be raised by - # IteratorResource::set_iterator. - with ops.Graph().as_default() as g: - _, _, _, restore_op = _build_reader_dataset_graph() - with self.test_session(graph=g) as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(restore_op) - - def testToSingleElement(self): - skip_value = array_ops.placeholder(dtypes.int64, shape=[]) - take_value = array_ops.placeholder_with_default( - constant_op.constant(1, dtype=dtypes.int64), shape=[]) - - dataset = (dataset_ops.Dataset.range(100) - .skip(skip_value) - .map(lambda x: x * x) - .take(take_value)) - - element = dataset_ops.get_single_element(dataset) - - with self.test_session() as sess: - self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) - self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) - self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset was empty."): - sess.run(element, feed_dict={skip_value: 100}) - - with self.assertRaisesRegexp(errors.InvalidArgumentError, - "Dataset had more than one element."): - sess.run(element, feed_dict={skip_value: 0, take_value: 2}) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py deleted file mode 100644 index 27298de65f90c627e5eb638385bfe0478ef74fca..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/list_files_dataset_op_test.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from os import path -import shutil -import tempfile - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import test -from tensorflow.python.util import compat - - -class ListFilesDatasetOpTest(test.TestCase): - - def setUp(self): - self.tmp_dir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.tmp_dir, ignore_errors=True) - - def _touchTempFiles(self, filenames): - for filename in filenames: - open(path.join(self.tmp_dir, filename), 'a').close() - - def testEmptyDirectory(self): - dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) - with self.test_session() as sess: - itr = dataset.make_one_shot_iterator() - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testSimpleDirectory(self): - filenames = ['a', 'b', 'c'] - self._touchTempFiles(filenames) - - dataset = dataset_ops.Dataset.list_files(path.join(self.tmp_dir, '*')) - with self.test_session() as sess: - itr = dataset.make_one_shot_iterator() - - full_filenames = [] - produced_filenames = [] - for filename in filenames: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - self.assertItemsEqual(full_filenames, produced_filenames) - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testEmptyDirectoryInitializer(self): - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testSimpleDirectoryInitializer(self): - filenames = ['a', 'b', 'c'] - self._touchTempFiles(filenames) - - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*')}) - - full_filenames = [] - produced_filenames = [] - for filename in filenames: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - - self.assertItemsEqual(full_filenames, produced_filenames) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testFileSuffixes(self): - filenames = ['a.txt', 'b.py', 'c.py', 'd.pyc'] - self._touchTempFiles(filenames) - - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py')}) - - full_filenames = [] - produced_filenames = [] - for filename in filenames[1:-1]: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - self.assertItemsEqual(full_filenames, produced_filenames) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - def testFileMiddles(self): - filenames = ['a.txt', 'b.py', 'c.pyc'] - self._touchTempFiles(filenames) - - filename_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - dataset = dataset_ops.Dataset.list_files(filename_placeholder) - - with self.test_session() as sess: - itr = dataset.make_initializable_iterator() - sess.run( - itr.initializer, - feed_dict={filename_placeholder: path.join(self.tmp_dir, '*.py*')}) - - full_filenames = [] - produced_filenames = [] - for filename in filenames[1:]: - full_filenames.append( - compat.as_bytes(path.join(self.tmp_dir, filename))) - produced_filenames.append(compat.as_bytes(sess.run(itr.get_next()))) - - self.assertItemsEqual(full_filenames, produced_filenames) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(itr.get_next()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index e9a07da84a8c80c09ebd4dab0b1d69febe1c9790..8d4042927970cab2f5a518fc0da49b38444dbcdf 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -16,16 +16,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from collections import namedtuple import os -import threading import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -33,15 +31,9 @@ from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops -from tensorflow.python.ops import data_flow_ops -from tensorflow.python.ops import functional_ops from tensorflow.python.ops import io_ops -from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops -from tensorflow.python.ops import script_ops -from tensorflow.python.ops import sparse_ops -from tensorflow.python.ops import string_ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import test from tensorflow.python.util import compat @@ -49,227 +41,13 @@ from tensorflow.python.util import compat class MapDatasetTest(test.TestCase): - def _buildMapDataset(self, components, count): - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(count)) - - def testMapDataset(self): - """Test an dataset that maps a TF function across its input elements.""" - # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(count). - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - count = array_ops.placeholder(dtypes.int64, shape=[]) - - dataset = self._buildMapDataset(components, count) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Test single-threaded access to the iterator. - sess.run(init_op, feed_dict={count: 14}) - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test multi-threaded access to the same iterator. - sess.run(init_op, feed_dict={count: 18}) - results = [] - def iterator_thread(): - while True: - try: - results.append(sess.run(get_next)) - except errors.OutOfRangeError: - return - threads = [self.checkedThread(target=iterator_thread) for _ in range(8)] - for t in threads: - t.start() - for t in threads: - t.join() - - # `results` will contain the same elements components**2 - # repeated 18 times, but in a non-deterministic order. Sort the - # results, and assert that each element of components**2 is - # produced 18 times. - results.sort(key=lambda x: x[0]) - for i in range(7): - for j in range(18): - for component, result_component in zip(components, - results[i * 18 + j]): - self.assertAllEqual(component[i]**2, result_component) - - def _buildParallelMapDataset(self, components, count, num_threads, - output_buffer_size): - def _map_fn(x, y, z): - return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return (dataset_ops.Dataset.from_tensor_slices(components).map( - _map_fn, num_threads=num_threads, output_buffer_size=output_buffer_size) - .repeat(count)) - - def testParallelMapDataset(self): - """Test an dataset that maps a TF function across its input elements.""" - # The pipeline is TensorSliceDataset -> ParallelMapDataset(square_3) -> - # RepeatDataset(count). - components = (np.arange(7), - np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], - np.array(37.0) * np.arange(7)) - count = array_ops.placeholder(dtypes.int64, shape=[]) - num_threads = array_ops.placeholder(dtypes.int32, shape=[]) - output_buffer_size = array_ops.placeholder(dtypes.int64, shape=[]) - - dataset = self._buildParallelMapDataset(components, count, num_threads, - output_buffer_size) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - def do_test(num_threads_val, output_buffer_size_val): - # Test single-threaded access to the iterator. - sess.run(init_op, feed_dict={ - count: 14, - num_threads: num_threads_val, - output_buffer_size: output_buffer_size_val}) - for _ in range(14): - for i in range(7): - result = sess.run(get_next) - for component, result_component in zip(components, result): - self.assertAllEqual(component[i]**2, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test multi-threaded access to the same iterator. - sess.run(init_op, feed_dict={ - count: 18, - num_threads: num_threads_val, - output_buffer_size: output_buffer_size_val}) - results = [] - def iterator_thread(): - while True: - try: - results.append(sess.run(get_next)) - except errors.OutOfRangeError: - return - threads = [self.checkedThread(target=iterator_thread) - for _ in range(64)] - for t in threads: - t.start() - for t in threads: - t.join() - - # `results` will contain the same elements components**2 - # repeated 18 times, but in a non-deterministic order. Sort the - # results, and assert that each element of components**2 is - # produced 18 times. - results.sort(key=lambda x: x[0]) - for i in range(7): - for j in range(18): - for component, result_component in zip(components, - results[i * 18 + j]): - self.assertAllEqual(component[i]**2, result_component) - - for num_threads_val, output_buffer_size_val in [ - (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]: - do_test(num_threads_val, output_buffer_size_val) - - def testImplicitDisposeParallelMapDataset(self): - # Tests whether a parallel map dataset will be cleaned up correctly when - # the pipeline does not run it until exhaustion. - # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> - # RepeatDataset(1000). - components = (np.arange(1000), - np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis], - np.array(37.0) * np.arange(1000)) - - dataset = self._buildParallelMapDataset(components, 1000, 100, 100) - # NOTE(mrry): Also test that the prefetching thread is cancelled correctly. - dataset = dataset.prefetch(100) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - - def testParallelMapUnspecifiedOutputSize(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message"), - num_threads=2)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - - def testParallelMapError(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message"), - num_threads=2, output_buffer_size=2)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - # The 4th element is NaN, so `array_ops.check_numerics()` should fail. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testPrefetchError(self): - components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message")) - .prefetch(2)) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for _ in range(3): - sess.run(get_next) - # The 4th element is NaN, so `array_ops.check_numerics()` should fail. - with self.assertRaises(errors.InvalidArgumentError): - sess.run(get_next) - sess.run(get_next) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - def testMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .map(lambda x: array_ops.check_numerics(x, "message")).apply( - error_ops.ignore_errors())) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.check_numerics(x, "message")).apply( + error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -284,10 +62,10 @@ class MapDatasetTest(test.TestCase): def testParallelMapIgnoreError(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) - dataset = (dataset_ops.Dataset.from_tensor_slices(components).map( - lambda x: array_ops.check_numerics(x, "message"), - num_threads=2, - output_buffer_size=2).apply(error_ops.ignore_errors())) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(components).map( + lambda x: array_ops.check_numerics(x, "message"), + num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -308,9 +86,10 @@ class MapDatasetTest(test.TestCase): for filename in filenames: write_string_to_file(filename, filename) - dataset = (dataset_ops.Dataset.from_tensor_slices(filenames).map( - io_ops.read_file, num_threads=2, output_buffer_size=2).apply( - error_ops.ignore_errors())) + dataset = ( + dataset_ops.Dataset.from_tensor_slices(filenames).map( + io_ops.read_file, num_parallel_calls=2).prefetch(2).apply( + error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -334,321 +113,125 @@ class MapDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testCaptureHashTable(self): - # NOTE(mrry): We must use the V2 variants of `HashTable` - # etc. because these produce a `tf.resource`-typed output that is - # compatible with the in-graph function implementation. - default_val = -1 - keys = constant_op.constant(["brain", "salad", "surgery"]) - values = constant_op.constant([0, 1, 2], dtypes.int64) - table = lookup_ops.HashTable( - lookup_ops.KeyValueTensorInitializer(keys, values), default_val) - - input_sentences = dataset_ops.Dataset.from_tensor_slices( - ["brain brain tank salad surgery", "surgery brain"]) - - iterator = (input_sentences - .map(lambda x: string_ops.string_split([x]).values) - .map(table.lookup) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(table.init) - sess.run(init_op) - - print(sess.run(get_next)) - print(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testCaptureQueue(self): - elements = np.random.randint(100, size=[200]) - queue = data_flow_ops.FIFOQueue(200, dtypes.int64, shapes=[]) - enqueue_op = queue.enqueue_many(elements) - close_op = queue.close() - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1) - .map(lambda _: queue.dequeue()).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(enqueue_op) - sess.run(close_op) - sess.run(init_op) - for element in elements: - self.assertEqual(element, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testCaptureSameResourceMultipleTimes(self): - elements = np.random.randint(100, size=[200]) - queue = data_flow_ops.FIFOQueue( - 200, dtypes.int64, shapes=[], shared_name="shared_queue") - queue_2 = data_flow_ops.FIFOQueue( - 200, dtypes.int64, shapes=[], shared_name="shared_queue") - - enqueue_op = queue.enqueue_many(elements) - close_op = queue.close() + def testCaptureResourceInMapFn(self): - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(-1) - .map(lambda _: (queue.dequeue(), queue_2.dequeue())) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + def _build_ds(iterator): - with self.test_session() as sess: - sess.run(enqueue_op) - sess.run(close_op) - sess.run(init_op) - for i in range(100): - self.assertEqual(sorted([elements[i * 2], elements[i * 2 + 1]]), - sorted(sess.run(get_next))) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def _map_fn(x): + get_next = iterator.get_next() + return x * get_next - def testCaptureVariable(self): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: counter_var.assign_add(1)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + return dataset_ops.Dataset.range(10).map(_map_fn) - with self.test_session() as sess: - sess.run(counter_var.initializer) - sess.run(init_op) - for i in range(10): - self.assertEqual(i, sess.run(counter_var)) - self.assertEqual(i + 1, sess.run(get_next)) - self.assertEqual(10, sess.run(counter_var)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(10, sess.run(counter_var)) - - def testCaptureUninitializedVariableError(self): - counter_var = variable_scope.get_variable( - "counter", (), dtypes.int32, use_resource=True) - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: counter_var.assign_add(1)) - .make_initializable_iterator()) - init_op = iterator.initializer + def _build_graph(): + captured_iterator = dataset_ops.Dataset.range( + 10).make_initializable_iterator() + ds = _build_ds(captured_iterator) + iterator = ds.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + return captured_iterator.initializer, init_op, get_next - with self.test_session() as sess: - with self.assertRaisesRegexp(errors.FailedPreconditionError, - "Failed to capture resource"): + with ops.Graph().as_default() as g: + captured_init_op, init_op, get_next = _build_graph() + with self.test_session(graph=g) as sess: + sess.run(captured_init_op) sess.run(init_op) + for i in range(10): + self.assertEquals(i * i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) - def testSeededStatefulOperatorIsProperlyStateful(self): - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10) - .map(lambda _: random_ops.random_uniform((), seed=11)).batch(2) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - with self.test_session() as sess: - sess.run(init_op) - random_values = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - random_values.extend(sess.run(get_next)) - self.assertEqual(10, len(random_values)) - self.assertGreater(np.abs(np.diff(random_values)).max(), 1e-6) - sess.run(init_op) - random_values_2 = [] - with self.assertRaises(errors.OutOfRangeError): - while True: - random_values_2.extend(sess.run(get_next)) +class MapDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - # Randomness is repeatable given same seed - self.assertAllClose(random_values, random_values_2) + def setUp(self): + self._tensor_slice_len = 7 + self._num_epochs = 14 + self._num_outputs = self._tensor_slice_len * self._num_epochs - def testMapDict(self): - iterator = (dataset_ops.Dataset.range(10) - .map(lambda x: {"foo": x * 2, "bar": x ** 2}) - .map(lambda d: d["foo"] + d["bar"]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + def _build_ds(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - self.assertEqual(i * 2 + i ** 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) - def testMapNamedtuple(self, count=10): - # construct dataset of tuples - labels = dataset_ops.Dataset.range(count) - images = labels.map(lambda l: -l) - dataset_tuple = dataset_ops.Dataset.zip((labels, images)) + return ( + dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) + .repeat(self._num_epochs)) - # convert dataset of tuples to dataset of namedtuples - example = namedtuple("Example", ["label", "image"]) - dataset_namedtuple = dataset_tuple.map(example) + def testSaveRestoreCore(self): + self.run_core_tests( + self._build_ds, + lambda: self._build_ds(multiplier=15.0), + self._num_outputs) - def preprocess_tuple(label, image): - image = 2 * image - return label, image + def testSaveStatefulFunction(self): - def preprocess_namedtuple(example): - return example._replace(image=2 * example.image) + def _build_ds(): - # preprocess both datasets - dataset_tuple = dataset_tuple.map(preprocess_tuple) - dataset_namedtuple = dataset_namedtuple.map(preprocess_namedtuple) + def _map_fn(x): + return random_ops.random_uniform( + (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - next_tuple = dataset_tuple.make_one_shot_iterator().get_next() - next_namedtuple = dataset_namedtuple.make_one_shot_iterator().get_next() + return dataset_ops.Dataset.range(100).map(_map_fn) - # make sure both datasets contain the same data - with self.test_session() as sess: - for i in range(count): - tuple_, namedtuple_ = sess.run([next_tuple, next_namedtuple]) - self.assertEqual(tuple_, namedtuple_) - self.assertEqual(tuple_, (i, -2 * i)) + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_namedtuple) - - def testUseStepContainerInMap(self): - row = np.arange(6) - iterator = ( - dataset_ops.Dataset.from_tensors(row) - .map(lambda elems: functional_ops.map_fn(lambda x: x * x, elems)) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + def testCaptureVariableInMapFn(self): - with self.test_session() as sess: - sess.run(init_op) - self.assertAllEqual(row ** 2, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def _build_ds(): + counter_var = variable_scope.get_variable( + "counter", (), dtypes.int32, use_resource=True) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda _: counter_var.assign_add(1))) - def testPrefetch(self): - # We will use this event to test that `_map_py_func()` has been - # invoked a certain number of times (6 times, to be exact) after - # consuming fewer elements from the iterator. - ev = threading.Event() + self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) - set_event_during_invocation = 5 + def testCaptureConstantInMapFn(self): - def _map_py_func(x): - if x == set_event_during_invocation: - ev.set() - return x * x + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var)) - def _map_fn(x): - return script_ops.py_func(_map_py_func, [x], x.dtype) + self.run_core_tests(_build_ds, None, 10) - buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = ( - dataset_ops.Dataset.range(100) - .map(_map_fn) - .prefetch(buffer_size_placeholder) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + def testCaptureDefunInMapFn(self): + num_outputs = 100 - with self.test_session() as sess: - # Simple test that prefetch yields the expected values in the - # expected order. - for buffer_size in [1, 10, 100, 1000]: - sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size}) - for i in range(100): - self.assertEqual(i * i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def _build_ds(): - # We can indirectly observe that varying the buffer size has the - # intended effect by observing when `ev` is set (on the 6th - # invocation of `_map_py_func()`). - # NOTE(mrry): We do not test with `buffer_size == - # set_event_during_invocation`, because we must consume at least - # one element to start the prefetching. - for buffer_size in range(1, set_event_during_invocation): - event_will_be_set_after_consuming = ( - set_event_during_invocation - buffer_size + 1) - - ev.clear() - sess.run(init_op, feed_dict={buffer_size_placeholder: buffer_size}) - for i in range(event_will_be_set_after_consuming): - self.assertFalse(ev.is_set()) - self.assertEqual(i * i, sess.run(get_next)) - ev.wait() - for i in range(event_will_be_set_after_consuming, 100): - self.assertEqual(i * i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + @function.Defun(dtypes.int64) + def defun_fn(x): + return constant_op.constant(1000) + math_ops.to_int32(x) - def testReturnList(self): - iterator = (dataset_ops.Dataset.range(10) - .map(lambda x: [x, constant_op.constant(37.0)]) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - self.assertEqual((i, 37.0), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + self.run_core_tests(_build_ds, None, num_outputs) - def testMultiOutputPyFunc(self): - # The `tf.py_func()` op returns a list of tensors for its outputs. - def _map_fn(x_tensor): - def _map_py_func(x): - return x, np.array(37.0, dtype=np.float64) - return script_ops.py_func( - _map_py_func, [x_tensor], [dtypes.int64, dtypes.float64]) - - iterator = (dataset_ops.Dataset.range(10) - .map(_map_fn) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + def testBuildDefunInMapFn(self): + num_outputs = 100 - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - self.assertEqual((i, 37.0), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + def _build_ds(): - def assertSparseValuesEqual(self, a, b): - self.assertAllEqual(a.indices, b.indices) - self.assertAllEqual(a.values, b.values) - self.assertAllEqual(a.dense_shape, b.dense_shape) + @function.Defun(dtypes.int64) + def defun_fn(x): - def testSparse(self): + @function.Defun(dtypes.int32) + def defun_fn_deep(x): + return constant_op.constant(1000) + math_ops.to_int32(x) - def _sparse(i): - return sparse_tensor.SparseTensorValue( - indices=np.array([[0, 0]]), - values=(i * np.array([1])), - dense_shape=np.array([1, 1])) + return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - iterator = (dataset_ops.Dataset.range(10) - .map(_sparse) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() + return dataset_ops.Dataset.range(num_outputs).map(defun_fn) - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, _sparse(i)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + self.run_core_tests(_build_ds, None, num_outputs) - def testSparseChain(self): + def testSparseCore(self): def _sparse(i): return sparse_tensor.SparseTensorValue( @@ -656,58 +239,20 @@ class MapDatasetTest(test.TestCase): values=(i * np.array([1])), dense_shape=np.array([1, 1])) - def _check(i): - self.assertTrue(sparse_tensor.is_sparse(i)) - return sparse_ops.sparse_concat(0, [i, i]) - - iterator = ( - dataset_ops.Dataset.range(10).map(_sparse).map(_check) - .make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - for i in range(10): - actual = sess.run(get_next) - self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue)) - self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval()) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testCaptureResourceInMapFn(self): - - def _build_ds(iterator): - - def _map_fn(x): - get_next = iterator.get_next() - return x * get_next - - return dataset_ops.Dataset.range(10).map(_map_fn) - - def _build_graph(): - captured_iterator = dataset_ops.Dataset.range( - 10).make_initializable_iterator() - ds = _build_ds(captured_iterator) - iterator = ds.make_initializable_iterator() - init_op = iterator.initializer - return captured_iterator.initializer, init_op + def _build_ds(num_outputs): + return dataset_ops.Dataset.range(num_outputs).map(_sparse) - with ops.Graph().as_default() as g: - captured_init_op, init_op = _build_graph() - with self.test_session(graph=g) as sess: - sess.run(captured_init_op) - with self.assertRaises(errors.UnimplementedError): - # CapturedFunction does not support capturing IteratorResource. - sess.run(init_op) + num_outputs = 10 + self.run_core_tests(lambda: _build_ds(num_outputs), + lambda: _build_ds(int(num_outputs / 2)), num_outputs) -class MapDatasetSerializationTest( +class ParallelMapDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): def setUp(self): self._tensor_slice_len = 7 - self._num_epochs = 14 + self._num_epochs = 1 self._num_outputs = self._tensor_slice_len * self._num_epochs def _build_ds(self, multiplier=37.0): @@ -718,14 +263,26 @@ class MapDatasetSerializationTest( def _map_fn(x, y, z): return math_ops.square(x), math_ops.square(y), math_ops.square(z) - return (dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn) - .repeat(self._num_epochs)) + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs)) + + def _build_ds_with_prefetch(self, multiplier=37.0): + components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) * + np.arange(self._tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(self._tensor_slice_len)) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + return (dataset_ops.Dataset.from_tensor_slices(components).map( + _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5)) def testSaveRestoreCore(self): - self.run_core_tests( - self._build_ds, - lambda: self._build_ds(multiplier=15.0), - self._num_outputs) + for ds_fn in [self._build_ds, self._build_ds_with_prefetch]: + self.run_core_tests( + ds_fn, + lambda: ds_fn(multiplier=15.0), + self._num_outputs) def testSaveStatefulFunction(self): @@ -735,7 +292,8 @@ class MapDatasetSerializationTest( return random_ops.random_uniform( (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x) - return dataset_ops.Dataset.range(100).map(_map_fn) + return dataset_ops.Dataset.range(100).map( + _map_fn, num_parallel_calls=2).prefetch(2) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) @@ -745,10 +303,20 @@ class MapDatasetSerializationTest( counter_var = variable_scope.get_variable( "counter", (), dtypes.int32, use_resource=True) return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( - lambda _: counter_var.assign_add(1))) + lambda _: counter_var.assign_add(1), + num_parallel_calls=2).prefetch(2)) self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError) + def testCaptureConstantInMapFn(self): + + def _build_ds(): + constant_var = constant_op.constant(5) + return (dataset_ops.Dataset.from_tensors(0).repeat(10).map( + lambda x: x + constant_var, num_parallel_calls=2).prefetch(2)) + + self.run_core_tests(_build_ds, None, 10) + def testCaptureDefunInMapFn(self): num_outputs = 100 @@ -758,7 +326,8 @@ class MapDatasetSerializationTest( def defun_fn(x): return constant_op.constant(1000) + math_ops.to_int32(x) - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) self.run_core_tests(_build_ds, None, num_outputs) @@ -776,7 +345,8 @@ class MapDatasetSerializationTest( return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x)) - return dataset_ops.Dataset.range(num_outputs).map(defun_fn) + return dataset_ops.Dataset.range(num_outputs).map( + defun_fn, num_parallel_calls=2).prefetch(2) self.run_core_tests(_build_ds, None, num_outputs) diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index 8e6ad061a11752ab7b1ffc13c90b4fa52f67d6aa..80e1cb0041024b68bd5268b5de5d69c88c839896 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -19,160 +19,24 @@ from __future__ import print_function import os +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import counter -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import enumerate_ops -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops -from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib class RangeDatasetTest(test.TestCase): - def tearDown(self): - # Remove all checkpoint files. - prefix = self._iterator_checkpoint_prefix() - pattern = prefix + "*" - files = gfile.Glob(pattern) - map(gfile.Remove, files) - - def testStop(self): - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={stop: 5}) - for i in range(5): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStartStop(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 5}) - for i in range(2, 5): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStartStopStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 10, step: 2}) - for i in range(2, 10, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testZeroStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - - with self.test_session() as sess: - with self.assertRaises(errors.InvalidArgumentError): - sess.run(init_op, feed_dict={start: 2, stop: 10, step: 0}) - - def testNegativeStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 2, stop: 10, step: -1}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(2, 10, -1): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStopLessThanStart(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(10, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStopLessThanStartWithPositiveStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2, step: 2}) - # This for loop is a no-op but will ensure that the implementation is - # consistent with range if it ever changes. - for i in range(10, 2, 2): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testStopLessThanStartWithNegativeStep(self): - start = array_ops.placeholder(dtypes.int64, shape=[]) - stop = array_ops.placeholder(dtypes.int64, shape=[]) - step = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = dataset_ops.Dataset.range(start, stop, - step).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op, feed_dict={start: 10, stop: 2, step: -1}) - for i in range(10, 2, -1): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - def testEnumerateDataset(self): components = (["a", "b"], [1, 2], [37.0, 38]) start = constant_op.constant(20, dtype=dtypes.int64) @@ -216,20 +80,25 @@ class RangeDatasetTest(test.TestCase): self.assertEqual(-1, sess.run(negative_get_next)) self.assertEqual(-2, sess.run(negative_get_next)) - def _iterator_checkpoint_prefix(self): + +class RangeDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _iterator_checkpoint_prefix_local(self): return os.path.join(self.get_temp_dir(), "iterator") def _save_op(self, iterator_resource): iterator_state_variant = gen_dataset_ops.serialize_iterator( iterator_resource) save_op = io_ops.write_file( - self._iterator_checkpoint_prefix(), + self._iterator_checkpoint_prefix_local(), parsing_ops.serialize_tensor(iterator_state_variant)) return save_op def _restore_op(self, iterator_resource): iterator_state_variant = parsing_ops.parse_tensor( - io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + io_ops.read_file(self._iterator_checkpoint_prefix_local()), + dtypes.variant) restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, iterator_state_variant) return restore_op @@ -283,382 +152,16 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) - def testSaveRestoreUsingSaverFromMetaGraph(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) - # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection - # so that it can be automatically picked up by the Saver. - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) - saver = saver_lib.Saver() - return init_op, get_next, saver - - start = 2 - stop = 10 - break_point = 5 - path = self._iterator_checkpoint_prefix() - meta_filename = path + ".meta" - - # Execute input pipeline for a few steps and save iterator state. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - saver.save(sess, path) - - # Build the saver from the MetaGraph using import_meta_graph and - # check that the iterator state is restored. - with ops.Graph().as_default() as g: - saver = saver_lib.import_meta_graph(meta_filename) - init_op, get_next = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreUsingBuiltSaver(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection - # so that it can be automatically picked up by the Saver. - saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) - saver = saver_lib.Saver() - return init_op, get_next, saver - - start = 2 - stop = 10 - stop_new = 15 - break_point = 5 - path = self._iterator_checkpoint_prefix() - - # Execute input pipeline for a few steps and save iterator state. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - saver.save(sess, path) - - # Manually build a modified Graph and Saver instead of importing - # MetaGraph and verify that original iterator state gets restored. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop_new) - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreUsingSaverThenInit(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection - # so that it can be automatically picked up by the Saver. - saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) - saver = saver_lib.Saver() - return init_op, get_next, saver - - start = 2 - stop = 10 - stop_new = 15 - break_point = 5 - path = self._iterator_checkpoint_prefix() - - # Execute input pipeline for a few steps and save iterator state. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - saver.save(sess, path) - - # Restore iterator state call and then call init_op for the iterator and - # verify that the new iterator hides the restored iterator. - with ops.Graph().as_default() as g: - init_op, get_next, saver = _build_graph(start, stop_new) - with self.test_session(graph=g) as sess: - saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) - sess.run(init_op) - for i in range(start, stop_new): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreWithoutBuildingDatasetGraph(self): - - def _build_graph(start, stop, num_epochs): - dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - num_epochs = 5 - break_point = 5 - break_epoch = 3 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for _ in range(break_epoch): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - # Create an empty IteratorResource and restore the Iterator into it. - output_types = dtypes.int64 - output_shapes = tensor_shape.scalar() - iterator = iterator_ops.Iterator.from_structure(output_types, - output_shapes) - restore_op = self._restore_op(iterator._iterator_resource) - get_next = iterator.get_next() - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - for _ in range(break_epoch + 1, num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRestoreInModifiedGraph(self): + def _build_range_dataset(self, start, stop): + return dataset_ops.Dataset.range(start, stop) - def _build_graph(start, stop): - dataset = dataset_ops.Dataset.range(start, stop) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. + def testRangeCore(self): start = 2 stop = 10 stop_1 = 8 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - # Intentionally build a graph with a different value for stop to make sure - # the original dataset graph is actually getting loaded. - init_op, get_next, _, restore_op = _build_graph(start, stop_1) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testInitThenRestore(self): - # Note: Calling init_op before restore_op is redundant. This test just makes - # sure we do not fail if restore is called on an already initialized - # iterator resource. - - def _build_graph(start, stop): - dataset = dataset_ops.Dataset.range(start, stop) - iterator = dataset.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - # Saving and restoring in different sessions. - start = 2 - stop = 10 - break_point = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(init_op) - sess.run(restore_op) - for i in range(break_point, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testMultipleSaves(self): - - def _build_graph(start, stop): - iterator = dataset_ops.Dataset.range(start, - stop).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - break_point1 = 5 - break_point2 = 7 - - with ops.Graph().as_default() as g: - init_op, get_next, save_op, _ = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - for i in range(start, break_point1): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point1, break_point2): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - break_point2 = 7 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph(start, stop) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_point2, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreWithRepeat(self): - - def _build_graph(start, stop, num_epochs): - iterator = dataset_ops.Dataset.range( - start, stop).repeat(num_epochs).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - num_epochs = 5 - break_range = 5 - break_epoch = 3 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph( - start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(break_epoch - 1): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - for i in range(start, break_range): - self.assertEqual(i, sess.run(get_next)) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - for i in range(break_range, stop): - self.assertEqual(i, sess.run(get_next)) - for _ in range(break_epoch, num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSaveRestoreExhaustedIterator(self): - - def _build_graph(start, stop, num_epochs): - iterator = dataset_ops.Dataset.range( - start, stop).repeat(num_epochs).make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - save_op = self._save_op(iterator._iterator_resource) - restore_op = self._restore_op(iterator._iterator_resource) - return init_op, get_next, save_op, restore_op - - start = 2 - stop = 10 - num_epochs = 5 - with ops.Graph().as_default() as g: - init_op, get_next, save_op, restore_op = _build_graph( - start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(variables.global_variables_initializer()) - sess.run(init_op) - # Note: There is no checkpoint saved currently so a NotFoundError is - # raised. - with self.assertRaises(errors.NotFoundError): - sess.run(restore_op) - for _ in range(num_epochs): - for i in range(start, stop): - self.assertEqual(i, sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - sess.run(save_op) - - with ops.Graph().as_default() as g: - init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) - with self.test_session(graph=g) as sess: - sess.run(restore_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) + self.run_core_tests(lambda: self._build_range_dataset(start, stop), + lambda: self._build_range_dataset(start, stop_1), + stop - start) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index 1c42a3d855bc16c21e385d7108c3106884ae4f5e..6efe97444a375febc550ff3a3ea04bcd9330a3a5 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -26,6 +26,7 @@ from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.data.ops import readers as core_readers from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -76,101 +77,12 @@ class TextLineDatasetTestBase(test.TestCase): return filenames -class TextLineDatasetTest(TextLineDatasetTestBase): - - def _testTextLineDataset(self, compression_type=None): - test_filenames = self._createFiles( - 2, 5, crlf=True, compression_type=compression_type) - filenames = array_ops.placeholder(dtypes.string, shape=[None]) - num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = readers.TextLineDataset( - filenames, compression_type=compression_type).repeat(num_epochs) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.test_session() as sess: - # Basic test: read from file 0. - sess.run( - init_op, feed_dict={filenames: [test_filenames[0]], - num_epochs: 1}) - for i in range(5): - self.assertEqual(self._lineText(0, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from file 1. - sess.run( - init_op, feed_dict={filenames: [test_filenames[1]], - num_epochs: 1}) - for i in range(5): - self.assertEqual(self._lineText(1, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1}) - for j in range(2): - for i in range(5): - self.assertEqual(self._lineText(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10}) - for _ in range(10): - for j in range(2): - for i in range(5): - self.assertEqual(self._lineText(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both files. - sess.run( - init_batch_op, - feed_dict={filenames: test_filenames, - num_epochs: 10, - batch_size: 5}) - for _ in range(10): - self.assertAllEqual([self._lineText(0, i) for i in range(5)], - sess.run(get_next)) - self.assertAllEqual([self._lineText(1, i) for i in range(5)], - sess.run(get_next)) - - def testTextLineDatasetNoCompression(self): - self._testTextLineDataset() - - def testTextLineDatasetGzipCompression(self): - self._testTextLineDataset(compression_type="GZIP") - - def testTextLineDatasetZlibCompression(self): - self._testTextLineDataset(compression_type="ZLIB") - - def testTextLineDatasetBuffering(self): - test_filenames = self._createFiles(2, 5, crlf=True) - - repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10) - iterator = repeat_dataset.make_one_shot_iterator() - - with self.test_session() as sess: - for j in range(2): - for i in range(5): - self.assertEqual(self._lineText(j, i), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - class TextLineDatasetSerializationTest( TextLineDatasetTestBase, dataset_serialization_test_base.DatasetSerializationTestBase): def _build_iterator_graph(self, test_filenames, compression_type=None): - return readers.TextLineDataset( + return core_readers.TextLineDataset( test_filenames, compression_type=compression_type, buffer_size=10) def testTextLineCore(self): @@ -217,101 +129,13 @@ class FixedLengthRecordReaderTestBase(test.TestCase): return filenames -class FixedLengthRecordReaderTest(FixedLengthRecordReaderTestBase): - - def testFixedLengthRecordDataset(self): - test_filenames = self._createFiles() - filenames = array_ops.placeholder(dtypes.string, shape=[None]) - num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) - batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = (readers.FixedLengthRecordDataset( - filenames, self._record_bytes, self._header_bytes, self._footer_bytes) - .repeat(num_epochs)) - batch_dataset = repeat_dataset.batch(batch_size) - - iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) - init_op = iterator.make_initializer(repeat_dataset) - init_batch_op = iterator.make_initializer(batch_dataset) - get_next = iterator.get_next() - - with self.test_session() as sess: - # Basic test: read from file 0. - sess.run( - init_op, feed_dict={filenames: [test_filenames[0]], - num_epochs: 1}) - for i in range(self._num_records): - self.assertEqual(self._record(0, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from file 1. - sess.run( - init_op, feed_dict={filenames: [test_filenames[1]], - num_epochs: 1}) - for i in range(self._num_records): - self.assertEqual(self._record(1, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Basic test: read from both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 1}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertEqual(self._record(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test repeated iteration through both files. - sess.run(init_op, feed_dict={filenames: test_filenames, num_epochs: 10}) - for _ in range(10): - for j in range(self._num_files): - for i in range(self._num_records): - self.assertEqual(self._record(j, i), sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test batched and repeated iteration through both files. - sess.run( - init_batch_op, - feed_dict={ - filenames: test_filenames, - num_epochs: 10, - batch_size: self._num_records - }) - for _ in range(10): - for j in range(self._num_files): - self.assertAllEqual( - [self._record(j, i) for i in range(self._num_records)], - sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testFixedLengthRecordDatasetBuffering(self): - test_filenames = self._createFiles() - dataset = readers.FixedLengthRecordDataset( - test_filenames, - self._record_bytes, - self._header_bytes, - self._footer_bytes, - buffer_size=10) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - for j in range(self._num_files): - for i in range(self._num_records): - self.assertEqual(self._record(j, i), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - class FixedLengthRecordDatasetSerializationTest( FixedLengthRecordReaderTestBase, dataset_serialization_test_base.DatasetSerializationTestBase): def _build_iterator_graph(self, num_epochs, compression_type=None): filenames = self._createFiles() - return readers.FixedLengthRecordDataset( + return core_readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes).repeat(num_epochs) @@ -338,9 +162,8 @@ class TFRecordDatasetTestBase(test.TestCase): self.compression_type = array_ops.placeholder_with_default("", shape=[]) self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - repeat_dataset = readers.TFRecordDataset(self.filenames, - self.compression_type).repeat( - self.num_epochs) + repeat_dataset = core_readers.TFRecordDataset( + self.filenames, self.compression_type).repeat(self.num_epochs) batch_dataset = repeat_dataset.batch(self.batch_size) iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) @@ -363,129 +186,6 @@ class TFRecordDatasetTestBase(test.TestCase): return filenames -class TFRecordDatasetTest(TFRecordDatasetTestBase): - - def testReadOneEpoch(self): - with self.test_session() as sess: - # Basic test: read from file 0. - sess.run( - self.init_op, - feed_dict={ - self.filenames: [self.test_filenames[0]], - self.num_epochs: 1 - }) - for i in range(self._num_records): - self.assertAllEqual(self._record(0, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - # Basic test: read from file 1. - sess.run( - self.init_op, - feed_dict={ - self.filenames: [self.test_filenames[1]], - self.num_epochs: 1 - }) - for i in range(self._num_records): - self.assertAllEqual(self._record(1, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - # Basic test: read from both files. - sess.run( - self.init_op, - feed_dict={self.filenames: self.test_filenames, - self.num_epochs: 1}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadTenEpochs(self): - with self.test_session() as sess: - sess.run( - self.init_op, - feed_dict={self.filenames: self.test_filenames, - self.num_epochs: 10}) - for _ in range(10): - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadTenEpochsOfBatches(self): - with self.test_session() as sess: - sess.run( - self.init_batch_op, - feed_dict={ - self.filenames: self.test_filenames, - self.num_epochs: 10, - self.batch_size: self._num_records - }) - for _ in range(10): - for j in range(self._num_files): - values = sess.run(self.get_next) - self.assertAllEqual( - [self._record(j, i) for i in range(self._num_records)], values) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadZlibFiles(self): - zlib_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - cdata = zlib.compress(f.read()) - - zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i) - with open(zfn, "wb") as f: - f.write(cdata) - zlib_files.append(zfn) - - with self.test_session() as sess: - sess.run( - self.init_op, - feed_dict={self.filenames: zlib_files, - self.compression_type: "ZLIB"}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadGzipFiles(self): - gzip_files = [] - for i, fn in enumerate(self.test_filenames): - with open(fn, "rb") as f: - gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i) - with gzip.GzipFile(gzfn, "wb") as gzf: - gzf.write(f.read()) - gzip_files.append(gzfn) - - with self.test_session() as sess: - sess.run( - self.init_op, - feed_dict={self.filenames: gzip_files, - self.compression_type: "GZIP"}) - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(self.get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(self.get_next) - - def testReadWithBuffer(self): - one_mebibyte = 2**20 - d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) - iterator = d.make_one_shot_iterator() - with self.test_session() as sess: - for j in range(self._num_files): - for i in range(self._num_records): - self.assertAllEqual(self._record(j, i), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - class TFRecordDatasetSerializationTest( TFRecordDatasetTestBase, dataset_serialization_test_base.DatasetSerializationTestBase): @@ -517,7 +217,7 @@ class TFRecordDatasetSerializationTest( gzip_files.append(gzfn) filenames = gzip_files - return readers.TFRecordDataset( + return core_readers.TFRecordDataset( filenames, compression_type, buffer_size=buffer_size).repeat(num_epochs).batch(batch_size) @@ -575,7 +275,7 @@ class ReadBatchFeaturesTest(test.TestCase): "record": parsing_ops.FixedLenFeature([], dtypes.int64), "keywords": parsing_ops.VarLenFeature(dtypes.string) }, - reader=readers.TFRecordDataset, + reader=core_readers.TFRecordDataset, randomize_input=False, num_epochs=self.num_epochs) @@ -714,12 +414,11 @@ class ReadBatchFeaturesTest(test.TestCase): self._next_actual_batch(sess) def testReadWithEquivalentDataset(self): - # TODO(mrry): Add support for tf.SparseTensor as a Dataset component. features = { "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } - dataset = (readers.TFRecordDataset(self.test_filenames) + dataset = (core_readers.TFRecordDataset(self.test_filenames) .map(lambda x: parsing_ops.parse_single_example(x, features)) .repeat(10).batch(2)) iterator = dataset.make_initializable_iterator() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 0ac8d7359f7234d98167277724780bf31555e6fb..3c7b46629edb13459766b5ef3f392e8d00ad4db8 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -19,8 +19,8 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.data.python.ops import resampling +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import errors from tensorflow.python.ops import string_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py index 5338ec56bf275e481a984964e39aa0c1ade3a752..e0494736b72ae52f586cb80d42a5c1e50ac17a61 100644 --- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -21,6 +21,7 @@ import itertools import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import scan_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import constant_op @@ -124,5 +125,18 @@ class ScanDatasetTest(test.TestCase): scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) +class ScanDatasetSerialzationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset(self, num_elements): + return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))) + + def testScanCore(self): + num_output = 5 + self.run_core_tests(lambda: self._build_dataset(num_output), + lambda: self._build_dataset(2), num_output) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py index 1a26da82e533ec01106ea10525c1cd96627c34fb..36ddf3004237ed042f21d691d83eafbaa20621e6 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sequence_dataset_op_test.py @@ -20,194 +20,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class SequenceDatasetTest(test.TestCase): - - def testRepeatTensorDataset(self): - """Test a dataset that repeats its input multiple times.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - # This placeholder can be fed when dataset-definition subgraph - # runs (i.e. `init_op` below) to configure the number of - # repetitions used in a particular iterator. - count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensors(components) - .repeat(count_placeholder).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Test a finite repetition. - sess.run(init_op, feed_dict={count_placeholder: 3}) - for _ in range(3): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test a different finite repetition. - sess.run(init_op, feed_dict={count_placeholder: 7}) - for _ in range(7): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test an empty repetition. - sess.run(init_op, feed_dict={count_placeholder: 0}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Test an infinite repetition. - # NOTE(mrry): There's not a good way to test that the sequence - # actually is infinite. - sess.run(init_op, feed_dict={count_placeholder: -1}) - for _ in range(17): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - - def testTakeTensorDataset(self): - components = (np.arange(10),) - count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .take(count_placeholder).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Take fewer than input size - sess.run(init_op, feed_dict={count_placeholder: 4}) - for i in range(4): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Take more than input size - sess.run(init_op, feed_dict={count_placeholder: 25}) - for i in range(10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Take all of input - sess.run(init_op, feed_dict={count_placeholder: -1}) - for i in range(10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Take nothing - sess.run(init_op, feed_dict={count_placeholder: 0}) - - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testSkipTensorDataset(self): - components = (np.arange(10),) - count_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .skip(count_placeholder).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape[1:] for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - # Skip fewer than input size, we should skip - # the first 4 elements and then read the rest. - sess.run(init_op, feed_dict={count_placeholder: 4}) - for i in range(4, 10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Skip more than input size: get nothing. - sess.run(init_op, feed_dict={count_placeholder: 25}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Skip exactly input size. - sess.run(init_op, feed_dict={count_placeholder: 10}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Set -1 for 'count': skip the entire dataset. - sess.run(init_op, feed_dict={count_placeholder: -1}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - # Skip nothing - sess.run(init_op, feed_dict={count_placeholder: 0}) - for i in range(0, 10): - results = sess.run(get_next) - self.assertAllEqual(results, components[0][i:i+1]) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRepeatRepeatTensorDataset(self): - """Test the composition of repeat datasets.""" - components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) - inner_count = array_ops.placeholder(dtypes.int64, shape=[]) - outer_count = array_ops.placeholder(dtypes.int64, shape=[]) - - iterator = (dataset_ops.Dataset.from_tensors(components).repeat(inner_count) - .repeat(outer_count).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([c.shape for c in components], - [t.shape for t in get_next]) - - with self.test_session() as sess: - sess.run(init_op, feed_dict={inner_count: 7, outer_count: 14}) - for _ in range(7 * 14): - results = sess.run(get_next) - for component, result_component in zip(components, results): - self.assertAllEqual(component, result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testRepeatEmptyDataset(self): - """Test that repeating an empty dataset does not hang.""" - iterator = (dataset_ops.Dataset.from_tensors(0).repeat(10).skip(10) - .repeat(-1).make_initializable_iterator()) - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - sess.run(init_op) - with self.assertRaisesRegexp( - errors.OutOfRangeError, - "Attempted to repeat an empty dataset infinitely."): - sess.run(get_next) - - class SequenceDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py b/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0a6b74dc3eb80a6168117beed06935737198cecb --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/serialization_integration_test.py @@ -0,0 +1,85 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Integration test for input pipeline serialization.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib + + +class MultipleInputPipelinesTest(test.TestCase): + + def _build_input_pipeline(self, name, num_outputs): + with ops.name_scope(name): + ds = dataset_ops.Dataset.range(num_outputs).shuffle( + 10, reshuffle_each_iteration=False).prefetch(10) + iterator = ds.make_initializable_iterator() + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + return iterator.initializer, iterator.get_next() + + def _build_graph(self, num_pipelines, num_outputs): + init_ops = [] + get_next_ops = [] + for i in range(num_pipelines): + name = "input_pipeline_%d" % i + init_op, get_next_op = self._build_input_pipeline(name, num_outputs) + init_ops.append(init_op) + get_next_ops.append(get_next_op) + saver = saver_lib.Saver() + return init_ops, get_next_ops, saver + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def testConcurrentSaves(self): + num_pipelines = 100 + num_outputs = 100 + break_point = 10 + all_outputs = [[] for _ in range(num_pipelines)] + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.test_session(graph=g) as sess: + sess.run(init_ops) + for _ in range(break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + saver.save(sess, self._ckpt_path()) + + with ops.Graph().as_default() as g: + init_ops, get_next_ops, saver = self._build_graph(num_pipelines, + num_outputs) + with self.test_session(graph=g) as sess: + saver.restore(sess, self._ckpt_path()) + for _ in range(num_outputs - break_point): + output = sess.run(get_next_ops) + for i in range(num_pipelines): + all_outputs[i].append(output[i]) + + for output in all_outputs: + self.assertSequenceEqual(sorted(output), range(num_outputs)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shard_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shard_dataset_op_test.py deleted file mode 100644 index 0b3c32c06eb1d69244c9a02ca4ba571769f13f40..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/data/python/kernel_tests/shard_dataset_op_test.py +++ /dev/null @@ -1,111 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the experimental input pipeline ops.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import errors -from tensorflow.python.platform import test - - -class ShardDatasetOpTest(test.TestCase): - - def testSimpleCase(self): - dataset = dataset_ops.Dataset.range(10).shard(5, 2) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - self.assertEqual(2, sess.run(iterator.get_next())) - self.assertEqual(7, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testNestedData(self): - dataset_a = dataset_ops.Dataset.range(10) - dataset_b = dataset_ops.Dataset.range(10, 0, -1) - dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - self.assertEqual((2, 8), sess.run(iterator.get_next())) - self.assertEqual((7, 3), sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testOffsetZero(self): - dataset = dataset_ops.Dataset.range(10).shard(5, 0) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - self.assertEqual(0, sess.run(iterator.get_next())) - self.assertEqual(5, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testOffsetGreaterNumShards(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(5, 7) - - def testNegativeOffset(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(5, -3) - - def testNegativeNumShards(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(-3, 1) - - def testZeroNumShards(self): - with self.assertRaises(ValueError): - dataset_ops.Dataset.range(10).shard(0, 1) - - def testIteratorEndsBeforeFirstElem(self): - dataset = dataset_ops.Dataset.range(1).shard(5, 2) - iterator = dataset.make_one_shot_iterator() - - with self.test_session() as sess: - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testLargerWorkerPool(self): - dataset = dataset_ops.Dataset.range(10).shard(7, 5) - iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: - self.assertEqual(5, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testIndexEqualsNumShards(self): - dataset = dataset_ops.Dataset.range(10).shard(5, 4) - iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: - self.assertEqual(4, sess.run(iterator.get_next())) - self.assertEqual(9, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - def testIndexEqualsNumShards2(self): - dataset = dataset_ops.Dataset.range(10).shard(4, 3) - iterator = dataset.make_one_shot_iterator() - with self.test_session() as sess: - self.assertEqual(3, sess.run(iterator.get_next())) - self.assertEqual(7, sess.run(iterator.get_next())) - with self.assertRaises(errors.OutOfRangeError): - sess.run(iterator.get_next()) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index 6b5b53cc0f8f2d1df5622a5bc5e2f8ef04c6342a..bcc644c0971854d948025009dc7add2fea214048 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -17,461 +17,145 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import collections -import os - import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops -from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import shuffle_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.platform import gfile from tensorflow.python.platform import test -from tensorflow.python.training import saver as saver_lib - - -class ShuffleDatasetTest(test.TestCase): - - def testShuffleDataset(self): - components = ( - np.array([1, 2, 3, 4]), np.array([5, 6, 7, 8]), - np.array([9.0, 10.0, 11.0, 12.0]) - ) - count_placeholder = array_ops.placeholder_with_default( - constant_op.constant(5, dtypes.int64), shape=[]) - buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - - repeat_dataset = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) - - shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder, - seed_placeholder) - - self.assertEqual(tuple([c.shape[1:] for c in components]), - shuffle_dataset.output_shapes) - - # Create initialization ops for iterators without and with - # shuffling, respectively. - iterator = iterator_ops.Iterator.from_structure( - shuffle_dataset.output_types, shuffle_dataset.output_shapes) - init_fifo_op = iterator.make_initializer(repeat_dataset) - init_shuffle_op = iterator.make_initializer(shuffle_dataset) - - get_next = iterator.get_next() - with self.test_session() as sess: - # First run without shuffling to collect the "ground truth". - sess.run(init_fifo_op) - unshuffled_elements = [] - for _ in range(20): - unshuffled_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - # Assert that the shuffled dataset has the same elements as the - # "ground truth". - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 100, - seed_placeholder: 37}) - shuffled_elements = [] - for _ in range(20): - shuffled_elements.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertAllEqual( - sorted(unshuffled_elements), sorted(shuffled_elements)) - - # Assert that shuffling twice with the same seeds gives the same sequence. - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 100, - seed_placeholder: 37}) - reshuffled_elements_same_seed = [] - for _ in range(20): - reshuffled_elements_same_seed.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertEqual(shuffled_elements, reshuffled_elements_same_seed) - - # Assert that shuffling twice with a different seed gives a different - # permutation of the same elements. - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 100, - seed_placeholder: 1037}) - reshuffled_elements_different_seed = [] - for _ in range(20): - reshuffled_elements_different_seed.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertNotEqual(shuffled_elements, reshuffled_elements_different_seed) - self.assertAllEqual( - sorted(shuffled_elements), sorted(reshuffled_elements_different_seed)) - - # Assert that the shuffled dataset has the same elements as the - # "ground truth" when the buffer size is smaller than the input - # dataset. - sess.run( - init_shuffle_op, - feed_dict={buffer_size_placeholder: 2, - seed_placeholder: 37}) - reshuffled_elements_small_buffer = [] - for _ in range(20): - reshuffled_elements_small_buffer.append(sess.run(get_next)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - self.assertAllEqual( - sorted(unshuffled_elements), sorted(reshuffled_elements_small_buffer)) - - # Test the case of shuffling an empty dataset. - sess.run(init_shuffle_op, feed_dict={buffer_size_placeholder: 2, - seed_placeholder: 37, - count_placeholder: 0}) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testDefaultArguments(self): - components = [0, 1, 2, 3, 4] - iterator = ( - contrib_dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) - .repeat().make_one_shot_iterator()) - - get_next = iterator.get_next() - - with self.test_session() as sess: - counts = collections.defaultdict(lambda: 0) - for _ in range(10): - for _ in range(5): - counts[sess.run(get_next)] += 1 - - for i in range(5): - self.assertEqual(10, counts[i]) +class ShuffleDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): - -class ShuffleDatasetSerializationTest(test.TestCase): - - def tearDown(self): - # Remove all checkpoint files. - prefix = self._ckpt_path() - pattern = prefix + "*" - files = gfile.Glob(pattern) - map(gfile.Remove, files) - - def _build_graph(self, - range_limit=10, - num_repeats=5, - buffer_size=5, - seed=None, - reshuffle_each_iteration=None, - build_saveable=True): - iterator = dataset_ops.Dataset.range(range_limit).shuffle( + def _build_shuffle_dataset( + self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + ): + return dataset_ops.Dataset.range(range_limit).shuffle( buffer_size, seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration).repeat( - num_repeats).make_initializable_iterator() - if build_saveable: - saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) - ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) - init_op = iterator.initializer - get_next = iterator.get_next() - ops.add_to_collection("iterator_ops", init_op) - ops.add_to_collection("iterator_ops", get_next) - saver = saver_lib.Saver(allow_empty=True) - return init_op, get_next, saver - - def _ckpt_path(self): - return os.path.join(self.get_temp_dir(), "iterator") - - def _latest_ckpt(self): - return saver_lib.latest_checkpoint(self.get_temp_dir()) - - def _save(self, sess, saver): - saver.save(sess, self._ckpt_path()) - - def _restore(self, saver, sess): - saver.restore(sess, self._latest_ckpt()) - - def _import_meta_graph(self): - meta_file_path = self._ckpt_path() + ".meta" - return saver_lib.import_meta_graph(meta_file_path) - - def _testReadWithBreaks(self, break_points, init_before_restore=False): - seed = 55 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - expected = [] - actual = [] - # Generate the ground truth. - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, _ = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs): - expected.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Run and checkpoint after first break_point. - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_points[0]): - actual.append(sess.run(get_next_op)) - self._save(sess, saver) + reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats) - # Load from checkpoint and continue running while stopping at each - # subsequent checkpoint. - for i in range(len(break_points)): - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - if init_before_restore: - sess.run(init_op) - self._restore(saver, sess) - start = break_points[i] - end = break_points[ - i + 1] if i < len(break_points) - 1 else num_outputs - for _ in range(end - start): - actual.append(sess.run(get_next_op)) - self._save(sess, saver) - if end == num_outputs: - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self.assertEqual(expected, actual) + def testShuffleCore(self): - def testSaveRestore(self): - self._testReadWithBreaks([8]) # rng buffer_size: 0 - self._testReadWithBreaks([13]) # rng buffer_size: 1 - self._testReadWithBreaks([18]) # rng buffer_size: 2 - self._testReadWithBreaks([23]) # rng buffer_size: 3 - - def testSaveUnusedIterator(self): - self._testReadWithBreaks([0]) - - def testSaveFullyUsedIterator(self): - self._testReadWithBreaks([50]) - - def testMultipleBreaks(self): - self._testReadWithBreaks([0, 5, 9, 15, 25, 32]) - - def testIdempotence(self): - # Attempt to save iterator immediately after restoring. - self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32]) - - def testInitThenRestore(self): - self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True) - - def testRestoreExhaustedIterator(self): - seed = 55 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [1, 3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs): - sess.run(get_next_op) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - self._save(sess, saver) - - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - def testResetRestoredIterator(self): seed = 55 range_limit = 10 num_repeats = 5 num_outputs = range_limit * num_repeats buffer_sizes = [1, 3, 8, 10, 25, 50] reshuffle_each_iteration = False + # pylint: disable=cell-var-from-loop + # pylint: disable=g-long-lambda for buffer_size in buffer_sizes: - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs // 2): - sess.run(get_next_op) - self._save(sess, saver) - - outputs = [] - with ops.Graph().as_default() as g: - saver = self._import_meta_graph() - init_op, get_next_op = ops.get_collection("iterator_ops") - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - sess.run(init_op) - for _ in range(num_outputs): - outputs.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - expected_outputs_sorted = sorted( - np.array([range(range_limit) - for _ in range(num_repeats)]).flatten()) - self.assertEqual(expected_outputs_sorted, sorted(outputs)) - - def testRestoreInModifiedGraph(self): - seed = 55 - break_point = 25 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - expected = [] - actual_without_restore = [] - actual = [] - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - expected.append(sess.run(get_next_op)) - actual.extend(expected) - self._save(sess, saver) - for _ in range(num_outputs - break_point): - expected.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - g.seed = 20 # Different seed than previous graph for shuffle rngs. - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(num_outputs): - actual_without_restore.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - with ops.Graph().as_default() as g: - g.seed = 20 # Different seed than previous graph for shuffle rngs. - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - self._restore(saver, sess) - for _ in range(num_outputs - break_point): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - - # Since the modified graph has a different random seed it produces a - # different order of examples. - self.assertNotEqual(expected, actual_without_restore) - self.assertEqual(sorted(expected), sorted(actual_without_restore)) - self.assertEqual(expected, actual) - - def testDoNotBuildSaveable(self): - seed = 55 - break_point = 25 - range_limit = 10 - num_repeats = 5 - num_outputs = range_limit * num_repeats - buffer_sizes = [3, 8, 10, 25, 50] - reshuffle_each_iteration = False - for buffer_size in buffer_sizes: - actual = [] - with ops.Graph().as_default() as g: - g.seed = 10 - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration) - with self.test_session(graph=g) as sess: - sess.run(init_op) - for _ in range(break_point): - sess.run(get_next_op) - self._save(sess, saver) - - with ops.Graph().as_default() as g: - g.seed = 20 # Different seed than previous graph for shuffle rngs. - init_op, get_next_op, saver = self._build_graph( - range_limit=range_limit, - num_repeats=num_repeats, - buffer_size=buffer_size, - seed=seed, - reshuffle_each_iteration=reshuffle_each_iteration, - build_saveable=False) - with self.test_session(graph=g) as sess: - # Since the SaveableObject was not added to Saver's list - # of saveables, iterator state is not restored by saver.restore(). - self._restore(saver, sess) - with self.assertRaises(errors.FailedPreconditionError): - sess.run(get_next_op) - sess.run(init_op) - for _ in range(num_outputs): - actual.append(sess.run(get_next_op)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next_op) - expected_outputs_sorted = sorted( - np.array([range(range_limit) for _ in range(num_repeats)]).flatten()) - self.assertEqual(expected_outputs_sorted, sorted(actual)) + self.run_core_tests( + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration), + lambda: self._build_shuffle_dataset( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=10, + reshuffle_each_iteration=reshuffle_each_iteration), + num_outputs) + # pylint: enable=cell-var-from-loop + # pylint: enable=g-long-lambda + + +class ShuffleAndRepeatTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed, count=5, num_elements=20): + return dataset_ops.Dataset.range(num_elements).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed)) + + def testCorrectOutput(self): + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertSequenceEqual( + sorted(output), sorted( + np.array([range(20) for _ in range(5)]).flatten())) + for i in range(5): + self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20)) + + def testReshuffling(self): + # Check that the output orders of different epochs are indeed different. + output = self.gen_outputs(lambda: self._build_ds(10), [], 100) + for i in range(4): + epoch1 = output[i * 20:(i + 1) * 20] + epoch2 = output[(i + 1) * 20:(i + 2) * 20] + self.assertNotEqual(epoch1, epoch2) + + def testSameOrderForSameSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + self.assertEqual(output1, output2) + + def testDifferentOrderForDifferentSeeds(self): + output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100) + output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountNone(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testCountMinusOne(self): + output1 = self.gen_outputs( + lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False) + output2 = self.gen_outputs( + lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False) + self.assertNotEqual(output1, output2) + self.assertEqual(sorted(output1), sorted(output2)) + + def testInfiniteOutputs(self): + # Asserting the iterator is exhausted after producing 100 items should fail. + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100) + with self.assertRaises(AssertionError): + self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100) + + def testInfiniteEmpty(self): + with self.assertRaises(errors.OutOfRangeError): + self.gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0), + [], 100) + with self.assertRaises(errors.OutOfRangeError): + self.gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0), [], + 100) + + def testLargeBufferSize(self): + with ops.Graph().as_default() as g: + ds = dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=21)) + get_next_op = ds.make_one_shot_iterator().get_next() + with self.test_session(graph=g) as sess: + sess.run(get_next_op) + + +class ShuffleAndRepeatSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_ds(self, seed): + return dataset_ops.Dataset.range(20).apply( + shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed)) + + def testCore(self): + self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20), + 100) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index efd864f866611bfd3bac1edcf98d84be852410fd..e26cef8ec522c7e69a0c19b2b30a969bbfc0ad78 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function import os + import sqlite3 from tensorflow.contrib.data.python.ops import readers diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py index 8f24d6b2f612cff662aa8a36085bc69a9ea1a290..07bdf920446e953c2a1abaf495d2e9e1256106fd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base from tensorflow.contrib.data.python.ops import stats_ops from tensorflow.core.framework import summary_pb2 from tensorflow.python.data.ops import dataset_ops @@ -209,5 +210,48 @@ class StatsDatasetTest(test.TestCase): sess.run(stats_aggregator_1.subscribe(iterator)) +class StatsDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def _build_dataset_bytes_stats(self, num_elements): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply( + stats_ops.bytes_produced_stats("bytes_produced")) + + def testBytesStatsDatasetSaveableCore(self): + num_outputs = 100 + self.run_core_tests( + lambda: self._build_dataset_bytes_stats(num_outputs), + lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs) + + def _build_dataset_latency_stats(self, num_elements, tag="record_latency"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag)) + + def _build_dataset_multiple_tags(self, + num_elements, + tag1="record_latency", + tag2="record_latency_2"): + return dataset_ops.Dataset.range(num_elements).apply( + stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2)) + + def testLatencyStatsDatasetSaveableCore(self): + num_outputs = 100 + + self.run_core_tests( + lambda: self._build_dataset_latency_stats(num_outputs), + lambda: self._build_dataset_latency_stats(num_outputs // 10), + num_outputs) + + self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs), + None, num_outputs) + + tag1 = "record_latency" + tag2 = "record_latency" + self.run_core_tests( + lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2), + None, num_outputs) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3c436f7a0b45a13109960e87dd97ca56b10bb871 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py @@ -0,0 +1,96 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import unique +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.platform import test +from tensorflow.python.util import compat + + +class UniqueDatasetTest(test.TestCase): + + def _testSimpleHelper(self, dtype, test_cases): + """Test the `unique()` transformation on a list of test cases. + + Args: + dtype: The `dtype` of the elements in each test case. + test_cases: A list of pairs of lists. The first component is the test + input that will be passed to the transformation; the second component + is the expected sequence of outputs from the transformation. + """ + + # The `current_test_case` will be updated when we loop over `test_cases` + # below; declare it here so that the generator can capture it once. + current_test_case = [] + dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case, + dtype).apply(unique.unique()) + iterator = dataset.make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for test_case, expected in test_cases: + current_test_case = test_case + sess.run(iterator.initializer) + for element in expected: + if dtype == dtypes.string: + element = compat.as_bytes(element) + self.assertAllEqual(element, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testSimpleInt(self): + for dtype in [dtypes.int32, dtypes.int64]: + self._testSimpleHelper(dtype, [ + ([], []), + ([1], [1]), + ([1, 1, 1, 1, 1, 1, 1], [1]), + ([1, 2, 3, 4], [1, 2, 3, 4]), + ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]), + ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]), + ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]), + ]) + + def testSimpleString(self): + self._testSimpleHelper(dtypes.string, [ + ([], []), + (["hello"], ["hello"]), + (["hello", "hello", "hello"], ["hello"]), + (["hello", "world"], ["hello", "world"]), + (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]), + ]) + + +class UniqueSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def testUnique(self): + + def build_dataset(num_elements, unique_elem_range): + return dataset_ops.Dataset.range(num_elements).map( + lambda x: x % unique_elem_range).apply(unique.unique()) + + self.run_core_tests(lambda: build_dataset(200, 100), + lambda: build_dataset(40, 100), 100) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py index 5d34b0024c472d0393544ff3dad8acea7964345f..e39fa957f0bbb9d3671274d5f58b993e8399814b 100644 --- a/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/zip_dataset_op_test.py @@ -20,97 +20,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base -from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import errors -from tensorflow.python.ops import array_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.platform import test -class ZipDatasetTest(test.TestCase): - - def testZipDataset(self): - component_placeholders = [ - array_ops.placeholder(dtypes.int64), - array_ops.placeholder(dtypes.int64), - array_ops.placeholder(dtypes.float64) - ] - - datasets = tuple([ - dataset_ops.Dataset.from_tensor_slices(component_placeholder) - for component_placeholder in component_placeholders - ]) - zipped = dataset_ops.Dataset.zip(datasets) - - iterator = zipped.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - with self.test_session() as sess: - equal_length_components = [ - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0]) - ] - sess.run(init_op, feed_dict={ph: value for ph, value in zip( - component_placeholders, equal_length_components)}) - for i in range(4): - results = sess.run(get_next) - for component, result_component in zip( - equal_length_components, results): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - variable_length_components = [[1, 2, 3, 4], [1, 2, 3, 4, 5], [1.0, 2.0]] - sess.run(init_op, feed_dict={ph: value for ph, value in zip( - component_placeholders, variable_length_components)}) - for i in range(2): - results = sess.run(get_next) - for component, result_component in zip( - variable_length_components, results): - self.assertAllEqual(component[i], result_component) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - def testNestedZipDataset(self): - component_placeholders = [ - array_ops.placeholder(dtypes.int64, shape=[4, 20]), - array_ops.placeholder(dtypes.int64, shape=[4, 22]), - array_ops.placeholder(dtypes.float64, shape=[4]) - ] - - datasets = [ - dataset_ops.Dataset.from_tensor_slices(component_placeholder) - for component_placeholder in component_placeholders - ] - zipped = dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2]))) - - iterator = zipped.make_initializable_iterator() - init_op = iterator.initializer - get_next = iterator.get_next() - - self.assertEqual([20], get_next[0].shape) - self.assertEqual([22], get_next[1][0].shape) - self.assertEqual([], get_next[1][1].shape) - - with self.test_session() as sess: - equal_length_components = [ - np.tile(np.array([[1], [2], [3], [4]]), 20), - np.tile(np.array([[12], [13], [14], [15]]), 22), - np.array([37.0, 38.0, 39.0, 40.0]) - ] - sess.run(init_op, feed_dict={ph: value for ph, value in zip( - component_placeholders, equal_length_components)}) - for i in range(4): - result1, (result2, result3) = sess.run(get_next) - self.assertAllEqual(equal_length_components[0][i], result1) - self.assertAllEqual(equal_length_components[1][i], result2) - self.assertAllEqual(equal_length_components[2][i], result3) - with self.assertRaises(errors.OutOfRangeError): - sess.run(get_next) - - class ZipDatasetSerializationTest( dataset_serialization_test_base.DatasetSerializationTestBase): diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index 25ed58cdf5833cd041582046bc1a358625e321e0..b488357f226d0922bba3799cc1f4b5c75e2e8328 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -15,7 +15,7 @@ py_library( name = "dataset_ops", srcs = [ "counter.py", - "dataset_ops.py", + "get_single_element.py", ], srcs_version = "PY2AND3", deps = [ @@ -40,6 +40,25 @@ py_library( ], ) +py_library( + name = "random_ops", + srcs = [ + "random_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + "//tensorflow/python/data/util:sparse", + ], +) + py_library( name = "readers", srcs = [ @@ -62,6 +81,19 @@ py_library( ], ) +py_library( + name = "shuffle_ops", + srcs = [ + "shuffle_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":random_ops", + ":transformation_ops", + "//tensorflow/python/data/ops:dataset_ops", + ], +) + py_library( name = "transformation_ops", srcs = [ @@ -73,9 +105,12 @@ py_library( "resampling.py", "scan_ops.py", "stats_ops.py", + "unique.py", ], srcs_version = "PY2AND3", deps = [ + ":contrib_op_loader", + ":gen_dataset_ops", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", @@ -89,6 +124,7 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:convert", "//tensorflow/python/data/util:nest", "//tensorflow/python/data/util:sparse", "//third_party/py/numpy", @@ -96,36 +132,44 @@ py_library( ) tf_gen_op_wrapper_py( - name = "prefetching_ops", - out = "gen_prefetching_ops.py", - deps = ["//tensorflow/contrib/data:prefetching_ops_op_lib"], + name = "gen_dataset_ops", + out = "gen_dataset_ops.py", + deps = ["//tensorflow/contrib/data:dataset_ops_op_lib"], ) tf_kernel_library( - name = "prefetching_ops_kernels", + name = "dataset_ops_kernels", deps = [ - "//tensorflow/contrib/data/kernels:prefetching_kernels", + "//tensorflow/contrib/data/kernels:dataset_kernels", "//tensorflow/core:framework", ], alwayslink = 1, ) tf_custom_op_py_library( - name = "prefetching_py", - srcs = ["prefetching_ops.py"], - dso = ["//tensorflow/contrib/data:_prefetching_ops.so"], + name = "contrib_op_loader", + srcs = ["contrib_op_loader.py"], + dso = ["//tensorflow/contrib/data:_dataset_ops.so"], kernels = [ - ":prefetching_ops_kernels", - "//tensorflow/contrib/data:prefetching_ops_op_lib", + ":dataset_ops_kernels", + "//tensorflow/contrib/data:dataset_ops_op_lib", ], srcs_version = "PY2AND3", deps = [ - ":prefetching_ops", + ":gen_dataset_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:platform", ], ) +py_library( + name = "prefetching_ops", + srcs = ["prefetching_ops.py"], + deps = [ + ":contrib_op_loader", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py index 63782d229e1535892686f202ca1f0833dee6ed80..6eb512dec67cb7b9c8c4518d03aee0b436205f9a 100644 --- a/tensorflow/contrib/data/python/ops/batching.py +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -22,6 +22,7 @@ from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops @@ -231,32 +232,29 @@ class DenseToSparseBatchDataset(dataset_ops.Dataset): input_dataset.output_types) self._input_dataset = input_dataset self._batch_size = batch_size - # pylint: disable=protected-access - self._row_shape = dataset_ops._partial_shape_to_tensor(row_shape) - # pylint: enable=protected-access + self._row_shape = row_shape def _as_variant_tensor(self): return gen_dataset_ops.dense_to_sparse_batch_dataset( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._batch_size, - self._row_shape, - output_shapes=self.output_shapes, - output_types=self.output_types) + row_shape=dataset_ops._partial_shape_to_tensor(self._row_shape), # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) @property def output_classes(self): - return (ops.Tensor, ops.Tensor, ops.Tensor) + return sparse_tensor.SparseTensor @property def output_shapes(self): - num_elements = tensor_shape.Dimension(None) - return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1), - tensor_shape.vector(num_elements), - tensor_shape.vector(self._row_shape.shape[0] + 1)) + return tensor_shape.vector(None).concatenate(self._row_shape) @property def output_types(self): - return (dtypes.int64, self._input_dataset.output_types, dtypes.int64) + return self._input_dataset.output_types class _RestructuredDataset(dataset_ops.Dataset): @@ -390,17 +388,12 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): """Fused implementation of `map` and `batch`. Maps `map_func` across `batch_size` consecutive elements of this dataset - and then combines them into a batch. Similarly to `batch_and_drop_remainder`, - if the batch size does not evenly divide the input dataset size, this - transformation will drop the final smaller element. - - - Functionally, it is equivalent to `map` followed by - `batch_and_drop_remainder`. However, by fusing the two transformations - together, the implementation can be more efficient. This transformation is a - stop gap solution for performance critical workloads. Once automatic input - pipeline optimization are implemented, the fusing of map and batch will not - need to be exposed at the API level and this method will be removed. + and then combines them into a batch. Functionally, it is equivalent to `map` + followed by `batch`. However, by fusing the two transformations together, the + implementation can be more efficient. Surfacing this transformation in the API + is temporary. Once automatic input pipeline optimization is implemented, + the fusing of `map` and `batch` will happen automatically and this API will be + deprecated. Args: map_func: A function mapping a nested structure of tensors to another @@ -410,11 +403,11 @@ def map_and_batch(map_func, batch_size, num_parallel_batches=1): num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the number of batches to create in parallel. On one hand, higher values can help mitigate the effect of stragglers. On the other hand, higher values - can increasing contention if CPU is scarce. + can increase contention if CPU is scarce. Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/contrib/data/python/ops/contrib_op_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..8f495a9dc9c82311435e71d2ac9ed35fd9aea794 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/contrib_op_loader.py @@ -0,0 +1,24 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python helper for loading contrib ops and kernels.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_dataset_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_dataset_ops.so")) diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 626a9e0edcea5928b1636c1a2a86e83657c966a5..ff15c4451ad987bcd77dbdd022a1c070056c47e1 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -364,7 +364,7 @@ class Dataset(dataset_ops.Dataset): When reading a single input file, you can skip elements as follows: ```python - d = tf.contrib.data.TFRecordDataset(FLAGS.input_file) + d = tf.data.TFRecordDataset(FLAGS.input_file) d = d.shard(FLAGS.num_workers, FLAGS.worker_index) d = d.repeat(FLAGS.num_epochs) d = d.shuffle(FLAGS.shuffle_buffer_size) @@ -382,12 +382,11 @@ class Dataset(dataset_ops.Dataset): sharding strategy within a complete pipeline: ```python - d = Dataset.list_files(FLAGS.pattern) + d = tf.data.Dataset.list_files(FLAGS.pattern) d = d.shard(FLAGS.num_workers, FLAGS.worker_index) d = d.repeat(FLAGS.num_epochs) d = d.shuffle(FLAGS.shuffle_buffer_size) - d = d.repeat() - d = d.interleave(tf.contrib.data.TFRecordDataset, + d = d.interleave(tf.data.TFRecordDataset, cycle_length=FLAGS.num_readers, block_length=1) d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads) ``` @@ -484,7 +483,7 @@ class Dataset(dataset_ops.Dataset): num_threads=None, output_buffer_size=None, num_parallel_calls=None): - """Maps `map_func` across this datset. + """Maps `map_func` across this dataset. Args: map_func: A function mapping a nested structure of tensors (having @@ -549,7 +548,7 @@ class Dataset(dataset_ops.Dataset): elements are produced. `cycle_length` controls the number of input elements that are processed concurrently. If you set `cycle_length` to 1, this transformation will handle one input element at a time, and will produce - identical results = to @{tf.contrib.data.Dataset.flat_map}. In general, + identical results = to @{tf.data.Dataset.flat_map}. In general, this transformation will apply `map_func` to `cycle_length` input elements, open iterators on the returned `Dataset` objects, and cycle through them producing `block_length` consecutive elements from each iterator, and diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py index aa629cba479102ee4244884e7c546615b28cf4e5..6c21e489f7c35484ebacd465e3b46d6920df5933 100644 --- a/tensorflow/contrib/data/python/ops/error_ops.py +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -17,10 +17,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse -from tensorflow.python.ops import gen_dataset_ops def ignore_errors(): diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py new file mode 100644 index 0000000000000000000000000000000000000000..a817b45b71b608810a9d7536ec123ab84f7cdc3b --- /dev/null +++ b/tensorflow/contrib/data/python/ops/get_single_element.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for Datasets and Iterators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops + + +def get_single_element(dataset): + """Returns the single element in `dataset` as a nested structure of tensors. + + This function enables you to use a @{tf.data.Dataset} in a stateless + "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This can be useful when your preprocessing transformations are expressed + as a `Dataset`, and you want to use the transformation at serving time. + For example: + + ```python + input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) + + def preprocessing_fn(input_str): + # ... + return image, label + + dataset = (tf.data.Dataset.from_tensor_slices(input_batch) + .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) + .batch(BATCH_SIZE)) + + image_batch, label_batch = tf.contrib.data.get_single_element(dataset) + ``` + + Args: + dataset: A @{tf.data.Dataset} object containing a single element. + + Returns: + A nested structure of @{tf.Tensor} objects, corresponding to the single + element of `dataset`. + + Raises: + TypeError: if `dataset` is not a `tf.data.Dataset` object. + InvalidArgumentError (at runtime): if `dataset` does not contain exactly + one element. + """ + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + return nest.pack_sequence_as( + dataset.output_types, + gen_dataset_ops.dataset_to_single_element( + dataset._as_variant_tensor(), # pylint: disable=protected-access + output_types=nest.flatten(dataset.output_types), + output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py index ef91c56726e969053fdad667dda3e89430045652..67b085002aa7797d858837fea4646fb968ad5d97 100644 --- a/tensorflow/contrib/data/python/ops/grouping.py +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -45,7 +45,7 @@ def group_by_window(key_func, key_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a scalar `tf.int64` tensor. - reduce_func: A function mapping a key and a dataset of up to `batch_size` + reduce_func: A function mapping a key and a dataset of up to `window_size` consecutive elements matching that key to another dataset. window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of consecutive elements matching the same key to combine in a single diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py index 53324e06e7f1dc249388410f0e14e42336630cd1..3124ca1d1540e12d949dded88ce1c66181be3595 100644 --- a/tensorflow/contrib/data/python/ops/interleave_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import convert from tensorflow.python.data.util import nest from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes @@ -31,7 +32,7 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): """A `Dataset` that maps a function over its input and flattens the result.""" def __init__(self, input_dataset, map_func, cycle_length, block_length, - sloppy): + sloppy, buffer_output_elements, prefetch_input_elements): """See `tf.contrib.data.parallel_interleave()` for details.""" super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset @@ -74,6 +75,14 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): block_length, dtype=dtypes.int64, name="block_length") self._sloppy = ops.convert_to_tensor( sloppy, dtype=dtypes.bool, name="sloppy") + self._buffer_output_elements = convert.optional_param_to_tensor( + "buffer_output_elements", + buffer_output_elements, + argument_default=2 * block_length) + self._prefetch_input_elements = convert.optional_param_to_tensor( + "prefetch_input_elements", + prefetch_input_elements, + argument_default=2 * cycle_length) def _as_variant_tensor(self): return gen_dataset_ops.parallel_interleave_dataset( @@ -82,6 +91,8 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): self._cycle_length, self._block_length, self._sloppy, + self._buffer_output_elements, + self._prefetch_input_elements, f=self._map_func, output_types=nest.flatten( sparse.as_dense_types(self.output_types, self.output_classes)), @@ -101,7 +112,12 @@ class ParallelInterleaveDataset(dataset_ops.Dataset): return self._output_types -def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): +def parallel_interleave(map_func, + cycle_length, + block_length=1, + sloppy=False, + buffer_output_elements=None, + prefetch_input_elements=None): """A parallel version of the `Dataset.interleave()` transformation. `parallel_interleave()` maps `map_func` across its input to produce nested @@ -129,12 +145,17 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): Args: map_func: A function mapping a nested structure of tensors to a `Dataset`. - cycle_length: The number of threads to interleave from in parallel. - block_length: The number of consecutive elements to pull from a thread - before advancing to the next thread. + cycle_length: The number of input `Dataset`s to interleave from in parallel. + block_length: The number of consecutive elements to pull from an input + `Dataset` before advancing to the next input `Dataset`. sloppy: If false, elements are produced in deterministic order. Otherwise, the implementation is allowed, for the sake of expediency, to produce elements in a non-deterministic order. + buffer_output_elements: The number of elements each iterator being + interleaved should buffer (similar to the `.prefetch()` transformation for + each interleaved iterator). + prefetch_input_elements: The number of input elements to transform to + iterators before they are needed for interleaving. Returns: A `Dataset` transformation function, which can be passed to @@ -142,7 +163,9 @@ def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): """ def _apply_fn(dataset): return ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy) + dataset, map_func, cycle_length, block_length, sloppy, + buffer_output_elements, prefetch_input_elements) + return _apply_fn @@ -187,11 +210,11 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): map_func: A function mapping a nested structure of tensors (having shapes and types defined by `self.output_shapes` and `self.output_types`) to a `Dataset`. - cycle_length: The number of threads to interleave from in parallel. - block_length: The number of consecutive elements to pull from a thread - before advancing to the next thread. Note: sloppy_interleave will - skip the remainder of elements in the block_length in order to avoid - blocking. + cycle_length: The number of input `Dataset`s to interleave from in parallel. + block_length: The number of consecutive elements to pull from an input + `Dataset` before advancing to the next input `Dataset`. Note: + `sloppy_interleave` will skip the remainder of elements in the + `block_length` in order to avoid blocking. Returns: A `Dataset` transformation function, which can be passed to @@ -199,5 +222,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1): """ def _apply_fn(dataset): return ParallelInterleaveDataset( - dataset, map_func, cycle_length, block_length, sloppy=True) + dataset, + map_func, + cycle_length, + block_length, + sloppy=True, + buffer_output_elements=None, + prefetch_input_elements=None) + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py index cfe8012b5657995b78d701528ea35cbb3748adb9..96a9e9ed6649444dac5e56d7dd2fcdb62fc56459 100644 --- a/tensorflow/contrib/data/python/ops/prefetching_ops.py +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -17,12 +17,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import gen_prefetching_ops -from tensorflow.contrib.util import loader -from tensorflow.python.platform import resource_loader - -_prefetching_ops = loader.load_op_library( - resource_loader.get_path_to_datafile("../../_prefetching_ops.so")) +from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import +from tensorflow.contrib.data.python.ops import gen_dataset_ops # TODO(rohanj): Add a python class that constructs resource in the __init__ @@ -35,7 +31,7 @@ def function_buffering_resource(string_arg, thread_pool_size=1, container="", name=None): - return gen_prefetching_ops.function_buffering_resource( + return gen_dataset_ops.function_buffering_resource( string_arg=string_arg, target_device=target_device, shared_name=shared_name, @@ -49,7 +45,7 @@ def function_buffering_resource(string_arg, def function_buffering_resource_get_next(function_buffer_resource, output_types, name=None): - return gen_prefetching_ops.function_buffering_resource_get_next( + return gen_dataset_ops.function_buffering_resource_get_next( function_buffer_resource=function_buffer_resource, output_types=output_types, name=name) diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7d727165feabb101549567f28a2dfa07083de244 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/random_ops.py @@ -0,0 +1,67 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Datasets for random number generators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops + + +class RandomDataset(dataset_ops.Dataset): + """A `Dataset` of pseudorandom values.""" + + def __init__(self, seed=None): + """A `Dataset` of pseudorandom values.""" + super(RandomDataset, self).__init__() + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + return gen_dataset_ops.random_dataset( + seed=self._seed, + seed2=self._seed2, + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.int64 diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py index acb7a43211482f9cdeed66542abab5dbde78d60e..57f30102778f3bac47580f9bdf94e411dfe1b621 100644 --- a/tensorflow/contrib/data/python/ops/readers.py +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -17,9 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import readers from tensorflow.python.data.util import nest from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -27,74 +25,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import gfile -from tensorflow.python.util import deprecation - - -class TextLineDataset(contrib_dataset_ops.Dataset): - """A `Dataset` comprising lines from one or more text files.""" - - @deprecation.deprecated(None, "Use `tf.data.TextLineDataset`.") - def __init__(self, filenames, compression_type=None, buffer_size=None): - """Creates a `TextLineDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. - buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes - to buffer. A value of 0 results in the default buffering values chosen - based on the compression type. - """ - dataset = readers.TextLineDataset(filenames, compression_type, - buffer_size) - super(TextLineDataset, self).__init__(dataset) - - -class TFRecordDataset(contrib_dataset_ops.Dataset): - """A `Dataset` comprising records from one or more TFRecord files.""" - - @deprecation.deprecated(None, "Use `tf.data.TFRecordDataset`.") - def __init__(self, filenames, compression_type=None, buffer_size=None): - """Creates a `TFRecordDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. - buffer_size: (Optional.) A `tf.int64` scalar representing the number of - bytes in the read buffer. 0 means no buffering. - """ - dataset = readers.TFRecordDataset(filenames, compression_type, - buffer_size) - super(TFRecordDataset, self).__init__(dataset) - - -class FixedLengthRecordDataset(contrib_dataset_ops.Dataset): - """A `Dataset` of fixed-length records from one or more binary files.""" - - @deprecation.deprecated(None, "Use `tf.data.FixedLengthRecordDataset`.") - def __init__(self, - filenames, - record_bytes, - header_bytes=None, - footer_bytes=None, - buffer_size=None): - """Creates a `FixedLengthRecordDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - record_bytes: A `tf.int64` scalar representing the number of bytes in - each record. - header_bytes: (Optional.) A `tf.int64` scalar representing the number of - bytes to skip at the start of a file. - footer_bytes: (Optional.) A `tf.int64` scalar representing the number of - bytes to ignore at the end of a file. - buffer_size: (Optional.) A `tf.int64` scalar representing the number of - bytes to buffer when reading. - """ - dataset = readers.FixedLengthRecordDataset( - filenames, record_bytes, header_bytes, footer_bytes, buffer_size) - super(FixedLengthRecordDataset, self).__init__(dataset) def read_batch_features(file_pattern, @@ -179,6 +109,7 @@ def read_batch_features(file_pattern, dataset = dataset.shuffle(capacity) dataset = dataset.batch(batch_size) dataset = dataset.map(lambda x: parsing_ops.parse_example(x, features)) + dataset = dataset.prefetch(1) iterator = dataset.make_one_shot_iterator() outputs = iterator.get_next() return outputs @@ -215,14 +146,7 @@ def _get_file_names(file_pattern, randomize_input): return file_names -class SqlDataset(contrib_dataset_ops.Dataset): - - def __init__(self, driver_name, data_source_name, query, output_types): - dataset = _SqlDataset(driver_name, data_source_name, query, output_types) - super(SqlDataset, self).__init__(dataset) - - -class _SqlDataset(dataset_ops.Dataset): +class SqlDataset(dataset_ops.Dataset): """A `Dataset` consisting of the results from a SQL query.""" def __init__(self, driver_name, data_source_name, query, output_types): @@ -254,7 +178,7 @@ class _SqlDataset(dataset_ops.Dataset): output_types: A tuple of `tf.DType` objects representing the types of the columns returned by `query`. """ - super(_SqlDataset, self).__init__() + super(SqlDataset, self).__init__() self._driver_name = ops.convert_to_tensor( driver_name, dtype=dtypes.string, name="driver_name") self._data_source_name = ops.convert_to_tensor( diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py index 2744786e9eec4c9268ba854df6ea761339bb0b4e..1c88366273f5d186509454188e02350d4ea9f66b 100644 --- a/tensorflow/contrib/data/python/ops/scan_ops.py +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -188,7 +188,7 @@ def scan(initial_state, scan_func): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): return _ScanDataset(dataset, initial_state, scan_func) diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..99bb79bc06a421f811869ca9169aaa11deaca2f3 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py @@ -0,0 +1,120 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental shuffle ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import gen_dataset_ops + + +class _ShuffleAndRepeatDataset(dataset_ops.Dataset): + """A `Dataset` that fuses `shuffle` and `repeat`.""" + + def __init__(self, + input_dataset, + buffer_size, + count=None, + seed=None): + """See `Dataset.map()` for details.""" + super(_ShuffleAndRepeatDataset, self).__init__() + self._input_dataset = input_dataset + self._buffer_size = ops.convert_to_tensor( + buffer_size, dtype=dtypes.int64, name="buffer_size") + if count is None: + self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count") + else: + self._count = ops.convert_to_tensor( + count, dtype=dtypes.int64, name="count") + + seed, seed2 = random_seed.get_seed(seed) + if seed is None: + self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed") + else: + self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed") + if seed2 is None: + self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2") + else: + self._seed2 = ops.convert_to_tensor( + seed2, dtype=dtypes.int64, name="seed2") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.shuffle_and_repeat_dataset( + input_resource, + buffer_size=self._buffer_size, + count=self._count, + seed=self._seed, + seed2=self._seed2, + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) + # pylint: enable=protected-access + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types + + +def shuffle_and_repeat(buffer_size, count=None, seed=None): + """Shuffles and repeats a Dataset returning a new permutation for each epoch. + + `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))` + + is equivalent to + + `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)` + + The difference is that the latter dataset is not serializable. So, + if you need to checkpoint an input pipeline with reshuffling you must use + this implementation. + + Args: + buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the + maximum number elements that will be buffered when prefetching. + count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + number of times the dataset should be repeated. The default behavior + (if `count` is `None` or `-1`) is for the dataset be repeated + indefinitely. + seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the + random seed that will be used to create the distribution. See + @{tf.set_random_seed} for behavior. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): # pylint: disable=missing-docstring + return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/contrib/data/python/ops/stats_ops.py index b8875bd533ddc9e2c195646619dccf3aab5225e4..9cd1701c397b5a0bf5cc47c1bcab033704794d80 100644 --- a/tensorflow/contrib/data/python/ops/stats_ops.py +++ b/tensorflow/contrib/data/python/ops/stats_ops.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops @@ -117,7 +118,7 @@ def bytes_produced_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -139,7 +140,7 @@ def latency_stats(tag): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): @@ -161,8 +162,10 @@ class _StatsDataset(dataset_ops.Dataset): return self._op_function( self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._tag, - output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes)), + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes))) @property def output_shapes(self): diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py new file mode 100644 index 0000000000000000000000000000000000000000..133e17d20d0fc4c8d52cef3c95c132374e927a0b --- /dev/null +++ b/tensorflow/contrib/data/python/ops/unique.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unique element dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import gen_dataset_ops + + +def unique(): + """Creates a `Dataset` from another `Dataset`, discarding duplicates. + + Use this transformation to produce a dataset that contains one instance of + each unique element in the input. For example: + + ```python + dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1]) + + # Using `unique()` will drop the duplicate elements. + dataset = dataset.apply(tf.contrib.data.unique()) # ==> { 1, 37, 2 } + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return UniqueDataset(dataset) + + return _apply_fn + + +class UniqueDataset(dataset_ops.Dataset): + """A `Dataset` contains the unique elements from its input.""" + + def __init__(self, input_dataset): + """See `unique()` for details.""" + super(UniqueDataset, self).__init__() + self._input_dataset = input_dataset + if input_dataset.output_types not in (dtypes.int32, dtypes.int64, + dtypes.string): + raise TypeError( + "`tf.contrib.data.unique()` only supports inputs with a single " + "`tf.int32`, `tf.int64`, or `tf.string` component.") + + def _as_variant_tensor(self): + return gen_dataset_ops.unique_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + output_shapes=nest.flatten( + sparse.as_dense_shapes(self.output_shapes, self.output_classes)), + output_types=nest.flatten( + sparse.as_dense_types(self.output_types, self.output_classes))) + + @property + def output_classes(self): + return self._input_dataset.output_classes + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD index 87c80740a8f0c0721394b5d832bc96e548e3a313..f6de5998d73a4869d2444cd90c9b64d1a2c889ac 100644 --- a/tensorflow/contrib/decision_trees/proto/BUILD +++ b/tensorflow/contrib/decision_trees/proto/BUILD @@ -7,7 +7,11 @@ exports_files([ "generic_tree_model_proto.swig", ]) -load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_proto_library", + "tf_pyclif_proto_library", +) filegroup( name = "all_files", @@ -34,3 +38,10 @@ tf_proto_library( protodeps = [":generic_tree_model"], visibility = ["//visibility:public"], ) + +tf_pyclif_proto_library( + name = "generic_tree_model_pyclif", + proto_lib = ":generic_tree_model", + proto_srcfile = "generic_tree_model.proto", + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig index d3d201afd5761e7c5c136301c779222bedc68492..cafb9314caee1c4907786b8101e7c71bd7095306 100644 --- a/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig +++ b/tensorflow/contrib/decision_trees/proto/generic_tree_model_proto.swig @@ -2,7 +2,7 @@ %include "net/proto/swig/protofunc.swig" -#ifndef MUST_USE_RESULT +#ifndef ABSL_MUST_USE_RESULT #error Use this file only as a %include or %import after google.swig. #endif diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index b2c641f8ab3ea23c5135042e4b1223d487ae8cbc..7f510c42215f48a9e795eb81bd9f66b0a2108335 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -60,6 +60,7 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:spectral_ops", "//tensorflow/python:state_ops", "//tensorflow/python:tensor_util", "//tensorflow/python:util", @@ -127,6 +128,19 @@ cuda_py_test( tags = ["no_pip"], ) +cuda_py_test( + name = "autoregressive_test", + size = "small", + srcs = ["python/kernel_tests/autoregressive_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "binomial_test", size = "small", @@ -437,6 +451,7 @@ cuda_py_test( "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:spectral_ops_test_util", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", "//tensorflow/python:platform_test", @@ -916,6 +931,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "real_nvp_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/real_nvp_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "permute_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 66827179e9fa1bea852f55246c263c4696cf3bdc..61c411271d0bb8d7b4cc3b14992b82ec1e5674ed 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -23,6 +23,7 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops.autoregressive import * from tensorflow.contrib.distributions.python.ops.binomial import * from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * @@ -39,6 +40,7 @@ from tensorflow.contrib.distributions.python.ops.geometric import * from tensorflow.contrib.distributions.python.ops.half_normal import * from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * +from tensorflow.contrib.distributions.python.ops.kumaraswamy import * from tensorflow.contrib.distributions.python.ops.logistic import * from tensorflow.contrib.distributions.python.ops.mixture import * from tensorflow.contrib.distributions.python.ops.mixture_same_family import * @@ -84,6 +86,7 @@ from tensorflow.python.ops.distributions.uniform import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + 'auto_correlation', 'bijectors', 'Cauchy', 'ConditionalDistribution', @@ -92,9 +95,9 @@ _allowed_symbols = [ 'NOT_REPARAMETERIZED', 'ReparameterizationType', 'Distribution', + 'Autoregressive', 'Binomial', 'Bernoulli', - 'BernoulliWithSigmoidProbs', 'Beta', 'BetaWithSoftplusConcentration', 'Categorical', @@ -112,6 +115,7 @@ _allowed_symbols = [ 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', + 'Kumaraswamy', 'Laplace', 'LaplaceWithSoftplusScale', 'Logistic', @@ -159,6 +163,10 @@ _allowed_symbols = [ 'assign_log_moving_mean_exp', 'moving_mean_variance', 'estimator_head_distribution_regression', + 'quadrature_scheme_softmaxnormal_gauss_hermite', + 'quadrature_scheme_softmaxnormal_quantiles', + 'quadrature_scheme_lognormal_gauss_hermite', + 'quadrature_scheme_lognormal_quantiles', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py new file mode 100644 index 0000000000000000000000000000000000000000..0928dc3f358ede693865a8d1ff9257a0ecbe9499 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/autoregressive_test.py @@ -0,0 +1,94 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import autoregressive as autoregressive_lib +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.affine import Affine +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.platform import test + + +class AutogressiveTest(test_util.VectorDistributionTestHelpers, test.TestCase): + """Tests the Autoregressive distribution.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def _random_scale_tril(self, event_size): + n = np.int32(event_size * (event_size + 1) // 2) + p = 2. * self._rng.random_sample(n).astype(np.float32) - 1. + return distribution_util.fill_triangular(0.25 * p) + + def _normal_fn(self, affine_bijector): + def _fn(samples): + scale = math_ops.exp(affine_bijector.forward(samples)) + return independent_lib.Independent( + normal_lib.Normal(loc=0., scale=scale, validate_args=True), + reinterpreted_batch_ndims=1) + return _fn + + def testSampleAndLogProbConsistency(self): + batch_shape = [] + event_size = 2 + with self.test_session() as sess: + batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) + sample0 = array_ops.zeros(batch_event_shape) + affine = Affine(scale_tril=self._random_scale_tril(event_size)) + ar = autoregressive_lib.Autoregressive( + self._normal_fn(affine), sample0, validate_args=True) + self.run_test_sample_consistent_log_prob( + sess.run, ar, radius=1., center=0., rtol=0.01) + + def testCompareToBijector(self): + """Demonstrates equivalence between TD, Bijector approach and AR dist.""" + sample_shape = np.int32([4, 5]) + batch_shape = np.int32([]) + event_size = np.int32(2) + with self.test_session() as sess: + batch_event_shape = np.concatenate([batch_shape, [event_size]], axis=0) + sample0 = array_ops.zeros(batch_event_shape) + affine = Affine(scale_tril=self._random_scale_tril(event_size)) + ar = autoregressive_lib.Autoregressive( + self._normal_fn(affine), sample0, validate_args=True) + ar_flow = MaskedAutoregressiveFlow( + is_constant_jacobian=True, + shift_and_log_scale_fn=lambda x: [None, affine.forward(x)], + validate_args=True) + td = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=ar_flow, + event_shape=[event_size], + batch_shape=batch_shape, + validate_args=True) + x_shape = np.concatenate( + [sample_shape, batch_shape, [event_size]], axis=0) + x = 2. * self._rng.random_sample(x_shape).astype(np.float32) - 1. + td_log_prob_, ar_log_prob_ = sess.run([td.log_prob(x), ar.log_prob(x)]) + self.assertAllClose(td_log_prob_, ar_log_prob_, atol=0., rtol=1e-6) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py index 25a9b6f5fe2ed6d218d6b44650fce17fa89c0664..dcfb0eb05185d36d96947905c2eb91b2201aece1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -22,9 +22,9 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import _gen_mask from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import masked_autoregressive_default_template from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import _gen_mask from tensorflow.python.framework import constant_op from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables @@ -149,5 +149,17 @@ class MaskedAutoregressiveFlowShiftOnlyTest(MaskedAutoregressiveFlowTest): } +class MaskedAutoregressiveFlowUnrollLoopTest(MaskedAutoregressiveFlowTest): + + @property + def _autoregressive_flow_kwargs(self): + return { + "shift_and_log_scale_fn": masked_autoregressive_default_template( + hidden_layers=[2], shift_only=False), + "is_constant_jacobian": False, + "unroll_loop": True, + } + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..46fe7797419a9906ecdad60dd0dfe1e9d7c743ed --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/real_nvp_test.py @@ -0,0 +1,144 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MaskedAutoregressiveFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import real_nvp_default_template +from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import RealNVP +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.platform import test + + +class RealNVPTest(test_util.VectorDistributionTestHelpers, test.TestCase): + + @property + def _real_nvp_kwargs(self): + return { + "shift_and_log_scale_fn": real_nvp_default_template( + hidden_layers=[3], shift_only=False), + "is_constant_jacobian": False, + } + + def testBijector(self): + x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4 * 2) + with self.test_session() as sess: + nvp = RealNVP( + num_masked=4, + validate_args=True, + **self._real_nvp_kwargs) + x = constant_op.constant(x_) + forward_x = nvp.forward(x) + # Use identity to invalidate cache. + inverse_y = nvp.inverse(array_ops.identity(forward_x)) + fldj = nvp.forward_log_det_jacobian(x) + # Use identity to invalidate cache. + ildj = nvp.inverse_log_det_jacobian(array_ops.identity(forward_x)) + variables.global_variables_initializer().run() + [ + forward_x_, + inverse_y_, + ildj_, + fldj_, + ] = sess.run([ + forward_x, + inverse_y, + ildj, + fldj, + ]) + self.assertEqual("real_nvp", nvp.name) + self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.) + self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.) + self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.) + + def testMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + nvp = RealNVP( + num_masked=3, + validate_args=True, + **self._real_nvp_kwargs) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=nvp, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + def testInvertMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + nvp = Invert(RealNVP( + num_masked=3, + validate_args=True, + **self._real_nvp_kwargs)) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=nvp, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + +class NICETest(RealNVPTest): + + @property + def _real_nvp_kwargs(self): + return { + "shift_and_log_scale_fn": real_nvp_default_template( + hidden_layers=[2], shift_only=True), + "is_constant_jacobian": True, + } + + +class RealNVPConstantShiftScaleTest(RealNVPTest): + + @property + def _real_nvp_kwargs(self): + + def constant_shift_log_scale_fn(x0, output_units): + del x0, output_units + shift = constant_op.constant([0.1]) + log_scale = constant_op.constant([0.5]) + return shift, log_scale + + return { + "shift_and_log_scale_fn": constant_shift_log_scale_fn, + "is_constant_jacobian": True, + } + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py index 38b3a23c2d684a6f89b7c4be4a763c649bf4de15..e216d88cb190dc16fc0056186f80817d6f2d7c67 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -22,14 +22,28 @@ import numpy as np from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite from tensorflow.python.platform import test -class ReshapeBijectorTest(test.TestCase): - """Tests correctness of the reshape transformation.""" +@test_util.with_c_api +class _ReshapeBijectorTest(object): + """Base class for testing the reshape transformation. + + Methods defined in this class call a method self.build_shapes() that + is implemented by subclasses defined below, returning respectively + ReshapeBijectorTestStatic: static shapes, + ReshapeBijectorTestDynamic: shape placeholders of known ndims, and + ReshapeBijectorTestDynamicNdims: shape placeholders of unspecified ndims, + so that each test in this base class is automatically run over all + three cases. The subclasses also implement assertRaisesError to test + for either Python exceptions (in the case of static shapes) or + TensorFlow op errors (dynamic shapes). + """ def setUp(self): self._rng = np.random.RandomState(42) @@ -40,9 +54,10 @@ class ReshapeBijectorTest(test.TestCase): expected_y = np.reshape(expected_x, [4, 6]) with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,]) bijector = Reshape( - event_shape_out=[6,], - event_shape_in=[3, 2], + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) (x_, y_, @@ -52,66 +67,23 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.forward_log_det_jacobian(expected_x), bijector.inverse_log_det_jacobian(expected_y), - )) + ), feed_dict=feed_dict) self.assertEqual("reshape", bijector.name) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) - def testEventShapeDynamicNdims(self): - """Check forward/inverse shape methods with dynamic ndims.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) + def testEventShapeTensor(self): + """Test event_shape_tensor methods when even ndims may be dynamic.""" + shape_in_static = [2, 3] + shape_out_static = [6,] + shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static, + shape_out_static) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, validate_args=True) - - # using the _tensor methods, we should always get a fully-specified - # result since these are evaluated at graph runtime. - with self.test_session() as sess: - (shape_out_, - shape_in_) = sess.run(( - bijector.forward_event_shape_tensor(shape_in), - bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeDynamic(self): - """Check shape methods with static ndims but dynamic shape.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_in_partial = tensor_shape.TensorShape([None,]) - shape_in_ph = array_ops.placeholder( - shape=[1,], dtype=dtypes.int32) - - shape_out = tensor_shape.TensorShape([2, 3]) - shape_out_partial = tensor_shape.TensorShape([None, None]) - shape_out_ph = array_ops.placeholder( - shape=[2,], dtype=dtypes.int32) - - bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, - validate_args=True) - - # if event shapes are not statically available, should - # return partially-specified TensorShapes. - self.assertAllEqual( - bijector.forward_event_shape(shape_in).as_list(), - shape_out_partial.as_list()) - self.assertAllEqual( - bijector.inverse_event_shape(shape_out).as_list(), - shape_in_partial.as_list()) + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) # using the _tensor methods, we should always get a fully-specified # result since these are evaluated at graph runtime. @@ -120,42 +92,9 @@ class ReshapeBijectorTest(test.TestCase): shape_in_) = sess.run(( bijector.forward_event_shape_tensor(shape_in), bijector.inverse_event_shape_tensor(shape_out), - ), feed_dict={ - shape_in_ph: shape_in, - shape_out_ph: shape_out, - }) - self.assertAllEqual(shape_out, shape_out_) - self.assertAllEqual(shape_in, shape_in_) - - def testEventShapeStatic(self): - """Check shape methods when shape is statically known.""" - - shape_in = tensor_shape.TensorShape([6,]) - shape_out = tensor_shape.TensorShape([2, 3]) - - bijector_static = Reshape( - event_shape_out=shape_out, - event_shape_in=shape_in, - validate_args=True) - - # test that forward_ and inverse_event_shape do sensible things - # when shapes are statically known. - self.assertEqual( - bijector_static.forward_event_shape(shape_in), - shape_out) - self.assertEqual( - bijector_static.inverse_event_shape(shape_out), - shape_in) - - with self.test_session() as sess: - (shape_out_static_, - shape_in_static_, - ) = sess.run(( - bijector_static.forward_event_shape_tensor(shape_in), - bijector_static.inverse_event_shape_tensor(shape_out), - )) - self.assertAllEqual(shape_out, shape_out_static_) - self.assertAllEqual(shape_in, shape_in_static_) + ), feed_dict=feed_dict) + self.assertAllEqual(shape_out_static, shape_out_) + self.assertAllEqual(shape_in_static, shape_in_) def testScalarReshape(self): """Test reshaping to and from a scalar shape ().""" @@ -166,11 +105,11 @@ class ReshapeBijectorTest(test.TestCase): expected_x_scalar = np.random.randn(1,) expected_y_scalar = expected_x_scalar[0] + shape_in, shape_out, feed_dict = self.build_shapes([], [1,]) with self.test_session() as sess: bijector = Reshape( - event_shape_out=[], - event_shape_in=[1,], validate_args=True) - + event_shape_out=shape_in, + event_shape_in=shape_out, validate_args=True) (x_, y_, x_scalar_, @@ -180,53 +119,179 @@ class ReshapeBijectorTest(test.TestCase): bijector.forward(expected_x), bijector.inverse(expected_y_scalar), bijector.forward(expected_x_scalar), - )) + ), feed_dict=feed_dict) self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) - def testRaisesOpError(self): - x1 = np.random.randn(4, 2, 3) - x2 = np.random.randn(4, 3, 2) - x3 = np.random.randn(4, 5, 1, 1) + def testMultipleUnspecifiedDimensionsOpError(self): with self.test_session() as sess: - shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) - shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,]) bijector = Reshape( - event_shape_out=shape_out_ph, - event_shape_in=shape_in_ph, + event_shape_out=shape_out, + event_shape_in=shape_in, validate_args=True) - with self.assertRaisesOpError( + with self.assertRaisesError( + "elements must have at most one `-1`."): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + + # pylint: disable=invalid-name + def _testInvalidDimensionsOpError(self, expected_error_message): + + with self.test_session() as sess: + + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError(expected_error_message): + sess.run(bijector.forward_event_shape_tensor(shape_in), + feed_dict=feed_dict) + # pylint: enable=invalid-name + + def testValidButNonMatchingInputOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # Here we pass in a tensor (x) whose shape is compatible with + # the output shape, so tf.reshape will throw no error, but + # doesn't match the expected input shape. + with self.assertRaisesError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x), + feed_dict=feed_dict) + + def testValidButNonMatchingInputPartiallySpecifiedOpError(self): + x = np.random.randn(4, 3, 2) + + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError( "Input `event_shape` does not match `event_shape_in`."): - sess.run(bijector.forward(x2), - feed_dict={shape_out_ph: [1, 6, 1], - shape_in_ph: [2, 3]}) - - with self.assertRaisesOpError( - "event_shape_out entries must be positive."): - sess.run(bijector.forward(x1), - feed_dict={shape_out_ph: [-1, -1, 6], - shape_in_ph: [2, 3]}) - - # test that *all* methods check basic assertions - fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): + sess.run(bijector.forward(x), + feed_dict=feed_dict) + + # pylint: disable=invalid-name + def _testInputOutputMismatchOpError(self, expected_error_message): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 1, 1, 5) + + with self.test_session() as sess: + shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3], + [1, 1, 5]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + with self.assertRaisesError(expected_error_message): sess.run(bijector.forward(x1), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.inverse_log_det_jacobian(x3), - feed_dict=fd_mismatched) - with self.assertRaisesOpError( - "Input/output `event_size`s do not match."): - sess.run(bijector.forward_log_det_jacobian(x1), - feed_dict=fd_mismatched) + with self.assertRaisesError(expected_error_message): + sess.run(bijector.inverse(x2), feed_dict=fd_mismatched) + # pylint: enable=invalid-name + + def testOneShapePartiallySpecified(self): + expected_x = np.random.randn(4, 6) + expected_y = np.reshape(expected_x, [4, 2, 3]) + + with self.test_session() as sess: + # one of input/output shapes is partially specified + shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testBothShapesPartiallySpecified(self): + expected_x = np.random.randn(4, 2, 3) + expected_y = np.reshape(expected_x, [4, 3, 2]) + with self.test_session() as sess: + shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2]) + bijector = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def testDefaultVectorShape(self): + expected_x = np.random.randn(4, 4) + expected_y = np.reshape(expected_x, [4, 2, 2]) + with self.test_session() as sess: + _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2]) + bijector = Reshape(shape_out, + validate_args=True) + (x_, + y_, + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + ), feed_dict=feed_dict) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + + def build_shapes(self, *args, **kwargs): + raise NotImplementedError("Subclass failed to implement `build_shapes`.") + + +@test_util.with_c_api +class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_static = shape_in + shape_out_static = shape_out + feed_dict = {} + return shape_in_static, shape_out_static, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesRegexp(Exception, msg) + + def testEventShape(self): + shape_in_static = tensor_shape.TensorShape([2, 3]) + shape_out_static = tensor_shape.TensorShape([6,]) + bijector = Reshape( + event_shape_out=shape_out_static, + event_shape_in=shape_in_static, validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector.forward_event_shape(shape_in_static), + shape_out_static) + self.assertEqual( + bijector.inverse_event_shape(shape_out_static), + shape_in_static) def testBijectiveAndFinite(self): x = np.random.randn(4, 2, 3) @@ -238,5 +303,62 @@ class ReshapeBijectorTest(test.TestCase): validate_args=True) assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + def testInvalidDimensionsOpError(self): + if ops._USE_C_API: + error_message = "Invalid value in tensor used for shape: -2" + else: + error_message = "elements must be either positive integers or `-1`." + self._testInvalidDimensionsOpError(error_message) + + def testInputOutputMismatchOpError(self): + if ops._USE_C_API: + error_message = "Cannot reshape a tensor with" + else: + error_message = "Input to reshape is a tensor with" + self._testInputOutputMismatchOpError(error_message) + + +@test_util.with_c_api +class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=(len(shape_in),), + dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=(len(shape_out),), + dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + def testInvalidDimensionsOpError(self): + self._testInvalidDimensionsOpError( + "elements must be either positive integers or `-1`.") + + def testInputOutputMismatchOpError(self): + self._testInputOutputMismatchOpError("Input to reshape is a tensor with") + + +@test_util.with_c_api +class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest): + + def build_shapes(self, shape_in, shape_out): + shape_in_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32) + feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out} + return shape_in_ph, shape_out_ph, feed_dict + + def assertRaisesError(self, msg): + return self.assertRaisesOpError(msg) + + def testInvalidDimensionsOpError(self): + self._testInvalidDimensionsOpError( + "elements must be either positive integers or `-1`.") + + def testInputOutputMismatchOpError(self): + self._testInputOutputMismatchOpError("Input to reshape is a tensor with") + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index 2d74aa1f320149d0f7ef9e9c52b8c7053c2f74d7..31d24aa9ea09007b8db40e4869371b1f62639ac7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -23,10 +23,15 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops import mixture +from tensorflow.contrib.distributions.python.ops import mixture_same_family +from tensorflow.contrib.distributions.python.ops import mvn_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.linalg import linear_operator_diag import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -395,5 +400,145 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) +class PadMixtureDimensionsTest(test.TestCase): + + def test_pad_mixture_dimensions_mixture(self): + with self.test_session() as sess: + gm = mixture.Mixture( + cat=categorical.Categorical(probs=[[0.3, 0.7]]), + components=[ + normal.Normal(loc=[-1.0], scale=[1.0]), + normal.Normal(loc=[1.0], scale=[0.5]) + ]) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.cat, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + def test_pad_mixture_dimensions_mixture_same_family(self): + with self.test_session() as sess: + gm = mixture_same_family.MixtureSameFamily( + mixture_distribution=categorical.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1.0, 0.5])) + + x = array_ops.constant([[1.0, 2.0], [3.0, 4.0]]) + x_pad = distribution_util.pad_mixture_dimensions( + x, gm, gm.mixture_distribution, gm.event_shape.ndims) + x_out, x_pad_out = sess.run([x, x_pad]) + + self.assertAllEqual(x_pad_out.shape, [2, 2, 1]) + self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1])) + + +class _PadTest(object): + + def testNegAxisCorrectness(self): + x_ = np.float32([[1., 2, 3], + [4, 5, 6]]) + value_ = np.float32(0.25) + count_ = np.int32(2) + with self.test_session() as sess: + x = array_ops.placeholder_with_default( + x_, shape=x_.shape if self.is_static_shape else None) + value = (constant_op.constant(value_) if self.is_static_shape + else array_ops.placeholder_with_default(value_, shape=None)) + count = (constant_op.constant(count_) if self.is_static_shape + else array_ops.placeholder_with_default(count_, shape=None)) + + x0_front = distribution_util.pad( + x, axis=-2, value=value, count=count, front=True) + x0_back = distribution_util.pad( + x, axis=-2, count=count, back=True) + x0_both = distribution_util.pad( + x, axis=-2, value=value, front=True, back=True) + + if self.is_static_shape: + self.assertAllEqual([4, 3], x0_front.shape) + self.assertAllEqual([4, 3], x0_back.shape) + self.assertAllEqual([4, 3], x0_both.shape) + + [x0_front_, x0_back_, x0_both_] = sess.run([ + x0_front, x0_back, x0_both]) + + self.assertAllClose( + np.float32([[value_]*3, + [value_]*3, + [1, 2, 3], + [4, 5, 6]]), + x0_front_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[1, 2, 3], + [4, 5, 6], + [0.]*3, + [0.]*3]), + x0_back_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[value_]*3, + [1, 2, 3], + [4, 5, 6], + [value_]*3]), + x0_both_, atol=0., rtol=1e-6) + + def testPosAxisCorrectness(self): + x_ = np.float32([[1., 2, 3], + [4, 5, 6]]) + value_ = np.float32(0.25) + count_ = np.int32(2) + with self.test_session() as sess: + x = array_ops.placeholder_with_default( + x_, shape=x_.shape if self.is_static_shape else None) + value = (constant_op.constant(value_) if self.is_static_shape + else array_ops.placeholder_with_default(value_, shape=None)) + count = (constant_op.constant(count_) if self.is_static_shape + else array_ops.placeholder_with_default(count_, shape=None)) + + x1_front = distribution_util.pad( + x, axis=1, value=value, count=count, front=True) + x1_back = distribution_util.pad( + x, axis=1, count=count, back=True) + x1_both = distribution_util.pad( + x, axis=1, value=value, front=True, back=True) + + if self.is_static_shape: + self.assertAllEqual([2, 5], x1_front.shape) + self.assertAllEqual([2, 5], x1_back.shape) + self.assertAllEqual([2, 5], x1_both.shape) + + [x1_front_, x1_back_, x1_both_] = sess.run([ + x1_front, x1_back, x1_both]) + + self.assertAllClose( + np.float32([[value_]*2 + [1, 2, 3], + [value_]*2 + [4, 5, 6]]), + x1_front_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[1, 2, 3] + [0.]*2, + [4, 5, 6] + [0.]*2]), + x1_back_, atol=0., rtol=1e-6) + self.assertAllClose( + np.float32([[value_, 1, 2, 3, value_], + [value_, 4, 5, 6, value_]]), + x1_both_, atol=0., rtol=1e-6) + + +class PadStaticTest(_PadTest, test.TestCase): + + @property + def is_static_shape(self): + return True + + +class PadDynamicTest(_PadTest, test.TestCase): + + @property + def is_static_shape(self): + return False + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py index a7571806f295af4566e57ac4a785bc8774fd31ab..a4e75660083dc2edd1759a3a54e221d9e8a268c3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/half_normal_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import importlib import numpy as np +from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -28,7 +29,6 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import variables -from tensorflow.contrib.distributions.python.ops import half_normal as hn_lib from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -200,7 +200,7 @@ class HalfNormalTest(test.TestCase): with self.test_session(): scale = np.array([[1.0, 2.0, 3.0]]) halfnorm = hn_lib.HalfNormal(scale=scale) - + # See https://en.wikipedia.org/wiki/Half-normal_distribution for the # entropy formula used here. expected_entropy = 0.5 * np.log(np.pi * scale ** 2.0 / 2.0) + 0.5 diff --git a/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ea3c86b5c0f42b64fc6e4e362cbcc162bccf74a2 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/kumaraswamy_test.py @@ -0,0 +1,388 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import kumaraswamy as kumaraswamy_lib +from tensorflow.python.client import session +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import random_seed +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + + +special = try_import("scipy.special") +stats = try_import("scipy.stats") + + +def _kumaraswamy_mode(a, b): + a = np.asarray(a) + b = np.asarray(b) + return ((a - 1) / (a * b - 1))**(1 / a) + + +def _kumaraswamy_moment(a, b, n): + a = np.asarray(a) + b = np.asarray(b) + return b * special.beta(1.0 + n / a, b) + + +def _harmonic_number(b): + b = np.asarray(b) + return special.psi(b + 1) - special.psi(1) + + +def _kumaraswamy_cdf(a, b, x): + a = np.asarray(a) + b = np.asarray(b) + x = np.asarray(x) + return 1 - (1 - x**a)**b + + +def _kumaraswamy_pdf(a, b, x): + a = np.asarray(a) + b = np.asarray(b) + x = np.asarray(x) + return a * b * x ** (a - 1) * (1 - x ** a) ** (b - 1) + + +class KumaraswamyTest(test.TestCase): + + def testSimpleShapes(self): + with self.test_session(): + a = np.random.rand(3) + b = np.random.rand(3) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3]), dist.batch_shape) + + def testComplexShapes(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(3, 2, 2) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + + def testComplexShapesBroadcast(self): + with self.test_session(): + a = np.random.rand(3, 2, 2) + b = np.random.rand(2, 2) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertAllEqual([], dist.event_shape_tensor().eval()) + self.assertAllEqual([3, 2, 2], dist.batch_shape_tensor().eval()) + self.assertEqual(tensor_shape.TensorShape([]), dist.event_shape) + self.assertEqual(tensor_shape.TensorShape([3, 2, 2]), dist.batch_shape) + + def testAProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual([1, 3], dist.concentration1.get_shape()) + self.assertAllClose(a, dist.concentration1.eval()) + + def testBProperty(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual([1, 3], dist.concentration0.get_shape()) + self.assertAllClose(b, dist.concentration0.eval()) + + def testPdfXProper(self): + a = [[1., 2, 3]] + b = [[2., 4, 3]] + with self.test_session(): + dist = kumaraswamy_lib.Kumaraswamy(a, b, validate_args=True) + dist.prob([.1, .3, .6]).eval() + dist.prob([.2, .3, .5]).eval() + # Either condition can trigger. + with self.assertRaisesOpError("sample must be positive"): + dist.prob([-1., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be positive"): + dist.prob([0., 0.1, 0.5]).eval() + with self.assertRaisesOpError("sample must be no larger than `1`"): + dist.prob([.1, .2, 1.2]).eval() + + def testPdfTwoBatches(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.5, .5] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfTwoBatchesNontrivialX(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [.3, .7] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2,), pdf.get_shape()) + + def testPdfUniformZeroBatch(self): + with self.test_session(): + # This is equivalent to a uniform distribution + a = 1. + b = 1. + x = np.array([.1, .2, .3, .5, .8], dtype=np.float32) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((5,), pdf.get_shape()) + + def testPdfAStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2]] + b = [[1., 2]] + x = [[.5, .5], [.3, .7]] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + pdf = dist.prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfAStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [1., 2] + b = [1., 2] + x = [[.5, .5], [.2, .8]] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenSameRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [[.5, .5]] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testPdfXStretchedInBroadcastWhenLowerRank(self): + with self.test_session(): + a = [[1., 2], [2., 3]] + b = [[1., 2], [2., 3]] + x = [.5, .5] + pdf = kumaraswamy_lib.Kumaraswamy(a, b).prob(x) + expected_pdf = _kumaraswamy_pdf(a, b, x) + self.assertAllClose(expected_pdf, pdf.eval()) + self.assertEqual((2, 2), pdf.get_shape()) + + def testKumaraswamyMean(self): + with session.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.mean().get_shape(), (3,)) + if not stats: + return + expected_mean = _kumaraswamy_moment(a, b, 1) + self.assertAllClose(expected_mean, dist.mean().eval()) + + def testKumaraswamyVariance(self): + with session.Session(): + a = [1., 2, 3] + b = [2., 4, 1.2] + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.variance().get_shape(), (3,)) + if not stats: + return + expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment( + a, b, 1)**2 + self.assertAllClose(expected_variance, dist.variance().eval()) + + def testKumaraswamyMode(self): + with session.Session(): + a = np.array([1.1, 2, 3]) + b = np.array([2., 4, 1.2]) + expected_mode = _kumaraswamy_mode(a, b) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.mode().get_shape(), (3,)) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testKumaraswamyModeInvalid(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + dist.mode().eval() + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=False) + with self.assertRaisesOpError("Condition x < y.*"): + dist.mode().eval() + + def testKumaraswamyModeEnableAllowNanStats(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True) + + expected_mode = _kumaraswamy_mode(a, b) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + a = np.array([2., 2, 3]) + b = np.array([1., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b, allow_nan_stats=True) + + expected_mode = _kumaraswamy_mode(a, b) + expected_mode[0] = np.nan + self.assertEqual((3,), dist.mode().get_shape()) + self.assertAllClose(expected_mode, dist.mode().eval()) + + def testKumaraswamyEntropy(self): + with session.Session(): + a = np.array([1., 2, 3]) + b = np.array([2., 4, 1.2]) + dist = kumaraswamy_lib.Kumaraswamy(a, b) + self.assertEqual(dist.entropy().get_shape(), (3,)) + if not stats: + return + expected_entropy = (1 - 1. / a) + ( + 1 - 1. / b) * _harmonic_number(b) + np.log(a * b) + self.assertAllClose(expected_entropy, dist.entropy().eval()) + + def testKumaraswamySample(self): + with self.test_session(): + a = 1. + b = 2. + kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b) + n = constant_op.constant(100000) + samples = kumaraswamy.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000,)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertLess( + stats.kstest( + # Kumaraswamy is a univariate distribution. + sample_values, + lambda x: _kumaraswamy_cdf(1., 2., x))[0], + 0.01) + # The standard error of the sample mean is 1 / (sqrt(18 * n)) + expected_mean = _kumaraswamy_moment(a, b, 1) + self.assertAllClose(sample_values.mean(axis=0), expected_mean, atol=1e-2) + expected_variance = _kumaraswamy_moment(a, b, 2) - _kumaraswamy_moment( + a, b, 1)**2 + self.assertAllClose( + np.cov(sample_values, rowvar=0), expected_variance, atol=1e-1) + + # Test that sampling with the same seed twice gives the same results. + def testKumaraswamySampleMultipleTimes(self): + with self.test_session(): + a_val = 1. + b_val = 2. + n_val = 100 + + random_seed.set_random_seed(654321) + kumaraswamy1 = kumaraswamy_lib.Kumaraswamy( + concentration1=a_val, concentration0=b_val, name="kumaraswamy1") + samples1 = kumaraswamy1.sample(n_val, seed=123456).eval() + + random_seed.set_random_seed(654321) + kumaraswamy2 = kumaraswamy_lib.Kumaraswamy( + concentration1=a_val, concentration0=b_val, name="kumaraswamy2") + samples2 = kumaraswamy2.sample(n_val, seed=123456).eval() + + self.assertAllClose(samples1, samples2) + + def testKumaraswamySampleMultidimensional(self): + with self.test_session(): + a = np.random.rand(3, 2, 2).astype(np.float32) + b = np.random.rand(3, 2, 2).astype(np.float32) + kumaraswamy = kumaraswamy_lib.Kumaraswamy(a, b) + n = constant_op.constant(100000) + samples = kumaraswamy.sample(n) + sample_values = samples.eval() + self.assertEqual(sample_values.shape, (100000, 3, 2, 2)) + self.assertFalse(np.any(sample_values < 0.0)) + if not stats: + return + self.assertAllClose( + sample_values[:, 1, :].mean(axis=0), + _kumaraswamy_moment(a, b, 1)[1, :], + atol=1e-1) + + def testKumaraswamyCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = kumaraswamy_lib.Kumaraswamy(a, b).cdf(x).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose( + _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0) + + def testKumaraswamyLogCdf(self): + with self.test_session(): + shape = (30, 40, 50) + for dt in (np.float32, np.float64): + a = 10. * np.random.random(shape).astype(dt) + b = 10. * np.random.random(shape).astype(dt) + x = np.random.random(shape).astype(dt) + actual = math_ops.exp(kumaraswamy_lib.Kumaraswamy(a, + b).log_cdf(x)).eval() + self.assertAllEqual(np.ones(shape, dtype=np.bool), 0. <= x) + self.assertAllEqual(np.ones(shape, dtype=np.bool), 1. >= x) + if not stats: + return + self.assertAllClose( + _kumaraswamy_cdf(a, b, x), actual, rtol=1e-4, atol=0) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py index ece6bc077d9e21502fdfd01300a9d3e9f2c9c380..ff6092fc260660b512e8123823c63e98a023af6d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -45,6 +45,17 @@ class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, self.assertEqual([4, 5], x.shape) self.assertEqual([4, 5], log_prob_x.shape) + def testSampleAndLogProbBatch(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[[0.3, 0.7]]), + components_distribution=normal_lib.Normal( + loc=[[-1., 1]], scale=[[0.1, 0.5]])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 1], x.shape) + self.assertEqual([4, 5, 1], log_prob_x.shape) + def testSampleAndLogProbShapesBroadcastMix(self): mix_probs = np.float32([.3, .7]) bern_probs = np.float32([[.4, .6], [.25, .75]]) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index 1e514fe0ff21cd53c8c235da417890773db50c37..02064891758a86c5108e11da6a3666f2d5c56c64 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -107,7 +107,7 @@ def _test_capture_normal_sample_outputs(): ds.Normal._call_sample_n = true_normal_call_sample_n -def make_univariate_mixture(batch_shape, num_components): +def make_univariate_mixture(batch_shape, num_components, use_static_graph): batch_shape = ops.convert_to_tensor(batch_shape, dtypes.int32) logits = random_ops.random_uniform( array_ops.concat((batch_shape, [num_components]), axis=0), @@ -119,11 +119,11 @@ def make_univariate_mixture(batch_shape, num_components): for _ in range(num_components) ] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) def make_multivariate_mixture(batch_shape, num_components, event_shape, - batch_shape_tensor=None): + use_static_graph, batch_shape_tensor=None): if batch_shape_tensor is None: batch_shape_tensor = batch_shape batch_shape_tensor = ops.convert_to_tensor(batch_shape_tensor, dtypes.int32) @@ -145,15 +145,17 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape, loc=loc, scale_diag=scale_diag) components = [create_component() for _ in range(num_components)] cat = ds.Categorical(logits, dtype=dtypes.int32) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=use_static_graph) class MixtureTest(test.TestCase): + use_static_graph = False def testShapes(self): with self.test_session(): for batch_shape in ([], [1], [2, 3, 4]): - dist = make_univariate_mixture(batch_shape, num_components=10) + dist = make_univariate_mixture(batch_shape, num_components=10, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual([], dist.event_shape) @@ -161,7 +163,8 @@ class MixtureTest(test.TestCase): for event_shape in ([1], [2]): dist = make_multivariate_mixture( - batch_shape, num_components=10, event_shape=event_shape) + batch_shape, num_components=10, event_shape=event_shape, + use_static_graph=self.use_static_graph) self.assertAllEqual(batch_shape, dist.batch_shape) self.assertAllEqual(batch_shape, dist.batch_shape_tensor().eval()) self.assertAllEqual(event_shape, dist.event_shape) @@ -172,7 +175,8 @@ class MixtureTest(test.TestCase): r"cat.num_classes != len"): ds.Mixture( ds.Categorical([0.1, 0.5]), # 2 classes - [ds.Normal(loc=1.0, scale=2.0)]) + [ds.Normal(loc=1.0, scale=2.0)], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the @@ -185,13 +189,15 @@ class MixtureTest(test.TestCase): loc=1.0, scale=2.0), # scalar dist ds.Normal( loc=[1.0, 1.0], scale=[2.0, 2.0]) - ]) + ], + use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) ds.Mixture( ds.Categorical(cat_logits), [ds.Normal( - loc=[1.0], scale=[2.0])]) + loc=[1.0], scale=[2.0])], + use_static_graph=self.use_static_graph) def testBrokenShapesDynamic(self): with self.test_session(): @@ -203,29 +209,37 @@ class MixtureTest(test.TestCase): loc=d0_param, scale=d0_param), ds.Normal( loc=d1_param, scale=d1_param) ], - validate_args=True) - with self.assertRaisesOpError(r"batch shape must match"): + validate_args=True, + use_static_graph=self.use_static_graph) + + if self.use_static_graph: + error_string = r"Shapes of all inputs must match" + else: + error_string = r"batch shape must match" + + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: [1.0]}) - with self.assertRaisesOpError(r"batch shape must match"): + with self.assertRaisesOpError(error_string): d.sample().eval(feed_dict={d0_param: [2.0, 3.0], d1_param: 1.0}) def testBrokenTypes(self): with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"): - ds.Mixture(None, []) + ds.Mixture(None, [], use_static_graph=self.use_static_graph) cat = ds.Categorical([0.3, 0.2]) # components must be a list of distributions with self.assertRaisesWithPredicateMatch( TypeError, "all .* must be Distribution instances"): - ds.Mixture(cat, [None]) + ds.Mixture(cat, [None], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): ds.Mixture( cat, [ ds.Normal(loc=[1.0], scale=[2.0]), ds.Normal(loc=[np.float16(1.0)], scale=[np.float16(2.0)]), - ]) + ], use_static_graph=self.use_static_graph) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): - ds.Mixture(ds.Categorical([0.3, 0.2]), None) + ds.Mixture(ds.Categorical([0.3, 0.2]), None, + use_static_graph=self.use_static_graph) # TODO(ebrevdo): once distribution Domains have been added, add a # test to ensure that the domains of the distributions in a @@ -235,7 +249,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=2) + batch_shape=batch_shape, num_components=2, + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape, mean.get_shape()) @@ -256,7 +271,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) mean = dist.mean() self.assertEqual(batch_shape + (4,), mean.get_shape()) @@ -283,7 +299,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape, dev.get_shape()) @@ -325,7 +342,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, - event_shape=(4,)) + event_shape=(4,), + use_static_graph=self.use_static_graph) dev = dist.stddev() self.assertEqual(batch_shape + (4,), dev.get_shape()) @@ -371,7 +389,8 @@ class MixtureTest(test.TestCase): scale=component_devs[0]), ds.Normal(loc=component_means[1], scale=component_devs[1]), - ]) + ], + use_static_graph=self.use_static_graph) mix_dev = mixture_dist.stddev() with self.test_session() as sess: actual_stddev = sess.run(mix_dev) @@ -379,7 +398,8 @@ class MixtureTest(test.TestCase): def testProbScalarUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[], num_components=2) + dist = make_univariate_mixture(batch_shape=[], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.array( [1.0, 2.0], dtype=np.float32), np.array( @@ -405,7 +425,8 @@ class MixtureTest(test.TestCase): def testProbScalarMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[], num_components=2, event_shape=[3]) + batch_shape=[], num_components=2, event_shape=[3], + use_static_graph=self.use_static_graph) for x in [ np.array( [[-1.0, 0.0, 1.0], [0.5, 1.0, -0.3]], dtype=np.float32), np.array( @@ -432,7 +453,8 @@ class MixtureTest(test.TestCase): def testProbBatchUnivariate(self): with self.test_session() as sess: - dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2) + dist = make_univariate_mixture(batch_shape=[2, 3], num_components=2, + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3).astype(np.float32), @@ -459,7 +481,8 @@ class MixtureTest(test.TestCase): def testProbBatchMultivariate(self): with self.test_session() as sess: dist = make_multivariate_mixture( - batch_shape=[2, 3], num_components=2, event_shape=[4]) + batch_shape=[2, 3], num_components=2, event_shape=[4], + use_static_graph=self.use_static_graph) for x in [ np.random.randn(2, 3, 4).astype(np.float32), @@ -487,7 +510,8 @@ class MixtureTest(test.TestCase): num_components = 3 batch_shape = [] dist = make_univariate_mixture( - batch_shape=batch_shape, num_components=num_components) + batch_shape=batch_shape, num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -502,7 +526,10 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch univariate case: batch_size == 1, rank 1 - which_dist_samples = dist_sample_values[c][:size_c] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c] + else: + which_dist_samples = dist_sample_values[c][:size_c] self.assertAllClose(which_dist_samples, sample_values[which_c]) # Test that sampling with the same seed twice gives the same results. @@ -522,7 +549,8 @@ class MixtureTest(test.TestCase): ] cat = ds.Categorical( logits, dtype=dtypes.int32, name="cat1") - dist1 = ds.Mixture(cat, components, name="mixture1") + dist1 = ds.Mixture(cat, components, name="mixture1", + use_static_graph=self.use_static_graph) samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) @@ -532,7 +560,8 @@ class MixtureTest(test.TestCase): ] cat2 = ds.Categorical( logits, dtype=dtypes.int32, name="cat2") - dist2 = ds.Mixture(cat2, components2, name="mixture2") + dist2 = ds.Mixture(cat2, components2, name="mixture2", + use_static_graph=self.use_static_graph) samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -541,7 +570,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: num_components = 3 dist = make_multivariate_mixture( - batch_shape=[], num_components=num_components, event_shape=[2]) + batch_shape=[], num_components=num_components, event_shape=[2], + use_static_graph=self.use_static_graph) n = 4 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -555,14 +585,18 @@ class MixtureTest(test.TestCase): which_c = np.where(cat_sample_values == c)[0] size_c = which_c.size # Scalar Batch multivariate case: batch_size == 1, rank 2 - which_dist_samples = dist_sample_values[c][:size_c, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c, :] + else: + which_dist_samples = dist_sample_values[c][:size_c, :] self.assertAllClose(which_dist_samples, sample_values[which_c, :]) def testSampleBatchUnivariate(self): with self.test_session() as sess: num_components = 3 dist = make_univariate_mixture( - batch_shape=[2, 3], num_components=num_components) + batch_shape=[2, 3], num_components=num_components, + use_static_graph=self.use_static_graph) n = 4 with _test_capture_normal_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -576,8 +610,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 3 - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1]) @@ -594,7 +632,8 @@ class MixtureTest(test.TestCase): dist = make_multivariate_mixture( batch_shape=batch_shape, num_components=num_components, event_shape=[4], - batch_shape_tensor=batch_shape_tensor) + batch_shape_tensor=batch_shape_tensor, + use_static_graph=self.use_static_graph) n = 5 with _test_capture_mvndiag_sample_outputs() as component_samples: samples = dist.sample(n, seed=123) @@ -617,8 +656,12 @@ class MixtureTest(test.TestCase): which_c_s, which_c_b0, which_c_b1 = np.where(cat_sample_values == c) size_c = which_c_s.size # Batch univariate case: batch_size == [2, 3], rank 4 (multivariate) - which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, - which_c_b1, :] + if self.use_static_graph: + which_dist_samples = dist_sample_values[c][which_c_s, which_c_b0, + which_c_b1, :] + else: + which_dist_samples = dist_sample_values[c][range(size_c), which_c_b0, + which_c_b1, :] self.assertAllClose(which_dist_samples, sample_values[which_c_s, which_c_b0, which_c_b1, :]) @@ -632,7 +675,8 @@ class MixtureTest(test.TestCase): with self.test_session() as sess: for batch_shape in ((), (2,), (2, 3)): dist = make_multivariate_mixture( - batch_shape=batch_shape, num_components=2, event_shape=(4,)) + batch_shape=batch_shape, num_components=2, event_shape=(4,), + use_static_graph=self.use_static_graph) entropy_lower_bound = dist.entropy_lower_bound() self.assertEqual(batch_shape, entropy_lower_bound.get_shape()) @@ -673,7 +717,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32) @@ -721,7 +766,8 @@ class MixtureTest(test.TestCase): cat_tf = ds.Categorical(probs=mixture_weights) components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf, + use_static_graph=self.use_static_graph) x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32) xs_to_check = [ @@ -760,12 +806,18 @@ class MixtureTest(test.TestCase): gm = ds.Mixture( cat=ds.Categorical(probs=[.3, .7]), components=[ds.Gamma(1., 2.), - ds.Gamma(2., 1.)]) + ds.Gamma(2., 1.)], + use_static_graph=self.use_static_graph) x_ = gm.sample().eval() self.assertAllEqual([], x_.shape) +class MixtureStaticSampleTest(MixtureTest): + use_static_graph = True + + class MixtureBenchmark(test.Benchmark): + use_static_graph = False def _runSamplingBenchmark(self, name, create_distribution, use_gpu, num_components, batch_size, num_features, @@ -811,7 +863,7 @@ class MixtureBenchmark(test.Benchmark): components = list( ds.MultivariateNormalDiag( loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -853,7 +905,7 @@ class MixtureBenchmark(test.Benchmark): ds.MultivariateNormalTriL( loc=mu, scale_tril=linalg_ops.cholesky(sigma)) for (mu, sigma) in zip(mus, sigmas)) - return ds.Mixture(cat, components) + return ds.Mixture(cat, components, use_static_graph=self.use_static_graph) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -872,5 +924,9 @@ class MixtureBenchmark(test.Benchmark): sample_size=sample_size) +class MixtureStaticSampleBenchmark(MixtureBenchmark): + use_static_graph = True + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py index 3c0147b8cf6e1b6a2791e85c0c0997992445fa7e..1035cb00f76d95c7c52c3e812e8bb2868d34b890 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -18,37 +18,40 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.contrib.distributions.python.ops import poisson_lognormal from tensorflow.contrib.distributions.python.ops import test_util -from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test -class PoissonLogNormalQuadratureCompoundTest( - test_util.DiscreteScalarDistributionTestHelpers, test.TestCase): +class _PoissonLogNormalQuadratureCompoundTest( + test_util.DiscreteScalarDistributionTestHelpers): """Tests the PoissonLogNormalQuadratureCompoundTest distribution.""" def testSampleProbConsistent(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=-2., - scale=1.1, - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + -2., + shape=[] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1.1, + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1) + sess.run, pln, batch_size=1, rtol=0.1) def testMeanVariance(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=0., - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + 0., + shape=[] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.02) @@ -56,21 +59,27 @@ class PoissonLogNormalQuadratureCompoundTest( def testSampleProbConsistentBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[0., -0.5], - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [0., -0.5], + shape=[2] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1, atol=0.01) + sess.run, pln, batch_size=2, rtol=0.1, atol=0.01) def testMeanVarianceBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[0., -0.5], - scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [0., -0.5], + shape=[2] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + 1., + shape=[] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.1, atol=0.01) @@ -78,38 +87,46 @@ class PoissonLogNormalQuadratureCompoundTest( def testSampleProbConsistentBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[[0.], [-0.5]], - scale=[[1., 0.9]], - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [[0.], [-0.5]], + shape=[2, 1] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + [[1., 0.9]], + shape=[1, 2] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_log_prob( - sess.run, pln, rtol=0.1, atol=0.08) + sess.run, pln, batch_size=4, rtol=0.1, atol=0.08) def testMeanVarianceBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=[[0.], [-0.5]], - scale=[[1., 0.9]], - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + loc=array_ops.placeholder_with_default( + [[0.], [-0.5]], + shape=[2, 1] if self.static_shape else None), + scale=array_ops.placeholder_with_default( + [[1., 0.9]], + shape=[1, 2] if self.static_shape else None), + quadrature_size=10, validate_args=True) self.run_test_sample_consistent_mean_variance( sess.run, pln, rtol=0.1, atol=0.01) - def testSampleProbConsistentDynamicQuadrature(self): - with self.test_session() as sess: - qgrid = array_ops.placeholder(dtype=dtypes.float32) - qprobs = array_ops.placeholder(dtype=dtypes.float32) - g, p = np.polynomial.hermite.hermgauss(deg=10) - pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( - loc=-2., - scale=1.1, - quadrature_grid_and_probs=(g, p), - validate_args=True) - self.run_test_sample_consistent_log_prob( - lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}), - pln, rtol=0.1) + +class PoissonLogNormalQuadratureCompoundStaticShapeTest( + _PoissonLogNormalQuadratureCompoundTest, test.TestCase): + + @property + def static_shape(self): + return True + + +class PoissonLogNormalQuadratureCompoundDynamicShapeTest( + _PoissonLogNormalQuadratureCompoundTest, test.TestCase): + + @property + def static_shape(self): + return False if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py index 595d9f5df755d7defa63d385039bafe4f87aa6ec..4186cf129dbf31724c84133734da3f226817c71a 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/sample_stats_test.py @@ -23,11 +23,244 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import sample_stats from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import spectral_ops_test_util from tensorflow.python.platform import test rng = np.random.RandomState(0) +class _AutoCorrelationTest(object): + + @property + def use_static_shape(self): + raise NotImplementedError("Subclass failed to implement `use_static_shape`") + + @property + def dtype(self): + raise NotImplementedError("Subclass failed to implement `dtype`.") + + def test_constant_sequence_axis_0_max_lags_none_center_false(self): + x_ = np.array([[0., 0., 0.], + [1., 1., 1.]]).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + input=x_, + shape=x_.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, center=False, normalize=False) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [1., 1., 1.]], auto_corr_) + + def test_constant_sequence_axis_0_max_lags_none_center_true(self): + x_ = np.array([[0., 0., 0.], + [1., 1., 1.]]).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + input=x_, + shape=x_.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session() as sess: + # Setting normalize = True means we divide by zero. + auto_corr = sample_stats.auto_correlation( + x_ph, axis=1, normalize=False, center=True) + if self.use_static_shape: + self.assertEqual((2, 3), auto_corr.shape) + auto_corr_ = sess.run(auto_corr) + self.assertAllClose( + [[0., 0., 0.], + [0., 0., 0.]], auto_corr_) + + def check_results_versus_brute_force( + self, x, axis, max_lags, center, normalize): + """Compute auto-correlation by brute force, then compare to tf result.""" + # Brute for auto-corr -- avoiding fft and transpositions. + axis_len = x.shape[axis] + if max_lags is None: + max_lags = axis_len - 1 + else: + max_lags = min(axis_len - 1, max_lags) + auto_corr_at_lag = [] + if center: + x -= x.mean(axis=axis, keepdims=True) + for m in range(max_lags + 1): + auto_corr_at_lag.append(( + np.take(x, indices=range(0, axis_len - m), axis=axis) * + np.conj(np.take(x, indices=range(m, axis_len), axis=axis)) + ).mean(axis=axis, keepdims=True)) + rxx = np.concatenate(auto_corr_at_lag, axis=axis) + if normalize: + rxx /= np.take(rxx, [0], axis=axis) + + x_ph = array_ops.placeholder_with_default( + x, shape=x.shape if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + auto_corr = sample_stats.auto_correlation( + x_ph, axis=axis, max_lags=max_lags, center=center, + normalize=normalize) + if self.use_static_shape: + output_shape = list(x.shape) + output_shape[axis] = max_lags + 1 + self.assertAllEqual(output_shape, auto_corr.shape) + self.assertAllClose(rxx, auto_corr.eval(), rtol=1e-5, atol=1e-5) + + def test_axis_n1_center_false_max_lags_none(self): + x = rng.randn(2, 3, 4).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(2, 3, 4).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-1, max_lags=None, center=False, normalize=False) + + def test_axis_n2_center_false_max_lags_none(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-2, max_lags=None, center=False, normalize=False) + + def test_axis_n1_center_false_max_lags_none_normalize_true(self): + x = rng.randn(2, 3, 4).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(2, 3, 4).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-1, max_lags=None, center=False, normalize=True) + + def test_axis_n2_center_false_max_lags_none_normalize_true(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=-2, max_lags=None, center=False, normalize=True) + + def test_axis_0_center_true_max_lags_none(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=0, max_lags=None, center=True, normalize=False) + + def test_axis_2_center_true_max_lags_1(self): + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=2, max_lags=1, center=True, normalize=False) + + def test_axis_2_center_true_max_lags_100(self): + # There are less than 100 elements in axis 2, so expect we get back an array + # the same size as x, despite having asked for 100 lags. + x = rng.randn(3, 4, 5).astype(self.dtype) + if self.dtype in [np.complex64]: + x = 1j * rng.randn(3, 4, 5).astype(self.dtype) + self.check_results_versus_brute_force( + x, axis=2, max_lags=100, center=True, normalize=False) + + def test_long_orthonormal_sequence_has_corr_length_0(self): + l = 10000 + x = rng.randn(l).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(l,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # OSS CPU FFT has some accuracy issues is not the most accurate. + # So this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + + def test_step_function_sequence(self): + # x jumps to new random value every 10 steps. So correlation length = 10. + x = (rng.randint(-10, 10, size=(1000, 1)) + * np.ones((1, 10))).ravel().astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(1000 * 10,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=1000 * 10 // 2, center=True, normalize=False) + if self.use_static_shape: + self.assertAllEqual((1000 * 10 // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + rxx_ /= rxx_[0] + # Expect positive correlation for the first 10 lags, then significantly + # smaller negative. + self.assertGreater(rxx_[:10].min(), 0) + self.assertGreater(rxx_[9], 5 * rxx_[10:20].mean()) + # RXX should be decreasing for the first 10 lags. + diff = np.diff(rxx_) + self.assertLess(diff[:10].max(), 0) + + def test_normalization(self): + l = 10000 + x = 3 * rng.randn(l).astype(self.dtype) + x_ph = array_ops.placeholder_with_default( + x, shape=(l,) if self.use_static_shape else None) + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(): + rxx = sample_stats.auto_correlation( + x_ph, max_lags=l // 2, center=True, normalize=True) + if self.use_static_shape: + self.assertAllEqual((l // 2 + 1,), rxx.shape) + rxx_ = rxx.eval() + # Note that RXX[0] = 1, despite the fact that E[X^2] = 9, and this is + # due to normalize=True. + # OSS CPU FFT has some accuracy issues is not the most accurate. + # So this tolerance is a bit bad. + self.assertAllClose(1., rxx_[0], rtol=0.05) + # The maximal error in the rest of the sequence is not great. + self.assertAllClose(np.zeros(l // 2), rxx_[1:], atol=0.1) + # The mean error in the rest is ok, actually 0.008 when I tested it. + self.assertLess(np.abs(rxx_[1:]).mean(), 0.02) + + +class AutoCorrelationTestStaticShapeFloat32(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.float32 + + @property + def use_static_shape(self): + return True + + +class AutoCorrelationTestStaticShapeComplex64(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.complex64 + + @property + def use_static_shape(self): + return True + + +class AutoCorrelationTestDynamicShapeFloat32(test.TestCase, + _AutoCorrelationTest): + + @property + def dtype(self): + return np.float32 + + @property + def use_static_shape(self): + return False + + class PercentileTestWithLowerInterpolation(test.TestCase): _interpolation = "lower" diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 103d8e186221e879d1734a097114708429f725bd..cbaf74d3f66253ae5727e1ba579e2d49235b748e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -200,6 +200,27 @@ class TransformedDistributionTest(test.TestCase): self.assertAllEqual([2], multi_logit_normal.event_shape) self.assertAllEqual([2], multi_logit_normal.event_shape_tensor().eval()) + def testCastLogDetJacobian(self): + """Test log_prob when Jacobian and log_prob dtypes do not match.""" + + with self.test_session(): + # Create an identity bijector whose jacobians have dtype int32 + int_identity = bs.Inline( + forward_fn=array_ops.identity, + inverse_fn=array_ops.identity, + inverse_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + forward_log_det_jacobian_fn=lambda x: math_ops.cast(0, dtypes.int32), + is_constant_jacobian=True) + normal = self._cls()( + distribution=ds.Normal(loc=0., scale=1.), + bijector=int_identity, + validate_args=True) + + y = normal.sample() + normal.log_prob(y).eval() + normal.prob(y).eval() + normal.entropy().eval() + def testEntropy(self): with self.test_session(): shift = np.array([[-1, 0, 1], [-1, -2, -3]], dtype=np.float32) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index de4a221f7badca8267a81d612a57137c676ff052..04f047aa0c81b3f59b97f14554fb59cb1b3dd8af 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -21,14 +21,14 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.distributions.python.ops import test_util -from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops +from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vdm_lib from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.platform import test +rng = np.random.RandomState(0) + class VectorDiffeomixtureTest( test_util.VectorDistributionTestHelpers, test.TestCase): @@ -37,9 +37,9 @@ class VectorDiffeomixtureTest( def testSampleProbConsistentBroadcastMixNoBatch(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], - mix_scale=[1.], + temperature=[1.], distribution=normal_lib.Normal(0., 1.), loc=[ None, @@ -54,20 +54,21 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.015) def testSampleProbConsistentBroadcastMixNonStandardBase(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], - mix_scale=[1.], + temperature=[1.], distribution=normal_lib.Normal(1., 1.5), loc=[ None, @@ -82,20 +83,21 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=1., rtol=0.006) + sess.run, vdm, radius=2., center=1., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=3., rtol=0.009) + sess.run, vdm, radius=4., center=3., rtol=0.01) def testSampleProbConsistentBroadcastMixBatch(self): with self.test_session() as sess: dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [1.]], - mix_scale=[1.], + temperature=[1.], distribution=normal_lib.Normal(0., 1.), loc=[ None, @@ -113,20 +115,48 @@ class VectorDiffeomixtureTest( ]), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.01) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess.run, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.01) + + def testSampleProbConsistentBroadcastMixTwoBatchDims(self): + dims = 4 + loc_1 = rng.randn(2, 3, dims).astype(np.float32) + + with self.test_session() as sess: + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=(rng.rand(2, 3, 1) - 0.5).astype(np.float32), + temperature=[1.], + distribution=normal_lib.Normal(0., 1.), + loc=[ + None, + loc_1, + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=[np.float32(1.1)], + is_positive_definite=True), + ] * 2, + validate_args=True) + # Ball centered at component0's mean. + self.run_test_sample_consistent_log_prob( + sess.run, vdm, radius=2., center=0., rtol=0.01) + # Larger ball centered at component1's mean. + self.run_test_sample_consistent_log_prob( + sess.run, vdm, radius=3., center=loc_1, rtol=0.02) def testMeanCovarianceNoBatch(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], - mix_scale=[10.], + temperature=[1 / 10.], distribution=normal_lib.Normal(0., 1.), loc=[ np.float32([-2.]), @@ -141,16 +171,99 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess.run, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.08) + + def testTemperatureControlsHowMuchThisLooksLikeDiscreteMixture(self): + # As temperature decreases, this should approach a mixture of normals, with + # components at -2, 2. + with self.test_session() as sess: + dims = 1 + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=[0.], + temperature=[[2.], [1.], [0.2]], + distribution=normal_lib.Normal(0., 1.), + loc=[ + np.float32([-2.]), + np.float32([2.]), + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=np.float32(0.5), + is_positive_definite=True), + ] * 2, # Use the same scale for each component. + quadrature_size=8, + validate_args=True) + + samps = vdm.sample(10000) + self.assertAllEqual((10000, 3, 1), samps.shape) + samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape. + + # One characteristic of a discrete mixture (as opposed to a "smear") is + # that more weight is put near the component centers at -2, 2, and thus + # less weight is put near the origin. + prob_of_being_near_origin = (np.abs(samps_) < 1).mean(axis=0) + self.assertGreater( + prob_of_being_near_origin[0], prob_of_being_near_origin[1]) + self.assertGreater( + prob_of_being_near_origin[1], prob_of_being_near_origin[2]) + + # Run this test as well, just because we can. + self.run_test_sample_consistent_mean_covariance( + sess.run, vdm, rtol=0.02, cov_rtol=0.08) + + def testConcentrationLocControlsHowMuchWeightIsOnEachComponent(self): + with self.test_session() as sess: + dims = 1 + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=[[-1.], [0.], [1.]], + temperature=[0.5], + distribution=normal_lib.Normal(0., 1.), + loc=[ + np.float32([-2.]), + np.float32([2.]), + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=np.float32(0.5), + is_positive_definite=True), + ] * 2, # Use the same scale for each component. + quadrature_size=8, + validate_args=True) + + samps = vdm.sample(10000) + self.assertAllEqual((10000, 3, 1), samps.shape) + samps_ = sess.run(samps).reshape(10000, 3) # Make scalar event shape. + + # One characteristic of putting more weight on a component is that the + # mean is closer to that component's mean. + # Get the mean for each batch member, the names signify the value of + # concentration for that batch member. + mean_neg1, mean_0, mean_1 = samps_.mean(axis=0) + + # Since concentration is the concentration for component 0, + # concentration = -1 ==> more weight on component 1, which has mean = 2 + # concentration = 0 ==> equal weight + # concentration = 1 ==> more weight on component 0, which has mean = -2 + self.assertLess(-2, mean_1) + self.assertLess(mean_1, mean_0) + self.assertLess(mean_0, mean_neg1) + self.assertLess(mean_neg1, 2) + + # Run this test as well, just because we can. + self.run_test_sample_consistent_mean_covariance( + sess.run, vdm, rtol=0.02, cov_rtol=0.08) def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], - mix_scale=[10.], + temperature=[0.1], distribution=normal_lib.Normal(-1., 1.5), loc=[ np.float32([-2.]), @@ -165,6 +278,7 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) @@ -172,9 +286,9 @@ class VectorDiffeomixtureTest( def testMeanCovarianceBatch(self): with self.test_session() as sess: dims = 3 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( + vdm = vdm_lib.VectorDiffeomixture( mix_loc=[[0.], [4.]], - mix_scale=[10.], + temperature=[0.1], distribution=normal_lib.Normal(0., 1.), loc=[ np.float32([[-2.]]), @@ -192,19 +306,17 @@ class VectorDiffeomixtureTest( ]), is_positive_definite=True), ], + quadrature_size=8, validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess.run, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.07) - def testSampleProbConsistentDynamicQuadrature(self): + def testSampleProbConsistentQuadrature(self): with self.test_session() as sess: - qgrid = array_ops.placeholder(dtype=dtypes.float32) - qprobs = array_ops.placeholder(dtype=dtypes.float32) - g, p = np.polynomial.hermite.hermgauss(deg=8) dims = 4 - vdm = vector_diffeomixture_lib.VectorDiffeomixture( - mix_loc=[[0.], [1.]], - mix_scale=[1.], + vdm = vdm_lib.VectorDiffeomixture( + mix_loc=[0.], + temperature=[0.1], distribution=normal_lib.Normal(0., 1.), loc=[ None, @@ -219,38 +331,14 @@ class VectorDiffeomixtureTest( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], - quadrature_grid_and_probs=(g, p), + quadrature_size=3, validate_args=True) # Ball centered at component0's mean. - sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}) self.run_test_sample_consistent_log_prob( - sess_run_fn, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.015) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess_run_fn, vdm, radius=4., center=2., rtol=0.005) - - # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent, - # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't - # tested that the quadrature approach well-approximates the integral. - # - # To that end, consider adding these tests: - # - # Test1: In the limit of high mix_scale, this approximates a discrete mixture, - # and there are many discrete mixtures where we can explicitly compute - # mean/var, etc... So test1 would choose one of those discrete mixtures and - # show our mean/var/etc... is close to that. - # - # Test2: In the limit of low mix_scale, the a diffeomixture of Normal(-5, 1), - # Normal(5, 1) should (I believe...must check) should look almost like - # Uniform(-5, 5), and thus (i) .prob(x) should be about 1/10 for x in (-5, 5), - # and (ii) the first few moments should approximately match that of - # Uniform(-5, 5) - # - # Test3: If mix_loc is symmetric, then for any mix_scale, our - # quadrature-based diffeomixture of Normal(-1, 1), Normal(1, 1) should have - # mean zero, exactly. - - # TODO(jvdillon): Add more tests which verify broadcasting. + sess.run, vdm, radius=4., center=2., rtol=0.005) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/ops/autoregressive.py b/tensorflow/contrib/distributions/python/ops/autoregressive.py new file mode 100644 index 0000000000000000000000000000000000000000..852298bf334666db003353d5fc8e172ffb738668 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/autoregressive.py @@ -0,0 +1,208 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Autoregressive distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import util as distribution_util + + +class Autoregressive(distribution_lib.Distribution): + """Autoregressive distributions. + + The Autoregressive distribution enables learning (often) richer multivariate + distributions by repeatedly applying a [diffeomorphic]( + https://en.wikipedia.org/wiki/Diffeomorphism) transformation (such as + implemented by `Bijector`s). Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + Practically speaking the autoregressive property means that there exists a + permutation of the event coordinates such that each coordinate is a + diffeomorphic function of only preceding coordinates. [2] + + #### Mathematical Details + + The probability function is, + + ```none + prob(x; fn, n) = fn(x).prob(x) + ``` + + And a sample is generated by, + + ```none + x = fn(...fn(fn(x0).sample()).sample()).sample() + ``` + + where the ellipses (`...`) represent `n-2` composed calls to `fn`, `fn` + constructs a `tf.distributions.Distribution`-like instance, and `x0` is a + fixed initializing `Tensor`. + + #### Examples + + ```python + tfd = tf.contrib.distributions + + def normal_fn(self, event_size): + n = event_size * (event_size + 1) / 2 + p = tf.Variable(tfd.Normal(loc=0., scale=1.).sample(n)) + affine = tfd.bijectors.Affine( + scale_tril=tfd.fill_triangular(0.25 * p)) + def _fn(samples): + scale = math_ops.exp(affine.forward(samples)).eval() + return independent_lib.Independent( + normal_lib.Normal(loc=0., scale=scale, validate_args=True), + reinterpreted_batch_ndims=1) + return _fn + + batch_and_event_shape = [3, 2, 4] + sample0 = array_ops.zeros(batch_and_event_shape) + ar = autoregressive_lib.Autoregressive( + self._normal_fn(batch_and_event_shape[-1]), sample0) + x = ar.sample([6, 5]) + # ==> x.shape = [6, 5, 3, 2, 4] + prob_x = ar.prob(x) + # ==> x.shape = [6, 5, 3, 2] + + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "Conditional Image Generation with PixelCNN Decoders." + Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex + Graves, Koray Kavukcuoglu. Arxiv, 2016. + https://arxiv.org/abs/1606.05328 + """ + + def __init__(self, + distribution_fn, + sample0=None, + num_steps=None, + validate_args=False, + allow_nan_stats=True, + name="Autoregressive"): + """Construct an `Autoregressive` distribution. + + Args: + distribution_fn: Python `callable` which constructs a + `tf.distributions.Distribution`-like instance from a `Tensor` (e.g., + `sample0`). The function must respect the "autoregressive property", + i.e., there exists a permutation of event such that each coordinate is a + diffeomorphic function of on preceding coordinates. + sample0: Initial input to `distribution_fn`; used to + build the distribution in `__init__` which in turn specifies this + distribution's properties, e.g., `event_shape`, `batch_shape`, `dtype`. + If unspecified, then `distribution_fn` should be default constructable. + num_steps: Number of times `distribution_fn` is composed from samples, + e.g., `num_steps=2` implies + `distribution_fn(distribution_fn(sample0).sample(n)).sample()`. + validate_args: Python `bool`. Whether to validate input with asserts. + If `validate_args` is `False`, and the inputs are invalid, + correct behavior is not guaranteed. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + Default value: "Autoregressive". + + Raises: + ValueError: if `num_steps` and + `distribution_fn(sample0).event_shape.num_elements()` are both `None`. + ValueError: if `num_steps < 1`. + """ + parameters = locals() + with ops.name_scope(name): + self._distribution_fn = distribution_fn + self._sample0 = sample0 + self._distribution0 = (distribution_fn() if sample0 is None + else distribution_fn(sample0)) + if num_steps is None: + num_steps = self._distribution0.event_shape.num_elements() + if num_steps is None: + raise ValueError("distribution_fn must generate a distribution " + "with fully known `event_shape`.") + if num_steps < 1: + raise ValueError("num_steps ({}) must be at least 1.".format(num_steps)) + self._num_steps = num_steps + super(Autoregressive, self).__init__( + dtype=self._distribution0.dtype, + reparameterization_type=self._distribution0.reparameterization_type, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=self._distribution0._graph_parents, # pylint: disable=protected-access + name=name) + + @property + def distribution_fn(self): + return self._distribution_fn + + @property + def sample0(self): + return self._sample0 + + @property + def num_steps(self): + return self._num_steps + + @property + def distribution0(self): + return self._distribution0 + + def _batch_shape(self): + return self.distribution0.batch_shape + + def _batch_shape_tensor(self): + return self.distribution0.batch_shape_tensor() + + def _event_shape(self): + return self.distribution0.event_shape + + def _event_shape_tensor(self): + return self.distribution0.event_shape_tensor() + + def _sample_n(self, n, seed=None): + if seed is None: + seed = distribution_util.gen_new_seed( + seed=np.random.randint(2**32 - 1), + salt="autoregressive") + samples = self.distribution0.sample(n, seed=seed) + for _ in range(self._num_steps): + samples = self.distribution_fn(samples).sample(seed=seed) + return samples + + def _log_prob(self, value): + return self.distribution_fn(value).log_prob(value) + + def _prob(self, value): + return self.distribution_fn(value).prob(value) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index bc0ec7f195af009c87020ce8c4ea18f2e713759a..93923c3f083c7f5136b55e9021cbd6323684b976 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -29,6 +29,7 @@ @@MaskedAutoregressiveFlow @@Permute @@PowerTransform +@@RealNVP @@Reshape @@Sigmoid @@SigmoidCentered @@ -39,6 +40,7 @@ @@masked_autoregressive_default_template @@masked_dense +@@real_nvp_default_template """ from __future__ import absolute_import @@ -60,6 +62,7 @@ from tensorflow.contrib.distributions.python.ops.bijectors.invert import * from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * +from tensorflow.contrib.distributions.python.ops.bijectors.real_nvp import * from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 6049419818e18c54209f0be95d41fcecf6627b7e..0fe9f6aa78fbe845b99d0668f075b0162ec2a9f7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -18,12 +18,117 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["AbsoluteValue"] +__all__ = [ + "AbsoluteValue", +] -remove_undocumented(__name__, _allowed_symbols) + +class AbsoluteValue(bijector.Bijector): + """Computes `Y = g(X) = Abs(X)`, element-wise. + + This non-injective bijector allows for transformations of scalar distributions + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + + + ```python + tfd = tf.contrib.distributions + + abs = tfd.bijectors.AbsoluteValue() + + abs.forward([-1., 0., 1.]) + ==> [1., 0., 1.] + + abs.inverse(1.) + ==> [-1., 1.] + + # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. + abs.inverse_log_det_jacobian(1.) + ==> [0., 0.] + + # Special case handling of 0. + abs.inverse(0.) + ==> [0., 0.] + + abs.inverse_log_det_jacobian(0.) + ==> [0., 0.] + ``` + + """ + + def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + """Instantiates the `AbsoluteValue` bijector. + + Args: + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init"): + super(AbsoluteValue, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return math_ops.abs(x) + + def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) + return -y, y + + def _inverse_log_det_jacobian(self, y): + # If event_ndims = 2, + # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), + # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. + batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] + zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) + return zeros, zeros + + @property + def _is_injective(self): + return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py deleted file mode 100644 index b84502003ab6c0c4ffdda21eea162f441509e1fa..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py +++ /dev/null @@ -1,132 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AbsoluteValue bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "AbsoluteValue", -] - - -class AbsoluteValue(bijector.Bijector): - """Computes `Y = g(X) = Abs(X)`, element-wise. - - This non-injective bijector allows for transformations of scalar distributions - with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. - - * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse - `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. - * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse - (the set inverse is the singleton `{0}`), but "works" in conjunction with - `TransformedDistribution` to produce a left semi-continuous pdf. - * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the - wrong thing, `-y, y`. This is done for efficiency. If - `validate_args == True`, `y < 0` will raise an exception. - - - ```python - abs = ds.bijectors.AbsoluteValue() - - abs.forward([-1., 0., 1.]) - ==> [1., 0., 1.] - - abs.inverse(1.) - ==> [-1., 1.] - - # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. - abs.inverse_log_det_jacobian(1.) - ==> [0., 0.] - - # Special case handling of 0. - abs.inverse(0.) - ==> [0., 0.] - - abs.inverse_log_det_jacobian(0.) - ==> [0., 0.] - ``` - - """ - - def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): - """Instantiates the `AbsoluteValue` bijector. - - Args: - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. Currently only zero is - supported. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness, in particular whether inputs to `inverse` and - `inverse_log_det_jacobian` are non-negative. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: If `event_ndims` is not zero. - """ - self._graph_parents = [] - self._name = name - - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0,): - raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) - else: - if validate_args: - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_equal( - event_ndims, 0, message="event_ndims was not 0")], - event_ndims) - - with self._name_scope("init"): - super(AbsoluteValue, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - return math_ops.abs(x) - - def _inverse(self, y): - if self.validate_args: - y = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - y) - return -y, y - - def _inverse_log_det_jacobian(self, y): - # If event_ndims = 2, - # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), - # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. - batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] - zeros = array_ops.zeros(batch_shape, dtype=y.dtype) - if self.validate_args: - zeros = control_flow_ops.with_dependencies( - [check_ops.assert_non_negative(y, message="Argument y was negative")], - zeros) - return zeros, zeros - - @property - def _is_injective(self): - return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py index 940cceff04e77cfc2f7caae5a798d135f7601b95..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine.py @@ -18,12 +18,386 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib import linalg +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Affine"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Affine", +] + + +def _as_tensor(x, name): + """Convenience to convert to `Tensor` or leave as `None`.""" + return None if x is None else ops.convert_to_tensor(x, name=name) + + +class Affine(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. + + In TF parlance, the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + The `scale` term is applied without necessarily materializing constituent + matrices, i.e., the matmul is [matrix-free]( + https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. + + Examples: + + ```python + # Y = X + b = Affine() + + # Y = X + shift + b = Affine(shift=[1., 2, 3]) + + # Y = 2 * I @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_identity_multiplier=2.) + + # Y = tf.diag(d1) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[-1., 2, 1]) # Implicitly 3x3. + + # Y = (I + v * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift + b = Affine(shift=[1., 2, 3], + scale_diag=[1., 3, 3], # Implicitly 3x3. + scale_perturb_diag=[2., 1], # Implicitly 2x2. + scale_perturb_factor=[[1., 0], + [0, 1], + [1, 1]]) + + ``` + + """ + + def __init__(self, + shift=None, + scale_identity_multiplier=None, + scale_diag=None, + scale_tril=None, + scale_perturb_factor=None, + scale_perturb_diag=None, + event_ndims=1, + validate_args=False, + name="affine"): + """Instantiates the `Affine` bijector. + + This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, + giving the forward operation: + + ```none + Y = g(X) = scale @ X + shift + ``` + + where the `scale` term is logically equivalent to: + + ```python + scale = ( + scale_identity_multiplier * tf.diag(tf.ones(d)) + + tf.diag(scale_diag) + + scale_tril + + scale_perturb_factor @ diag(scale_perturb_diag) @ + tf.transpose([scale_perturb_factor]) + ) + ``` + + If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are + specified then `scale += IdentityMatrix`. Otherwise specifying a + `scale` argument has the semantics of `scale += Expand(arg)`, i.e., + `scale_diag != None` means `scale += tf.diag(scale_diag)`. + + Args: + shift: Floating-point `Tensor`. If this is set to `None`, no shift is + applied. + scale_identity_multiplier: floating point rank 0 `Tensor` representing a + scaling done to the identity matrix. + When `scale_identity_multiplier = scale_diag = scale_tril = None` then + `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added + to `scale`. + scale_diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + When `None` no diagonal term is added to `scale`. + scale_tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k + lower triangular matrix. + When `None` no `scale_tril` term is added to `scale`. + The upper triangular elements above the diagonal are ignored. + scale_perturb_factor: Floating-point `Tensor` representing factor matrix + with last two dimensions of shape `(k, r)`. When `None`, no rank-r + update is added to `scale`. + scale_perturb_diag: Floating-point `Tensor` representing the diagonal + matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which + represents an `r x r` diagonal matrix. When `None` low rank updates will + take the form `scale_perturb_factor * scale_perturb_factor.T`. + event_ndims: Scalar `int` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `perturb_diag` is specified but not `perturb_factor`. + TypeError: if `shift` has different `dtype` from `scale` arguments. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + + # Ambiguous definition of low rank update. + if scale_perturb_diag is not None and scale_perturb_factor is None: + raise ValueError("When scale_perturb_diag is specified, " + "scale_perturb_factor must be specified.") + + # Special case, only handling a scaled identity matrix. We don't know its + # dimensions, so this is special cased. + # We don't check identity_multiplier, since below we set it to 1. if all + # other scale args are None. + self._is_only_identity_multiplier = (scale_tril is None and + scale_diag is None and + scale_perturb_factor is None) + + with self._name_scope("init", values=[ + shift, scale_identity_multiplier, scale_diag, scale_tril, + scale_perturb_diag, scale_perturb_factor]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0, 1): + raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + + if event_ndims_const == 0 and not self._is_only_identity_multiplier: + raise ValueError( + "If event_ndims == 0, the only scale argument you can pass is " + "scale_identity_multiplier. All others operate on vectors.") + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + dtype = shift.dtype.base_dtype + self._shift = shift + + # When no args are specified, pretend the scale matrix is the identity + # matrix. + if (self._is_only_identity_multiplier and + scale_identity_multiplier is None): + scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) + + # self._create_scale_operator returns a LinearOperator in all cases + # except if self._is_only_identity_multiplier; in which case it + # returns a scalar Tensor. + scale = self._create_scale_operator( + identity_multiplier=scale_identity_multiplier, + diag=scale_diag, + tril=scale_tril, + perturb_diag=scale_perturb_diag, + perturb_factor=scale_perturb_factor, + shift=shift, + validate_args=validate_args) + + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + + if scale is not None and not self._is_only_identity_multiplier: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + else: + # We won't need shape inference when scale is None or when scale is a + # scalar. + batch_ndims = 0 + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(Affine, self).__init__( + event_ndims=event_ndims, + graph_parents=( + [event_ndims] + + [self._scale] if tensor_util.is_tensor(self._scale) + else self._scale.graph_parents + + [self._shift] if self._shift is not None else []), + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + def _create_scale_operator(self, identity_multiplier, diag, tril, + perturb_diag, perturb_factor, shift, + validate_args): + """Construct `scale` from various components. + + Args: + identity_multiplier: floating point rank 0 `Tensor` representing a scaling + done to the identity matrix. + diag: Floating-point `Tensor` representing the diagonal matrix. + `scale_diag` has shape [N1, N2, ... k], which represents a k x k + diagonal matrix. + tril: Floating-point `Tensor` representing the diagonal matrix. + `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower + triangular matrix. + perturb_diag: Floating-point `Tensor` representing the diagonal matrix of + the low rank update. + perturb_factor: Floating-point `Tensor` representing factor matrix. + shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + + Returns: + scale. In the case of scaling by a constant, scale is a + floating point `Tensor`. Otherwise, scale is a `LinearOperator`. + + Raises: + ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. + """ + identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") + diag = _as_tensor(diag, "diag") + tril = _as_tensor(tril, "tril") + perturb_diag = _as_tensor(perturb_diag, "perturb_diag") + perturb_factor = _as_tensor(perturb_factor, "perturb_factor") + + # If possible, use the low rank update to infer the shape of + # the identity matrix, when scale represents a scaled identity matrix + # with a low rank update. + shape_hint = None + if perturb_factor is not None: + shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) + + if self._is_only_identity_multiplier: + if validate_args: + return control_flow_ops.with_dependencies( + [check_ops.assert_none_equal( + identity_multiplier, + array_ops.zeros([], identity_multiplier.dtype), + ["identity_multiplier should be non-zero."])], + identity_multiplier) + return identity_multiplier + + scale = distribution_util.make_tril_scale( + loc=shift, + scale_tril=tril, + scale_diag=diag, + scale_identity_multiplier=identity_multiplier, + validate_args=validate_args, + assert_positive=False, + shape_hint=shape_hint) + + if perturb_factor is not None: + return linalg.LinearOperatorLowRankUpdate( + scale, + u=perturb_factor, + diag_update=perturb_diag, + is_diag_update_positive=perturb_diag is None, + is_non_singular=True, # Implied by is_positive_definite=True. + is_self_adjoint=True, + is_positive_definite=True, + is_square=True) + + return scale + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self._is_only_identity_multiplier: + y *= self._scale + if self.shift is not None: + return y + self.shift + return y + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_check_scale() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self._is_only_identity_multiplier: + return x / self._scale + + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): + if self._is_only_identity_multiplier: + # We don't pad in this case and instead let the fldj be applied + # via broadcast. + event_size = distribution_util.pick_vector( + math_ops.equal(self._shaper.event_ndims, 0), + [1], array_ops.shape(x))[-1] + event_size = math_ops.cast(event_size, dtype=self._scale.dtype) + return math_ops.log(math_ops.abs(self._scale)) * event_size + return self.scale.log_abs_determinant() + + def _maybe_check_scale(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py deleted file mode 100644 index 05bb9c2f9bdf35e222c94db3491157893da64ebd..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ /dev/null @@ -1,403 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Affine bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import linalg -from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Affine", -] - - -def _as_tensor(x, name): - """Convenience to convert to `Tensor` or leave as `None`.""" - return None if x is None else ops.convert_to_tensor(x, name=name) - - -class Affine(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - Here `scale = c * I + diag(D1) + tril(L) + V @ diag(D2) @ V.T`. - - In TF parlance, the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - The `scale` term is applied without necessarily materializing constituent - matrices, i.e., the matmul is [matrix-free]( - https://en.wikipedia.org/wiki/Matrix-free_methods) when possible. - - Examples: - - ```python - # Y = X - b = Affine() - - # Y = X + shift - b = Affine(shift=[1., 2, 3]) - - # Y = 2 * I @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_identity_multiplier=2.) - - # Y = tf.diag(d1) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[-1., 2, 1]) # Implicitly 3x3. - - # Y = (I + v * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - # Y = (diag(d1) + v * diag(d2) * v.T) @ X.T + shift - b = Affine(shift=[1., 2, 3], - scale_diag=[1., 3, 3], # Implicitly 3x3. - scale_perturb_diag=[2., 1], # Implicitly 2x2. - scale_perturb_factor=[[1., 0], - [0, 1], - [1, 1]]) - - ``` - - """ - - def __init__(self, - shift=None, - scale_identity_multiplier=None, - scale_diag=None, - scale_tril=None, - scale_perturb_factor=None, - scale_perturb_diag=None, - event_ndims=1, - validate_args=False, - name="affine"): - """Instantiates the `Affine` bijector. - - This `Bijector` is initialized with `shift` `Tensor` and `scale` arguments, - giving the forward operation: - - ```none - Y = g(X) = scale @ X + shift - ``` - - where the `scale` term is logically equivalent to: - - ```python - scale = ( - scale_identity_multiplier * tf.diag(tf.ones(d)) + - tf.diag(scale_diag) + - scale_tril + - scale_perturb_factor @ diag(scale_perturb_diag) @ - tf.transpose([scale_perturb_factor]) - ) - ``` - - If none of `scale_identity_multiplier`, `scale_diag`, or `scale_tril` are - specified then `scale += IdentityMatrix`. Otherwise specifying a - `scale` argument has the semantics of `scale += Expand(arg)`, i.e., - `scale_diag != None` means `scale += tf.diag(scale_diag)`. - - Args: - shift: Floating-point `Tensor`. If this is set to `None`, no shift is - applied. - scale_identity_multiplier: floating point rank 0 `Tensor` representing a - scaling done to the identity matrix. - When `scale_identity_multiplier = scale_diag = scale_tril = None` then - `scale += IdentityMatrix`. Otherwise no scaled-identity-matrix is added - to `scale`. - scale_diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - When `None` no diagonal term is added to `scale`. - scale_tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k, k], which represents a k x k - lower triangular matrix. - When `None` no `scale_tril` term is added to `scale`. - The upper triangular elements above the diagonal are ignored. - scale_perturb_factor: Floating-point `Tensor` representing factor matrix - with last two dimensions of shape `(k, r)`. When `None`, no rank-r - update is added to `scale`. - scale_perturb_diag: Floating-point `Tensor` representing the diagonal - matrix. `scale_perturb_diag` has shape [N1, N2, ... r], which - represents an `r x r` diagonal matrix. When `None` low rank updates will - take the form `scale_perturb_factor * scale_perturb_factor.T`. - event_ndims: Scalar `int` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `perturb_diag` is specified but not `perturb_factor`. - TypeError: if `shift` has different `dtype` from `scale` arguments. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - - # Ambiguous definition of low rank update. - if scale_perturb_diag is not None and scale_perturb_factor is None: - raise ValueError("When scale_perturb_diag is specified, " - "scale_perturb_factor must be specified.") - - # Special case, only handling a scaled identity matrix. We don't know its - # dimensions, so this is special cased. - # We don't check identity_multiplier, since below we set it to 1. if all - # other scale args are None. - self._is_only_identity_multiplier = (scale_tril is None and - scale_diag is None and - scale_perturb_factor is None) - - with self._name_scope("init", values=[ - shift, scale_identity_multiplier, scale_diag, scale_tril, - scale_perturb_diag, scale_perturb_factor]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims_const = tensor_util.constant_value(event_ndims) - if event_ndims_const is not None and event_ndims_const not in (0, 1): - raise ValueError("event_ndims(%s) was not 0 or 1" % event_ndims_const) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - - if event_ndims_const == 0 and not self._is_only_identity_multiplier: - raise ValueError( - "If event_ndims == 0, the only scale argument you can pass is " - "scale_identity_multiplier. All others operate on vectors.") - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - dtype = shift.dtype.base_dtype - self._shift = shift - - # When no args are specified, pretend the scale matrix is the identity - # matrix. - if (self._is_only_identity_multiplier and - scale_identity_multiplier is None): - scale_identity_multiplier = ops.convert_to_tensor(1., dtype=dtype) - - # self._create_scale_operator returns a LinearOperator in all cases - # except if self._is_only_identity_multiplier; in which case it - # returns a scalar Tensor. - scale = self._create_scale_operator( - identity_multiplier=scale_identity_multiplier, - diag=scale_diag, - tril=scale_tril, - perturb_diag=scale_perturb_diag, - perturb_factor=scale_perturb_factor, - shift=shift, - validate_args=validate_args) - - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - - if scale is not None and not self._is_only_identity_multiplier: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - else: - # We won't need shape inference when scale is None or when scale is a - # scalar. - batch_ndims = 0 - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(Affine, self).__init__( - event_ndims=event_ndims, - graph_parents=( - [event_ndims] + - [self._scale] if tensor_util.is_tensor(self._scale) - else self._scale.graph_parents + - [self._shift] if self._shift is not None else []), - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - def _create_scale_operator(self, identity_multiplier, diag, tril, - perturb_diag, perturb_factor, shift, - validate_args): - """Construct `scale` from various components. - - Args: - identity_multiplier: floating point rank 0 `Tensor` representing a scaling - done to the identity matrix. - diag: Floating-point `Tensor` representing the diagonal matrix. - `scale_diag` has shape [N1, N2, ... k], which represents a k x k - diagonal matrix. - tril: Floating-point `Tensor` representing the diagonal matrix. - `scale_tril` has shape [N1, N2, ... k], which represents a k x k lower - triangular matrix. - perturb_diag: Floating-point `Tensor` representing the diagonal matrix of - the low rank update. - perturb_factor: Floating-point `Tensor` representing factor matrix. - shift: Floating-point `Tensor` representing `shift in `scale @ X + shift`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - - Returns: - scale. In the case of scaling by a constant, scale is a - floating point `Tensor`. Otherwise, scale is a `LinearOperator`. - - Raises: - ValueError: if all of `tril`, `diag` and `identity_multiplier` are `None`. - """ - identity_multiplier = _as_tensor(identity_multiplier, "identity_multiplier") - diag = _as_tensor(diag, "diag") - tril = _as_tensor(tril, "tril") - perturb_diag = _as_tensor(perturb_diag, "perturb_diag") - perturb_factor = _as_tensor(perturb_factor, "perturb_factor") - - # If possible, use the low rank update to infer the shape of - # the identity matrix, when scale represents a scaled identity matrix - # with a low rank update. - shape_hint = None - if perturb_factor is not None: - shape_hint = distribution_util.dimension_size(perturb_factor, axis=-2) - - if self._is_only_identity_multiplier: - if validate_args: - return control_flow_ops.with_dependencies( - [check_ops.assert_none_equal( - identity_multiplier, - array_ops.zeros([], identity_multiplier.dtype), - ["identity_multiplier should be non-zero."])], - identity_multiplier) - return identity_multiplier - - scale = distribution_util.make_tril_scale( - loc=shift, - scale_tril=tril, - scale_diag=diag, - scale_identity_multiplier=identity_multiplier, - validate_args=validate_args, - assert_positive=False, - shape_hint=shape_hint) - - if perturb_factor is not None: - return linalg.LinearOperatorLowRankUpdate( - scale, - u=perturb_factor, - diag_update=perturb_diag, - is_diag_update_positive=perturb_diag is None, - is_non_singular=True, # Implied by is_positive_definite=True. - is_self_adjoint=True, - is_positive_definite=True, - is_square=True) - - return scale - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self._is_only_identity_multiplier: - y *= self._scale - if self.shift is not None: - return y + self.shift - return y - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_check_scale() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self._is_only_identity_multiplier: - return x / self._scale - - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): - if self._is_only_identity_multiplier: - # We don't pad in this case and instead let the fldj be applied - # via broadcast. - event_size = distribution_util.pick_vector( - math_ops.equal(self._shaper.event_ndims, 0), - [1], array_ops.shape(x))[-1] - event_size = math_ops.cast(event_size, dtype=self._scale.dtype) - return math_ops.log(math_ops.abs(self._scale)) * event_size - return self.scale.log_abs_determinant() - - def _maybe_check_scale(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py index aca04a89df7c3ee09d5f7cc10f6779e33fa7aa66..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator.py @@ -18,12 +18,214 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator -_allowed_symbols = ["AffineLinearOperator"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "AffineLinearOperator", +] + + +class AffineLinearOperator(bijector.Bijector): + """Compute `Y = g(X; shift, scale) = scale @ X + shift`. + + `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. + + If `X` is a scalar then the forward transformation is: `scale * X + shift` + where `*` denotes the scalar product. + + Note: we don't always simply transpose `X` (but write it this way for + brevity). Actually the input `X` undergoes the following transformation + before being premultiplied by `scale`: + + 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., + `new_sample_shape = [1]`. Otherwise do nothing. + 2. The sample shape is flattened to have one dimension, i.e., + `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. + 3. The sample dim is cyclically rotated left by 1, i.e., + `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the + event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch + dimensions. + + (For more details see `shape.make_batch_of_event_sample_matrices`.) + + The result of the above transformation is that `X` can be regarded as a batch + of matrices where each column is a draw from the distribution. After + premultiplying by `scale`, we take the inverse of this procedure. The input + `Y` also undergoes the same transformation before/after premultiplying by + `inv(scale)`. + + Example Use: + + ```python + linalg = tf.linalg + + x = [1., 2, 3] + + shift = [-1., 0., 1] + diag = [1., 2, 3] + scale = linalg.LinearOperatorDiag(diag) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # y = scale @ x + shift + y = affine.forward(x) # [0., 4, 10] + + shift = [2., 3, 1] + tril = [[1., 0, 0], + [2, 1, 0], + [3, 2, 1]] + scale = linalg.LinearOperatorLowerTriangular(tril) + affine = AffineLinearOperator(shift, scale) + # In this case, `forward` is equivalent to: + # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift + y = affine.forward(x) # [3., 7, 11] + ``` + + """ + + def __init__(self, + shift=None, + scale=None, + event_ndims=1, + validate_args=False, + name="affine_linear_operator"): + """Instantiates the `AffineLinearOperator` bijector. + + Args: + shift: Floating-point `Tensor`. + scale: Subclass of `LinearOperator`. Represents the (batch) positive + definite matrix `M` in `R^{k x k}`. + event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. Must be 0 or 1. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `event_ndims` is not 0 or 1. + TypeError: if `scale` is not a `LinearOperator`. + TypeError: if `shift.dtype` does not match `scale.dtype`. + ValueError: if not `scale.is_non_singular`. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + graph_parents = [] + with self._name_scope("init", values=[shift]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + if tensor_util.constant_value(event_ndims) is not None: + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims not in (0, 1): + raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) + else: + if validate_args: + # Shape tool will catch if event_ndims is negative. + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_less( + event_ndims, 2, message="event_ndims must be 0 or 1")], + event_ndims) + graph_parents += [event_ndims] + + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. + dtype = dtypes.float32 + + if shift is not None: + shift = ops.convert_to_tensor(shift, name="shift") + graph_parents += [shift] + dtype = shift.dtype.base_dtype + self._shift = shift + + if scale is not None: + if (shift is not None and + shift.dtype.base_dtype != scale.dtype.base_dtype): + raise TypeError( + "shift.dtype({}) is incompatible with scale.dtype({}).".format( + shift.dtype, scale.dtype)) + if not isinstance(scale, linear_operator.LinearOperator): + raise TypeError("scale is not an instance of tf.LinearOperator") + if validate_args and not scale.is_non_singular: + raise ValueError("Scale matrix must be non-singular.") + graph_parents += scale.graph_parents + if scale.tensor_rank is not None: + batch_ndims = scale.tensor_rank - 2 + else: + batch_ndims = scale.tensor_rank_tensor() - 2 + graph_parents += [batch_ndims] + if scale.dtype is not None: + dtype = scale.dtype.base_dtype + else: + batch_ndims = 0 # We won't need shape inference when scale is None. + self._scale = scale + self._shaper = _DistributionShape( + batch_ndims=batch_ndims, + event_ndims=event_ndims, + validate_args=validate_args) + super(AffineLinearOperator, self).__init__( + event_ndims=event_ndims, + graph_parents=graph_parents, + is_constant_jacobian=True, + dtype=dtype, + validate_args=validate_args, + name=name) + + @property + def shift(self): + """The `shift` `Tensor` in `Y = scale @ X + shift`.""" + return self._shift + + @property + def scale(self): + """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" + return self._scale + + def _forward(self, x): + y = x + if self.scale is not None: + y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + y, expand_batch_dim=False) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + y = self.scale.matmul(y) + y = self._shaper.undo_make_batch_of_event_sample_matrices( + y, sample_shape, expand_batch_dim=False) + if self.shift is not None: + y += self.shift + return y + + def _inverse(self, y): + x = y + if self.shift is not None: + x -= self.shift + if self.scale is not None: + x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( + x, expand_batch_dim=False) + # Solve fails if the op is singular so we may safely skip this assertion. + x = self.scale.solve(x) + x = self._shaper.undo_make_batch_of_event_sample_matrices( + x, sample_shape, expand_batch_dim=False) + return x + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(y) + + def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument + if self.scale is None: + return constant_op.constant(0, dtype=x.dtype.base_dtype) + with ops.control_dependencies(self._maybe_collect_assertions() if + self.validate_args else []): + return self.scale.log_abs_determinant() + + def _maybe_collect_assertions(self): + try: + return [self.scale.assert_non_singular()] + except NotImplementedError: + pass + return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py deleted file mode 100644 index 89043b1410370074f11f2cfa59b6b6663fa62521..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ /dev/null @@ -1,231 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""AffineLinearOperator bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.linalg import linear_operator - - -__all__ = [ - "AffineLinearOperator", -] - - -class AffineLinearOperator(bijector.Bijector): - """Compute `Y = g(X; shift, scale) = scale @ X + shift`. - - `shift` is a numeric `Tensor` and `scale` is a `LinearOperator`. - - If `X` is a scalar then the forward transformation is: `scale * X + shift` - where `*` denotes the scalar product. - - Note: we don't always simply transpose `X` (but write it this way for - brevity). Actually the input `X` undergoes the following transformation - before being premultiplied by `scale`: - - 1. If there are no sample dims, we call `X = tf.expand_dims(X, 0)`, i.e., - `new_sample_shape = [1]`. Otherwise do nothing. - 2. The sample shape is flattened to have one dimension, i.e., - `new_sample_shape = [n]` where `n = tf.reduce_prod(old_sample_shape)`. - 3. The sample dim is cyclically rotated left by 1, i.e., - `new_shape = [B1,...,Bb, k, n]` where `n` is as above, `k` is the - event_shape, and `B1,...,Bb` are the batch shapes for each of `b` batch - dimensions. - - (For more details see `shape.make_batch_of_event_sample_matrices`.) - - The result of the above transformation is that `X` can be regarded as a batch - of matrices where each column is a draw from the distribution. After - premultiplying by `scale`, we take the inverse of this procedure. The input - `Y` also undergoes the same transformation before/after premultiplying by - `inv(scale)`. - - Example Use: - - ```python - linalg = tf.linalg - - x = [1., 2, 3] - - shift = [-1., 0., 1] - diag = [1., 2, 3] - scale = linalg.LinearOperatorDiag(diag) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # y = scale @ x + shift - y = affine.forward(x) # [0., 4, 10] - - shift = [2., 3, 1] - tril = [[1., 0, 0], - [2, 1, 0], - [3, 2, 1]] - scale = linalg.LinearOperatorLowerTriangular(tril) - affine = AffineLinearOperator(shift, scale) - # In this case, `forward` is equivalent to: - # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift - y = affine.forward(x) # [3., 7, 11] - ``` - - """ - - def __init__(self, - shift=None, - scale=None, - event_ndims=1, - validate_args=False, - name="affine_linear_operator"): - """Instantiates the `AffineLinearOperator` bijector. - - Args: - shift: Floating-point `Tensor`. - scale: Subclass of `LinearOperator`. Represents the (batch) positive - definite matrix `M` in `R^{k x k}`. - event_ndims: Scalar `integer` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. Must be 0 or 1. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `event_ndims` is not 0 or 1. - TypeError: if `scale` is not a `LinearOperator`. - TypeError: if `shift.dtype` does not match `scale.dtype`. - ValueError: if not `scale.is_non_singular`. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - graph_parents = [] - with self._name_scope("init", values=[shift]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - if tensor_util.constant_value(event_ndims) is not None: - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims not in (0, 1): - raise ValueError("event_ndims({}) was not 0 or 1".format(event_ndims)) - else: - if validate_args: - # Shape tool will catch if event_ndims is negative. - event_ndims = control_flow_ops.with_dependencies( - [check_ops.assert_less( - event_ndims, 2, message="event_ndims must be 0 or 1")], - event_ndims) - graph_parents += [event_ndims] - - # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. - dtype = dtypes.float32 - - if shift is not None: - shift = ops.convert_to_tensor(shift, name="shift") - graph_parents += [shift] - dtype = shift.dtype.base_dtype - self._shift = shift - - if scale is not None: - if (shift is not None and - shift.dtype.base_dtype != scale.dtype.base_dtype): - raise TypeError( - "shift.dtype({}) is incompatible with scale.dtype({}).".format( - shift.dtype, scale.dtype)) - if not isinstance(scale, linear_operator.LinearOperator): - raise TypeError("scale is not an instance of tf.LinearOperator") - if validate_args and not scale.is_non_singular: - raise ValueError("Scale matrix must be non-singular.") - graph_parents += scale.graph_parents - if scale.tensor_rank is not None: - batch_ndims = scale.tensor_rank - 2 - else: - batch_ndims = scale.tensor_rank_tensor() - 2 - graph_parents += [batch_ndims] - if scale.dtype is not None: - dtype = scale.dtype.base_dtype - else: - batch_ndims = 0 # We won't need shape inference when scale is None. - self._scale = scale - self._shaper = _DistributionShape( - batch_ndims=batch_ndims, - event_ndims=event_ndims, - validate_args=validate_args) - super(AffineLinearOperator, self).__init__( - event_ndims=event_ndims, - graph_parents=graph_parents, - is_constant_jacobian=True, - dtype=dtype, - validate_args=validate_args, - name=name) - - @property - def shift(self): - """The `shift` `Tensor` in `Y = scale @ X + shift`.""" - return self._shift - - @property - def scale(self): - """The `scale` `LinearOperator` in `Y = scale @ X + shift`.""" - return self._scale - - def _forward(self, x): - y = x - if self.scale is not None: - y, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - y, expand_batch_dim=False) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - y = self.scale.matmul(y) - y = self._shaper.undo_make_batch_of_event_sample_matrices( - y, sample_shape, expand_batch_dim=False) - if self.shift is not None: - y += self.shift - return y - - def _inverse(self, y): - x = y - if self.shift is not None: - x -= self.shift - if self.scale is not None: - x, sample_shape = self._shaper.make_batch_of_event_sample_matrices( - x, expand_batch_dim=False) - # Solve fails if the op is singular so we may safely skip this assertion. - x = self.scale.solve(x) - x = self._shaper.undo_make_batch_of_event_sample_matrices( - x, sample_shape, expand_batch_dim=False) - return x - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(y) - - def _forward_log_det_jacobian(self, x): # pylint: disable=unused-argument - if self.scale is None: - return constant_op.constant(0, dtype=x.dtype.base_dtype) - with ops.control_dependencies(self._maybe_collect_assertions() if - self.validate_args else []): - return self.scale.log_abs_determinant() - - def _maybe_collect_assertions(self): - try: - return [self.scale.assert_non_singular()] - except NotImplementedError: - pass - return [] diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py index 0db10fb75c8483a8209f39370362b05a03d047ca..3ce7c26213034c7345a20faa803c94a1bfa8d579 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain.py @@ -18,12 +18,151 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.chain_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import itertools -_allowed_symbols = ["Chain"] +from tensorflow.python.framework import constant_op +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Chain", +] + + +class Chain(bijector.Bijector): + """Bijector which applies a sequence of bijectors. + + Example Use: + + ```python + chain = Chain([Exp(), Softplus()], name="one_plus_exp") + ``` + + Results in: + + * Forward: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).forward(x) + = exp.forward(softplus.forward(x)) + = tf.exp(tf.log(1. + tf.exp(x))) + = 1. + tf.exp(x) + ``` + + * Inverse: + + ```python + exp = Exp() + softplus = Softplus() + Chain([exp, softplus]).inverse(y) + = softplus.inverse(exp.inverse(y)) + = tf.log(tf.exp(tf.log(y)) - 1.) + = tf.log(y - 1.) + ``` + + """ + + def __init__(self, bijectors=None, validate_args=False, name=None): + """Instantiates `Chain` bijector. + + Args: + bijectors: Python `list` of bijector instances. An empty list makes this + bijector equivalent to the `Identity` bijector. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. Default: + E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. + + Raises: + ValueError: if bijectors have different dtypes. + """ + if bijectors is None: + bijectors = () + self._bijectors = bijectors + + for a_bijector in bijectors: + if not a_bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijector ({})".format( + a_bijector.name)) + + dtype = list(set([b.dtype for b in bijectors])) + if len(dtype) > 2: + raise ValueError("incompatible dtypes: %s" % dtype) + elif len(dtype) == 2: + dtype = dtype[1] if dtype[0] is None else dtype[0] + event_ndims = bijectors[0].event_ndims + elif len(dtype) == 1: + dtype = dtype[0] + event_ndims = bijectors[0].event_ndims + else: + dtype = None + event_ndims = None + + super(Chain, self).__init__( + graph_parents=list(itertools.chain.from_iterable( + b.graph_parents for b in bijectors)), + is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), + validate_args=validate_args, + dtype=dtype, + event_ndims=event_ndims, + name=name or ("identity" if not bijectors else + "_of_".join(["chain"] + [b.name for b in bijectors]))) + + @property + def bijectors(self): + return self._bijectors + + def _shape_helper(self, func_name, input_shape, reverse): + new_shape = input_shape + for b in reversed(self.bijectors) if reverse else self.bijectors: + func = getattr(b, func_name, None) + if func is None: + raise ValueError("unable to call %s on bijector %s (%s)" % + (func_name, b.name, func)) + new_shape = func(new_shape) + return new_shape + + def _forward_event_shape(self, input_shape): + return self._shape_helper("forward_event_shape", input_shape, + reverse=True) + + def _forward_event_shape_tensor(self, input_shape): + return self._shape_helper( + "forward_event_shape_tensor", input_shape, reverse=True) + + def _inverse_event_shape(self, output_shape): + return self._shape_helper("inverse_event_shape", output_shape, + reverse=False) + + def _inverse_event_shape_tensor(self, output_shape): + return self._shape_helper("inverse_event_shape_tensor", output_shape, + reverse=False) + + def _inverse(self, y, **kwargs): + for b in self.bijectors: + y = b.inverse(y, **kwargs.get(b.name, {})) + return y + + def _inverse_log_det_jacobian(self, y, **kwargs): + ildj = constant_op.constant(0., dtype=y.dtype, + name="inverse_log_det_jacobian") + for b in self.bijectors: + ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) + y = b.inverse(y, **kwargs.get(b.name, {})) + return ildj + + def _forward(self, x, **kwargs): + for b in reversed(self.bijectors): + x = b.forward(x, **kwargs.get(b.name, {})) + return x + + def _forward_log_det_jacobian(self, x, **kwargs): + fldj = constant_op.constant(0., dtype=x.dtype, + name="forward_log_det_jacobian") + for b in reversed(self.bijectors): + fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) + x = b.forward(x, **kwargs.get(b.name, {})) + return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py deleted file mode 100644 index 3ce7c26213034c7345a20faa803c94a1bfa8d579..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Chain bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import itertools - -from tensorflow.python.framework import constant_op -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Chain", -] - - -class Chain(bijector.Bijector): - """Bijector which applies a sequence of bijectors. - - Example Use: - - ```python - chain = Chain([Exp(), Softplus()], name="one_plus_exp") - ``` - - Results in: - - * Forward: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).forward(x) - = exp.forward(softplus.forward(x)) - = tf.exp(tf.log(1. + tf.exp(x))) - = 1. + tf.exp(x) - ``` - - * Inverse: - - ```python - exp = Exp() - softplus = Softplus() - Chain([exp, softplus]).inverse(y) - = softplus.inverse(exp.inverse(y)) - = tf.log(tf.exp(tf.log(y)) - 1.) - = tf.log(y - 1.) - ``` - - """ - - def __init__(self, bijectors=None, validate_args=False, name=None): - """Instantiates `Chain` bijector. - - Args: - bijectors: Python `list` of bijector instances. An empty list makes this - bijector equivalent to the `Identity` bijector. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. Default: - E.g., `Chain([Exp(), Softplus()]).name == "chain_of_exp_of_softplus"`. - - Raises: - ValueError: if bijectors have different dtypes. - """ - if bijectors is None: - bijectors = () - self._bijectors = bijectors - - for a_bijector in bijectors: - if not a_bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijector ({})".format( - a_bijector.name)) - - dtype = list(set([b.dtype for b in bijectors])) - if len(dtype) > 2: - raise ValueError("incompatible dtypes: %s" % dtype) - elif len(dtype) == 2: - dtype = dtype[1] if dtype[0] is None else dtype[0] - event_ndims = bijectors[0].event_ndims - elif len(dtype) == 1: - dtype = dtype[0] - event_ndims = bijectors[0].event_ndims - else: - dtype = None - event_ndims = None - - super(Chain, self).__init__( - graph_parents=list(itertools.chain.from_iterable( - b.graph_parents for b in bijectors)), - is_constant_jacobian=all(b.is_constant_jacobian for b in bijectors), - validate_args=validate_args, - dtype=dtype, - event_ndims=event_ndims, - name=name or ("identity" if not bijectors else - "_of_".join(["chain"] + [b.name for b in bijectors]))) - - @property - def bijectors(self): - return self._bijectors - - def _shape_helper(self, func_name, input_shape, reverse): - new_shape = input_shape - for b in reversed(self.bijectors) if reverse else self.bijectors: - func = getattr(b, func_name, None) - if func is None: - raise ValueError("unable to call %s on bijector %s (%s)" % - (func_name, b.name, func)) - new_shape = func(new_shape) - return new_shape - - def _forward_event_shape(self, input_shape): - return self._shape_helper("forward_event_shape", input_shape, - reverse=True) - - def _forward_event_shape_tensor(self, input_shape): - return self._shape_helper( - "forward_event_shape_tensor", input_shape, reverse=True) - - def _inverse_event_shape(self, output_shape): - return self._shape_helper("inverse_event_shape", output_shape, - reverse=False) - - def _inverse_event_shape_tensor(self, output_shape): - return self._shape_helper("inverse_event_shape_tensor", output_shape, - reverse=False) - - def _inverse(self, y, **kwargs): - for b in self.bijectors: - y = b.inverse(y, **kwargs.get(b.name, {})) - return y - - def _inverse_log_det_jacobian(self, y, **kwargs): - ildj = constant_op.constant(0., dtype=y.dtype, - name="inverse_log_det_jacobian") - for b in self.bijectors: - ildj += b.inverse_log_det_jacobian(y, **kwargs.get(b.name, {})) - y = b.inverse(y, **kwargs.get(b.name, {})) - return ildj - - def _forward(self, x, **kwargs): - for b in reversed(self.bijectors): - x = b.forward(x, **kwargs.get(b.name, {})) - return x - - def _forward_log_det_jacobian(self, x, **kwargs): - fldj = constant_op.constant(0., dtype=x.dtype, - name="forward_log_det_jacobian") - for b in reversed(self.bijectors): - fldj += b.forward_log_det_jacobian(x, **kwargs.get(b.name, {})) - x = b.forward(x, **kwargs.get(b.name, {})) - return fldj diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py index 4686af8bc42a3232cb3a34f2cfcce8323c5896dd..cbd60f92a60612c6cf791b2c7708a3310c6e2b6b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product.py @@ -18,12 +18,219 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["CholeskyOuterProduct"] +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "CholeskyOuterProduct", +] + + +class CholeskyOuterProduct(bijector.Bijector): + """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. + + `event_ndims` must be 0 or 2, i.e., scalar or matrix. + + Note: the upper-triangular part of X is ignored (whether or not its zero). + + The surjectivity of g as a map from the set of n x n positive-diagonal + lower-triangular matrices to the set of SPD matrices follows immediately from + executing the Cholesky factorization algorithm on an SPD matrix A to produce a + positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. + + To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular + with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then + `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. + Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal + lower-triangular matrix follows from `inv(L_1)` being positive-diagonal + lower-triangular (which follows from the diagonal of a triangular matrix being + its spectrum), and that the product of two positive-diagonal lower-triangular + matrices is another positive-diagonal lower-triangular matrix. + + A simple inductive argument (proceding one column of L_3 at a time) shows + that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- + diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. + + Examples: + + ```python + bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) + # Result: [[1., 2], [2, 5]], i.e., x @ x.T + + bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) + # Result: [[1., 0], [2, 1]], i.e., cholesky(y). + ``` + + """ + + def __init__(self, event_ndims=2, validate_args=False, + name="cholesky_outer_product"): + """Instantiates the `CholeskyOuterProduct` bijector. + + Args: + event_ndims: `constant` `int32` scalar `Tensor` indicating the number of + dimensions associated with a particular draw from the distribution. Must + be 0 or 2. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if event_ndims is neither 0 or 2. + """ + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 2]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") + self._static_event_ndims = event_ndims + super(CholeskyOuterProduct, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self._static_event_ndims == 0: + return math_ops.square(x) + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least(x, 2) + shape = array_ops.shape(x) + is_square = check_ops.assert_equal(shape[-2], shape[-1]) + x = control_flow_ops.with_dependencies([is_matrix, is_square], x) + # For safety, explicitly zero-out the upper triangular part. + x = array_ops.matrix_band_part(x, -1, 0) + return math_ops.matmul(x, x, adjoint_b=True) + + def _inverse(self, y): + return (math_ops.sqrt(y) if self._static_event_ndims == 0 + else linalg_ops.cholesky(y)) + + def _inverse_log_det_jacobian(self, y): + return -self._forward_log_det_jacobian(x=self._inverse(y)) + + def _forward_log_det_jacobian(self, x): + # Let Y be a symmetric, positive definite matrix and write: + # Y = X X.T + # where X is lower-triangular. + # + # Observe that, + # dY[i,j]/dX[a,b] + # = d/dX[a,b] { X[i,:] X[j,:] } + # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } + # + # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is + # symmetric and X is lower-triangular, we need vectors of dimension: + # d = p (p + 1) / 2 + # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., + # k = { i (i + 1) / 2 + j i>=j + # { undef ij thus i,j!=a. + # + # Since the Jacobian is lower-triangular, we need only compute the product + # of diagonal elements: + # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] + # = X[j,j] + I[i=j] X[i,j] + # = 2 X[j,j]. + # Since there is a 2 X[j,j] term for every lower-triangular element of X we + # conclude: + # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. + if self._static_event_ndims == 0: + if self.validate_args: + is_positive = check_ops.assert_positive( + x, message="All elements must be positive.") + x = control_flow_ops.with_dependencies([is_positive], x) + return np.log(2.) + math_ops.log(x) + + diag = array_ops.matrix_diag_part(x) + + # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output + # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the + # output is unchanged. + diag = self._make_columnar(diag) + + if self.validate_args: + is_matrix = check_ops.assert_rank_at_least( + x, 2, message="Input must be a (batch of) matrix.") + shape = array_ops.shape(x) + is_square = check_ops.assert_equal( + shape[-2], shape[-1], + message="Input must be a (batch of) square matrix.") + # Assuming lower-triangular means we only need check diag>0. + is_positive_definite = check_ops.assert_positive( + diag, message="Input must be positive definite.") + x = control_flow_ops.with_dependencies( + [is_matrix, is_square, is_positive_definite], x) + + # Create a vector equal to: [p, p-1, ..., 2, 1]. + if x.get_shape().ndims is None or x.get_shape()[-1].value is None: + p_int = array_ops.shape(x)[-1] + p_float = math_ops.cast(p_int, dtype=x.dtype) + else: + p_int = x.get_shape()[-1].value + p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) + exponents = math_ops.linspace(p_float, 1., p_int) + + sum_weighted_log_diag = array_ops.squeeze( + math_ops.matmul(math_ops.log(diag), + exponents[..., array_ops.newaxis]), + squeeze_dims=-1) + fldj = p_float * np.log(2.) + sum_weighted_log_diag + + return fldj + + def _make_columnar(self, x): + """Ensures non-scalar input has at least one column. + + Example: + If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. + + If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. + + If `x = 1` then the output is unchanged. + + Args: + x: `Tensor`. + + Returns: + columnar_x: `Tensor` with at least two dimensions. + """ + if x.get_shape().ndims is not None: + if x.get_shape().ndims == 1: + x = x[array_ops.newaxis, :] + return x + shape = array_ops.shape(x) + maybe_expanded_shape = array_ops.concat([ + shape[:-1], + distribution_util.pick_vector( + math_ops.equal(array_ops.rank(x), 1), + [1], np.array([], dtype=np.int32)), + shape[-1:], + ], 0) + return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py deleted file mode 100644 index cbd60f92a60612c6cf791b2c7708a3310c6e2b6b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""CholeskyOuterProduct bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "CholeskyOuterProduct", -] - - -class CholeskyOuterProduct(bijector.Bijector): - """Compute `g(X) = X @ X.T`; X is lower-triangular, positive-diagonal matrix. - - `event_ndims` must be 0 or 2, i.e., scalar or matrix. - - Note: the upper-triangular part of X is ignored (whether or not its zero). - - The surjectivity of g as a map from the set of n x n positive-diagonal - lower-triangular matrices to the set of SPD matrices follows immediately from - executing the Cholesky factorization algorithm on an SPD matrix A to produce a - positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. - - To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular - with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then - `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. - Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal - lower-triangular matrix follows from `inv(L_1)` being positive-diagonal - lower-triangular (which follows from the diagonal of a triangular matrix being - its spectrum), and that the product of two positive-diagonal lower-triangular - matrices is another positive-diagonal lower-triangular matrix. - - A simple inductive argument (proceding one column of L_3 at a time) shows - that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- - diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. - - Examples: - - ```python - bijector.CholeskyOuterProduct(event_ndims=2).forward(x=[[1., 0], [2, 1]]) - # Result: [[1., 2], [2, 5]], i.e., x @ x.T - - bijector.CholeskyOuterProduct(event_ndims=2).inverse(y=[[1., 2], [2, 5]]) - # Result: [[1., 0], [2, 1]], i.e., cholesky(y). - ``` - - """ - - def __init__(self, event_ndims=2, validate_args=False, - name="cholesky_outer_product"): - """Instantiates the `CholeskyOuterProduct` bijector. - - Args: - event_ndims: `constant` `int32` scalar `Tensor` indicating the number of - dimensions associated with a particular draw from the distribution. Must - be 0 or 2. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if event_ndims is neither 0 or 2. - """ - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 2]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 2") - self._static_event_ndims = event_ndims - super(CholeskyOuterProduct, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self._static_event_ndims == 0: - return math_ops.square(x) - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least(x, 2) - shape = array_ops.shape(x) - is_square = check_ops.assert_equal(shape[-2], shape[-1]) - x = control_flow_ops.with_dependencies([is_matrix, is_square], x) - # For safety, explicitly zero-out the upper triangular part. - x = array_ops.matrix_band_part(x, -1, 0) - return math_ops.matmul(x, x, adjoint_b=True) - - def _inverse(self, y): - return (math_ops.sqrt(y) if self._static_event_ndims == 0 - else linalg_ops.cholesky(y)) - - def _inverse_log_det_jacobian(self, y): - return -self._forward_log_det_jacobian(x=self._inverse(y)) - - def _forward_log_det_jacobian(self, x): - # Let Y be a symmetric, positive definite matrix and write: - # Y = X X.T - # where X is lower-triangular. - # - # Observe that, - # dY[i,j]/dX[a,b] - # = d/dX[a,b] { X[i,:] X[j,:] } - # = sum_{d=1}^p { I[i=a] I[d=b] X[j,d] + I[j=a] I[d=b] X[i,d] } - # - # To compute the Jacobian dX/dY we must represent X,Y as vectors. Since Y is - # symmetric and X is lower-triangular, we need vectors of dimension: - # d = p (p + 1) / 2 - # where X, Y are p x p matrices, p > 0. We use a row-major mapping, i.e., - # k = { i (i + 1) / 2 + j i>=j - # { undef ij thus i,j!=a. - # - # Since the Jacobian is lower-triangular, we need only compute the product - # of diagonal elements: - # d vec[Y] / d vec[X] @[k(i,j), k(i,j)] - # = X[j,j] + I[i=j] X[i,j] - # = 2 X[j,j]. - # Since there is a 2 X[j,j] term for every lower-triangular element of X we - # conclude: - # |Jac(d vec[Y]/d vec[X])| = 2^p prod_{j=0}^{p-1} X[j,j]^{p-j}. - if self._static_event_ndims == 0: - if self.validate_args: - is_positive = check_ops.assert_positive( - x, message="All elements must be positive.") - x = control_flow_ops.with_dependencies([is_positive], x) - return np.log(2.) + math_ops.log(x) - - diag = array_ops.matrix_diag_part(x) - - # We now ensure diag is columnar. Eg, if `diag = [1, 2, 3]` then the output - # is `[[1], [2], [3]]` and if `diag = [[1, 2, 3], [4, 5, 6]]` then the - # output is unchanged. - diag = self._make_columnar(diag) - - if self.validate_args: - is_matrix = check_ops.assert_rank_at_least( - x, 2, message="Input must be a (batch of) matrix.") - shape = array_ops.shape(x) - is_square = check_ops.assert_equal( - shape[-2], shape[-1], - message="Input must be a (batch of) square matrix.") - # Assuming lower-triangular means we only need check diag>0. - is_positive_definite = check_ops.assert_positive( - diag, message="Input must be positive definite.") - x = control_flow_ops.with_dependencies( - [is_matrix, is_square, is_positive_definite], x) - - # Create a vector equal to: [p, p-1, ..., 2, 1]. - if x.get_shape().ndims is None or x.get_shape()[-1].value is None: - p_int = array_ops.shape(x)[-1] - p_float = math_ops.cast(p_int, dtype=x.dtype) - else: - p_int = x.get_shape()[-1].value - p_float = np.array(p_int, dtype=x.dtype.as_numpy_dtype) - exponents = math_ops.linspace(p_float, 1., p_int) - - sum_weighted_log_diag = array_ops.squeeze( - math_ops.matmul(math_ops.log(diag), - exponents[..., array_ops.newaxis]), - squeeze_dims=-1) - fldj = p_float * np.log(2.) + sum_weighted_log_diag - - return fldj - - def _make_columnar(self, x): - """Ensures non-scalar input has at least one column. - - Example: - If `x = [1, 2, 3]` then the output is `[[1], [2], [3]]`. - - If `x = [[1, 2, 3], [4, 5, 6]]` then the output is unchanged. - - If `x = 1` then the output is unchanged. - - Args: - x: `Tensor`. - - Returns: - columnar_x: `Tensor` with at least two dimensions. - """ - if x.get_shape().ndims is not None: - if x.get_shape().ndims == 1: - x = x[array_ops.newaxis, :] - return x - shape = array_ops.shape(x) - maybe_expanded_shape = array_ops.concat([ - shape[:-1], - distribution_util.pick_vector( - math_ops.equal(array_ops.rank(x), 1), - [1], np.array([], dtype=np.int32)), - shape[-1:], - ], 0) - return array_ops.reshape(x, maybe_expanded_shape) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py index d254b635d28099a09a2054536f04ffee3a355b2f..ccb1f029277bc07011df7be047a075274f2b3a27 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector.py @@ -18,12 +18,38 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["ConditionalBijector"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = ["ConditionalBijector"] + + +class ConditionalBijector(bijector.Bijector): + """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward(self, x, name="forward", **condition_kwargs): + return self._call_forward(x, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse(self, y, name="inverse", **condition_kwargs): + return self._call_inverse(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def inverse_log_det_jacobian( + self, y, name="inverse_log_det_jacobian", **condition_kwargs): + return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict={ + "**condition_kwargs": + "Named arguments forwarded to subclass implementation."}) + def forward_log_det_jacobian( + self, x, name="forward_log_det_jacobian", **condition_kwargs): + return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py deleted file mode 100644 index ccb1f029277bc07011df7be047a075274f2b3a27..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/conditional_bijector_impl.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ConditionalBijector base.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = ["ConditionalBijector"] - - -class ConditionalBijector(bijector.Bijector): - """Conditional Bijector is a Bijector that allows intrinsic conditioning.""" - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward(self, x, name="forward", **condition_kwargs): - return self._call_forward(x, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse(self, y, name="inverse", **condition_kwargs): - return self._call_inverse(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def inverse_log_det_jacobian( - self, y, name="inverse_log_det_jacobian", **condition_kwargs): - return self._call_inverse_log_det_jacobian(y, name, **condition_kwargs) - - @distribution_util.AppendDocstring(kwargs_dict={ - "**condition_kwargs": - "Named arguments forwarded to subclass implementation."}) - def forward_log_det_jacobian( - self, x, name="forward_log_det_jacobian", **condition_kwargs): - return self._call_forward_log_det_jacobian(x, name, **condition_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py index 399d713098eb7223601beb9518dc51dd6160ad64..b1ff840d62a73c941a4d67dec73b5c9f4d5353f9 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/exp.py @@ -18,12 +18,49 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.exp_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import power_transform -_allowed_symbols = ["Exp"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Exp", +] + + +class Exp(power_transform.PowerTransform): + """Compute `Y = g(X) = exp(X)`. + + Example Use: + + ```python + # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + exp = Exp(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + exp(x) == exp.forward(x) + log(x) == exp.inverse(x) + ``` + + Note: the exp(.) is applied element-wise but the Jacobian is a reduction + over the event space. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="exp"): + """Instantiates the `Exp` bijector. + + Args: + event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions + associated with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + super(Exp, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py deleted file mode 100644 index b1ff840d62a73c941a4d67dec73b5c9f4d5353f9..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/exp_impl.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Exp bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distributions.python.ops.bijectors import power_transform - - -__all__ = [ - "Exp", -] - - -class Exp(power_transform.PowerTransform): - """Compute `Y = g(X) = exp(X)`. - - Example Use: - - ```python - # Create the Y=g(X)=exp(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - exp = Exp(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - exp(x) == exp.forward(x) - log(x) == exp.inverse(x) - ``` - - Note: the exp(.) is applied element-wise but the Jacobian is a reduction - over the event space. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="exp"): - """Instantiates the `Exp` bijector. - - Args: - event_ndims: Scalar `int32` `Tensor` indicating the number of dimensions - associated with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - super(Exp, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index cf37aa51115ed98ab263bc03bcb297a03432a7ae..67f39785563255be0fe154aca3cbcf01c6a01e73 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -18,12 +18,107 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Gumbel"] +__all__ = [ + "Gumbel", +] -remove_undocumented(__name__, _allowed_symbols) + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py deleted file mode 100644 index 67f39785563255be0fe154aca3cbcf01c6a01e73..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Gumbel bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "Gumbel", -] - - -class Gumbel(bijector.Bijector): - """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - - This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): - - ```none - Y ~ Gumbel(loc, scale) - pdf(y; loc, scale) = exp( - -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale - ``` - """ - - def __init__(self, - loc=0., - scale=1., - event_ndims=0, - validate_args=False, - name="gumbel"): - """Instantiates the `Gumbel` bijector. - - Args: - loc: Float-like `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - scale: Positive Float-like `Tensor` that is the same dtype and is - broadcastable with `loc`. - This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[loc, scale]): - self._loc = ops.convert_to_tensor(loc, name="loc") - self._scale = ops.convert_to_tensor(scale, name="scale") - check_ops.assert_same_float_dtype([self._loc, self._scale]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, message="Argument scale was not positive") - ], self._scale) - - super(Gumbel, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def loc(self): - """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._loc - - @property - def scale(self): - """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" - return self._scale - - def _forward(self, x): - z = (x - self.loc) / self.scale - return math_ops.exp(-math_ops.exp(-z)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.loc - self.scale * math_ops.log(-math_ops.log(y)) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - event_dims = self._event_dims_tensor(x) - z = (x - self.loc) / self.scale - return math_ops.reduce_sum( - -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, - constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py index db10c3fc3a9135b4c408ada74622ba9b360f9ec1..fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/inline.py @@ -18,12 +18,124 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.inline_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Inline"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Inline", +] + + +class Inline(bijector.Bijector): + """Bijector constructed from custom callables. + + Example Use: + + ```python + exp = Inline( + forward_fn=tf.exp, + inverse_fn=tf.log, + inverse_log_det_jacobian_fn=( + lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), + name="exp") + ``` + + The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. + """ + + def __init__(self, + forward_fn=None, + inverse_fn=None, + inverse_log_det_jacobian_fn=None, + forward_log_det_jacobian_fn=None, + forward_event_shape_fn=None, + forward_event_shape_tensor_fn=None, + inverse_event_shape_fn=None, + inverse_event_shape_tensor_fn=None, + is_constant_jacobian=False, + validate_args=False, + name="inline"): + """Creates a `Bijector` from callables. + + Args: + forward_fn: Python callable implementing the forward transformation. + inverse_fn: Python callable implementing the inverse transformation. + inverse_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the inverse transformation. + forward_log_det_jacobian_fn: Python callable implementing the + log o det o jacobian of the forward transformation. + forward_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + forward_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_fn: Python callable implementing non-identical + static event shape changes. Default: shape is assumed unchanged. + inverse_event_shape_tensor_fn: Python callable implementing non-identical + event shape changes. Default: shape is assumed unchanged. + is_constant_jacobian: Python `bool` indicating that the Jacobian is + constant for all input arguments. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + super(Inline, self).__init__( + event_ndims=0, + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + self._forward_fn = forward_fn + self._inverse_fn = inverse_fn + self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn + self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn + self._forward_event_shape_fn = forward_event_shape_fn + self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn + self._inverse_event_shape_fn = inverse_event_shape_fn + self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn + + def _forward_event_shape(self, input_shape): + if self._forward_event_shape_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_fn(input_shape) + + def _forward_event_shape_tensor(self, input_shape): + if self._forward_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return input_shape + return self._forward_event_shape_tensor_fn(input_shape) + + def _inverse_event_shape(self, output_shape): + if self._inverse_event_shape_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_fn(output_shape) + + def _inverse_event_shape_tensor(self, output_shape): + if self._inverse_event_shape_tensor_fn is None: + # By default assume shape doesn't change. + return output_shape + return self._inverse_event_shape_tensor_fn(output_shape) + + def _forward(self, x, **kwargs): + if not callable(self._forward_fn): + raise NotImplementedError( + "forward_fn is not a callable function.") + return self._forward_fn(x, **kwargs) + + def _inverse(self, y, **kwargs): + if not callable(self._inverse_fn): + raise NotImplementedError( + "inverse_fn is not a callable function.") + return self._inverse_fn(y, **kwargs) + + def _inverse_log_det_jacobian(self, y, **kwargs): + if not callable(self._inverse_log_det_jacobian_fn): + raise NotImplementedError( + "inverse_log_det_jacobian_fn is not a callable function.") + return self._inverse_log_det_jacobian_fn(y, **kwargs) + + def _forward_log_det_jacobian(self, y, **kwargs): + if not callable(self._forward_log_det_jacobian_fn): + raise NotImplementedError( + "forward_log_det_jacobian_fn is not a callable function.") + return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py deleted file mode 100644 index fab1b22fbf92e7b92a5ec86ec62d66bec71a8c94..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/inline_impl.py +++ /dev/null @@ -1,141 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Inline bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Inline", -] - - -class Inline(bijector.Bijector): - """Bijector constructed from custom callables. - - Example Use: - - ```python - exp = Inline( - forward_fn=tf.exp, - inverse_fn=tf.log, - inverse_log_det_jacobian_fn=( - lambda y: -tf.reduce_sum(tf.log(y), axis=-1)), - name="exp") - ``` - - The above example is equivalent to the `Bijector` `Exp(event_ndims=1)`. - """ - - def __init__(self, - forward_fn=None, - inverse_fn=None, - inverse_log_det_jacobian_fn=None, - forward_log_det_jacobian_fn=None, - forward_event_shape_fn=None, - forward_event_shape_tensor_fn=None, - inverse_event_shape_fn=None, - inverse_event_shape_tensor_fn=None, - is_constant_jacobian=False, - validate_args=False, - name="inline"): - """Creates a `Bijector` from callables. - - Args: - forward_fn: Python callable implementing the forward transformation. - inverse_fn: Python callable implementing the inverse transformation. - inverse_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the inverse transformation. - forward_log_det_jacobian_fn: Python callable implementing the - log o det o jacobian of the forward transformation. - forward_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - forward_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_fn: Python callable implementing non-identical - static event shape changes. Default: shape is assumed unchanged. - inverse_event_shape_tensor_fn: Python callable implementing non-identical - event shape changes. Default: shape is assumed unchanged. - is_constant_jacobian: Python `bool` indicating that the Jacobian is - constant for all input arguments. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - super(Inline, self).__init__( - event_ndims=0, - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - self._forward_fn = forward_fn - self._inverse_fn = inverse_fn - self._inverse_log_det_jacobian_fn = inverse_log_det_jacobian_fn - self._forward_log_det_jacobian_fn = forward_log_det_jacobian_fn - self._forward_event_shape_fn = forward_event_shape_fn - self._forward_event_shape_tensor_fn = forward_event_shape_tensor_fn - self._inverse_event_shape_fn = inverse_event_shape_fn - self._inverse_event_shape_tensor_fn = inverse_event_shape_tensor_fn - - def _forward_event_shape(self, input_shape): - if self._forward_event_shape_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_fn(input_shape) - - def _forward_event_shape_tensor(self, input_shape): - if self._forward_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return input_shape - return self._forward_event_shape_tensor_fn(input_shape) - - def _inverse_event_shape(self, output_shape): - if self._inverse_event_shape_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_fn(output_shape) - - def _inverse_event_shape_tensor(self, output_shape): - if self._inverse_event_shape_tensor_fn is None: - # By default assume shape doesn't change. - return output_shape - return self._inverse_event_shape_tensor_fn(output_shape) - - def _forward(self, x, **kwargs): - if not callable(self._forward_fn): - raise NotImplementedError( - "forward_fn is not a callable function.") - return self._forward_fn(x, **kwargs) - - def _inverse(self, y, **kwargs): - if not callable(self._inverse_fn): - raise NotImplementedError( - "inverse_fn is not a callable function.") - return self._inverse_fn(y, **kwargs) - - def _inverse_log_det_jacobian(self, y, **kwargs): - if not callable(self._inverse_log_det_jacobian_fn): - raise NotImplementedError( - "inverse_log_det_jacobian_fn is not a callable function.") - return self._inverse_log_det_jacobian_fn(y, **kwargs) - - def _forward_log_det_jacobian(self, y, **kwargs): - if not callable(self._forward_log_det_jacobian_fn): - raise NotImplementedError( - "forward_log_det_jacobian_fn is not a callable function.") - return self._forward_log_det_jacobian_fn(y, **kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py index c134e10109ce5065eb58de1d847e3c487258954c..2c603fe61f36dd27f4984fe6c13c11f2fb534321 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert.py @@ -18,12 +18,85 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.invert_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops.distributions import bijector as bijector_lib -_allowed_symbols = ["Invert"] +__all__ = [ + "Invert", +] -remove_undocumented(__name__, _allowed_symbols) + +class Invert(bijector_lib.Bijector): + """Bijector which inverts another Bijector. + + Example Use: [ExpGammaDistribution (see Background & Context)]( + https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) + models `Y=log(X)` where `X ~ Gamma`. + + ```python + exp_gamma_distribution = TransformedDistribution( + distribution=Gamma(concentration=1., rate=2.), + bijector=bijector.Invert(bijector.Exp()) + ``` + + """ + + def __init__(self, bijector, validate_args=False, name=None): + """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. + + Note: An inverted bijector's `inverse_log_det_jacobian` is often more + efficient if the base bijector implements `_forward_log_det_jacobian`. If + `_forward_log_det_jacobian` is not implemented then the following code is + used: + + ```python + y = self.inverse(x, **kwargs) + return -self.inverse_log_det_jacobian(y, **kwargs) + ``` + + Args: + bijector: Bijector instance. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + + if not bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijectors.") + + self._bijector = bijector + super(Invert, self).__init__( + event_ndims=bijector.event_ndims, + graph_parents=bijector.graph_parents, + is_constant_jacobian=bijector.is_constant_jacobian, + validate_args=validate_args, + dtype=bijector.dtype, + name=name or "_".join(["invert", bijector.name])) + + def _forward_event_shape(self, input_shape): + return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access + + def _forward_event_shape_tensor(self, input_shape): + return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access + + def _inverse_event_shape(self, output_shape): + return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access + + def _inverse_event_shape_tensor(self, output_shape): + return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access + + @property + def bijector(self): + return self._bijector + + def _forward(self, x, **kwargs): + return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access + + def _inverse(self, y, **kwargs): + return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access + + def _inverse_log_det_jacobian(self, y, **kwargs): + return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access + + def _forward_log_det_jacobian(self, x, **kwargs): + return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py deleted file mode 100644 index 2c603fe61f36dd27f4984fe6c13c11f2fb534321..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py +++ /dev/null @@ -1,102 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Invert bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.ops.distributions import bijector as bijector_lib - -__all__ = [ - "Invert", -] - - -class Invert(bijector_lib.Bijector): - """Bijector which inverts another Bijector. - - Example Use: [ExpGammaDistribution (see Background & Context)]( - https://reference.wolfram.com/language/ref/ExpGammaDistribution.html) - models `Y=log(X)` where `X ~ Gamma`. - - ```python - exp_gamma_distribution = TransformedDistribution( - distribution=Gamma(concentration=1., rate=2.), - bijector=bijector.Invert(bijector.Exp()) - ``` - - """ - - def __init__(self, bijector, validate_args=False, name=None): - """Creates a `Bijector` which swaps the meaning of `inverse` and `forward`. - - Note: An inverted bijector's `inverse_log_det_jacobian` is often more - efficient if the base bijector implements `_forward_log_det_jacobian`. If - `_forward_log_det_jacobian` is not implemented then the following code is - used: - - ```python - y = self.inverse(x, **kwargs) - return -self.inverse_log_det_jacobian(y, **kwargs) - ``` - - Args: - bijector: Bijector instance. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - - if not bijector._is_injective: # pylint: disable=protected-access - raise NotImplementedError( - "Invert is not implemented for non-injective bijectors.") - - self._bijector = bijector - super(Invert, self).__init__( - event_ndims=bijector.event_ndims, - graph_parents=bijector.graph_parents, - is_constant_jacobian=bijector.is_constant_jacobian, - validate_args=validate_args, - dtype=bijector.dtype, - name=name or "_".join(["invert", bijector.name])) - - def _forward_event_shape(self, input_shape): - return self.bijector._inverse_event_shape(input_shape) # pylint: disable=protected-access - - def _forward_event_shape_tensor(self, input_shape): - return self.bijector._inverse_event_shape_tensor(input_shape) # pylint: disable=protected-access - - def _inverse_event_shape(self, output_shape): - return self.bijector._forward_event_shape(output_shape) # pylint: disable=protected-access - - def _inverse_event_shape_tensor(self, output_shape): - return self.bijector._forward_event_shape_tensor(output_shape) # pylint: disable=protected-access - - @property - def bijector(self): - return self._bijector - - def _forward(self, x, **kwargs): - return self.bijector._inverse(x, **kwargs) # pylint: disable=protected-access - - def _inverse(self, y, **kwargs): - return self.bijector._forward(y, **kwargs) # pylint: disable=protected-access - - def _inverse_log_det_jacobian(self, y, **kwargs): - return self.bijector._forward_log_det_jacobian(y, **kwargs) # pylint: disable=protected-access - - def _forward_log_det_jacobian(self, x, **kwargs): - return self.bijector._inverse_log_det_jacobian(x, **kwargs) # pylint: disable=protected-access diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index 132dc570f94719b6c71fb269866c943774481b7e..5251dbcb5748f75688aa43ce6e4e9dbd76be78bb 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -18,16 +18,490 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = [ +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops import variable_scope as variable_scope_lib +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ "MaskedAutoregressiveFlow", - "masked_dense", "masked_autoregressive_default_template", + "masked_dense", ] -remove_undocumented(__name__, _allowed_symbols) + +class MaskedAutoregressiveFlow(bijector_lib.Bijector): + """Affine MaskedAutoregressiveFlow bijector for vector-valued events. + + The affine autoregressive flow [1] provides a relatively simple framework for + user-specified (deep) architectures to learn a distribution over vector-valued + events. Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + In the `tf.distributions` framework, a "normalizing flow" is implemented as a + `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + is implemented using a `tf.while_loop` and a deep neural network (DNN) with + masked weights such that the autoregressive property is automatically met in + the `inverse`. + + A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the + (expensive) forward-mode calculation to draw samples and the (cheap) + reverse-mode calculation to compute log-probabilities. Conversely, a + `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses + the (expensive) forward-mode calculation to compute log-probabilities and the + (cheap) reverse-mode calculation to compute samples. See "Example Use" + [below] for more details. + + Given a `shift_and_log_scale_fn`, the forward and inverse transformations are + (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` + must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka + "alpha" [2]) such that each are broadcastable with the arguments to `forward` + and `inverse`, i.e., such that the calculations in `forward`, `inverse` + [below] are possible. + + For convenience, `masked_autoregressive_default_template` is offered as a + possible `shift_and_log_scale_fn` function. It implements the MADE + architecture [2]. MADE is a feed-forward network that computes a `shift` and + `log(scale)` using `masked_dense` layers in a deep neural network. Weights are + masked to ensure the autoregressive property. It is possible that this + architecture is suboptimal for your task. To build alternative networks, + either change the arguments to `masked_autoregressive_default_template`, use + the `masked_dense` function to roll-out your own, or use some other + architecture, e.g., using `tf.layers`. + + Warning: no attempt is made to validate that the `shift_and_log_scale_fn` + enforces the "autoregressive property". + + Assuming `shift_and_log_scale_fn` has valid shape and autoregressive + semantics, the forward transformation is, + + ```python + def forward(x): + y = zeros_like(x) + event_size = x.shape[-1] + for _ in range(event_size): + shift, log_scale = shift_and_log_scale_fn(y) + y = x * math_ops.exp(log_scale) + shift + return y + ``` + + and the inverse transformation is, + + ```python + def inverse(y): + shift, log_scale = shift_and_log_scale_fn(y) + return (y - shift) / math_ops.exp(log_scale) + ``` + + Notice that the `inverse` does not need a for-loop. This is because in the + forward pass each calculation of `shift` and `log_scale` is based on the `y` + calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is + equivalent to the scaling used in `forward` after `event_size` passes, i.e., + the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this + also proves the transform is bijective.) + + #### Example Use + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + dims = 5 + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + maf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512])), + event_shape=[dims]) + + x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. + maf.log_prob(x) # Almost free; uses Bijector caching. + maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. + + # [1] also describes an "Inverse Autoregressive Flow", e.g., + iaf = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=tfb.masked_autoregressive_default_template( + hidden_layers=[512, 512]))), + event_shape=[dims]) + + x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. + iaf.log_prob(x) # Almost free; uses Bijector caching. + iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. + + # In many (if not most) cases the default `shift_and_log_scale_fn` will be a + # poor choice. Here's an example of using a "shift only" version and with a + # different number/depth of hidden layers. + shift_only = True + maf_no_scale_hidden2 = tfd.TransformedDistribution( + distribution=tfd.Normal(loc=0., scale=1.), + bijector=tfb.MaskedAutoregressiveFlow( + tfb.masked_autoregressive_default_template( + hidden_layers=[32], + shift_only=shift_only), + is_constant_jacobian=shift_only), + event_shape=[dims]) + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + """ + + def __init__(self, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + unroll_loop=False, + name=None): + """Creates the MaskedAutoregressiveFlow bijector. + + Args: + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + unroll_loop: Python `bool` indicating whether the `tf.while_loop` in + `_forward` should be replaced with a static for loop. Requires that + the final dimension of `x` be known at graph construction time. Defaults + to `False`. + name: Python `str`, name given to ops managed by this object. + """ + name = name or "masked_autoregressive_flow" + self._shift_and_log_scale_fn = shift_and_log_scale_fn + self._unroll_loop = unroll_loop + super(MaskedAutoregressiveFlow, self).__init__( + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self._unroll_loop: + event_size = x.shape.with_rank_at_least(1)[-1].value + if event_size is None: + raise ValueError( + "The final dimension of `x` must be known at graph construction " + "time if `unroll_loop=True`. `x.shape: %r`" % x.shape) + y = array_ops.zeros_like(x, name="y0") + + for _ in range(event_size): + shift, log_scale = self._shift_and_log_scale_fn(y) + # next_y = scale * x + shift + next_y = x + if log_scale is not None: + next_y *= math_ops.exp(log_scale) + if shift is not None: + next_y += shift + y = next_y + return y + + event_size = array_ops.shape(x)[-1] + # If the event size is available at graph construction time, we can inform + # the graph compiler of the maximum number of steps. If not, + # static_event_size will be None, and the maximum_iterations argument will + # have no effect. + static_event_size = x.shape.with_rank_at_least(1)[-1].value + y0 = array_ops.zeros_like(x, name="y0") + # call the template once to ensure creation + _ = self._shift_and_log_scale_fn(y0) + def _loop_body(index, y0): + """While-loop body for autoregression calculation.""" + # Set caching device to avoid re-getting the tf.Variable for every while + # loop iteration. + with variable_scope_lib.variable_scope( + variable_scope_lib.get_variable_scope()) as vs: + if vs.caching_device is None: + vs.set_caching_device(lambda op: op.device) + shift, log_scale = self._shift_and_log_scale_fn(y0) + y = x + if log_scale is not None: + y *= math_ops.exp(log_scale) + if shift is not None: + y += shift + return index + 1, y + _, y = control_flow_ops.while_loop( + cond=lambda index, _: index < event_size, + body=_loop_body, + loop_vars=(0, y0), + maximum_iterations=static_event_size) + return y + + def _inverse(self, y): + shift, log_scale = self._shift_and_log_scale_fn(y) + x = y + if shift is not None: + x -= shift + if log_scale is not None: + x *= math_ops.exp(-log_scale) + return x + + def _inverse_log_det_jacobian(self, y): + _, log_scale = self._shift_and_log_scale_fn(y) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + +MASK_INCLUSIVE = "inclusive" +MASK_EXCLUSIVE = "exclusive" + + +def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): + """Generate the slices for building an autoregressive mask.""" + # TODO(b/67594795): Better support of dynamic shape. + slices = [] + col = 0 + d_in = n_in // num_blocks + d_out = n_out // num_blocks + row = d_out if mask_type == MASK_EXCLUSIVE else 0 + for _ in range(num_blocks): + row_slice = slice(row, None) + col_slice = slice(col, col + d_in) + slices.append([row_slice, col_slice]) + col += d_in + row += d_out + return slices + + +def _gen_mask(num_blocks, + n_in, + n_out, + mask_type=MASK_EXCLUSIVE, + dtype=dtypes.float32): + """Generate the mask for building an autoregressive dense layer.""" + # TODO(b/67594795): Better support of dynamic shape. + mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) + slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) + for [row_slice, col_slice] in slices: + mask[row_slice, col_slice] = 1 + return mask + + +def masked_dense(inputs, + units, + num_blocks=None, + exclusive=False, + kernel_initializer=None, + reuse=None, + name=None, + *args, + **kwargs): + """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + + See [1] for detailed explanation. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + inputs: Tensor input. + units: Python `int` scalar representing the dimensionality of the output + space. + num_blocks: Python `int` scalar representing the number of blocks for the + MADE masks. + exclusive: Python `bool` scalar representing whether to zero the diagonal of + the mask, used for the first layer of a MADE. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the + `tf.glorot_random_initializer`. + reuse: Python `bool` scalar representing whether to reuse the weights of a + previous layer by the same name. + name: Python `str` used to describe ops managed by this function. + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + Output tensor. + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + # TODO(b/67594795): Better support of dynamic shape. + input_depth = inputs.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + + mask = _gen_mask(num_blocks, input_depth, units, + MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T + + if kernel_initializer is None: + kernel_initializer = init_ops.glorot_normal_initializer() + + def masked_initializer(shape, dtype=None, partition_info=None): + return mask * kernel_initializer(shape, dtype, partition_info) + + with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): + layer = layers.Dense( + units, + kernel_initializer=masked_initializer, + kernel_constraint=lambda x: mask * x, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse, + *args, + **kwargs) + return layer.apply(inputs) + + +def masked_autoregressive_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + log_scale_min_clip=-5., + log_scale_max_clip=3., + log_scale_clip_gradient=False, + name=None, + *args, + **kwargs): + """Build the MADE Model [1]. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the input and returns the `loc` ("mu" [1]) and + `log_scale` ("alpha" [1]) from the MADE network. + + Warning: This function uses `masked_dense` to create randomly initialized + `tf.Variables`. It is presumed that these will be fit, just as you would any + other neural architecture which uses `tf.layers.dense`. + + #### About Hidden Layers: + + Each element of `hidden_layers` should be greater than the `input_depth` + (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the + neural network). This is necessary to ensure the autoregressivity property. + + #### About Clipping: + + This function also optionally clips the `log_scale` (but possibly not its + gradient). This is useful because if `log_scale` is too small/large it might + underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` + bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` + `bool` indicates whether the gradient should also be clipped. The default does + not clip the gradient; this is useful because it still provides gradient + information (for fitting) yet solves the numerical stability problem. I.e., + `log_scale_clip_gradient = False` means + `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual + `grad[clip(x)] exp(clip(x))`. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed. Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The minimum value to clip by. Default: -5. + log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The maximum value to clip by. Default: 3. + log_scale_clip_gradient: Python `bool` indicating that the gradient of + `tf.clip_by_value` should be preserved. Default: `False`. + name: A name for ops managed by this function. Default: + "masked_autoregressive_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "masked_autoregressive_default_template", + values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): + """MADE parameterized via `masked_autoregressive_default_template`.""" + # TODO(b/67594795): Better support of dynamic shape. + input_depth = x.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() + else array_ops.shape(x)) + for i, units in enumerate(hidden_layers): + x = masked_dense( + inputs=x, + units=units, + num_blocks=input_depth, + exclusive=True if i == 0 else False, + activation=activation, + *args, + **kwargs) + x = masked_dense( + inputs=x, + units=(1 if shift_only else 2) * input_depth, + num_blocks=input_depth, + activation=None, + *args, + **kwargs) + if shift_only: + x = array_ops.reshape(x, shape=input_shape) + return x, None + x = array_ops.reshape( + x, shape=array_ops.concat([input_shape, [2]], axis=0)) + shift, log_scale = array_ops.unstack(x, num=2, axis=-1) + which_clip = (math_ops.clip_by_value if log_scale_clip_gradient + else _clip_by_value_preserve_grad) + log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) + return shift, log_scale + return template_ops.make_template( + "masked_autoregressive_default_template", _fn) + + +def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): + """Clips input while leaving gradient unaltered.""" + with ops.name_scope(name, "clip_by_value_preserve_grad", + [x, clip_value_min, clip_value_max]): + clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) + return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py deleted file mode 100644 index ae142883931274b594dbbafbe86bd71e75c621bc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""MaskedAutoregressiveFlow bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.layers import core as layers -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import template as template_ops -from tensorflow.python.ops import variable_scope as variable_scope_lib -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "MaskedAutoregressiveFlow", - "masked_autoregressive_default_template", - "masked_dense", -] - - -class MaskedAutoregressiveFlow(bijector_lib.Bijector): - """Affine MaskedAutoregressiveFlow bijector for vector-valued events. - - The affine autoregressive flow [1] provides a relatively simple framework for - user-specified (deep) architectures to learn a distribution over vector-valued - events. Regarding terminology, - - "Autoregressive models decompose the joint density as a product of - conditionals, and model each conditional in turn. Normalizing flows - transform a base density (e.g. a standard Gaussian) into the target density - by an invertible transformation with tractable Jacobian." [1] - - In other words, the "autoregressive property" is equivalent to the - decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided - `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves - this property by zeroing out weights in its `masked_dense` layers. - - In the `tf.distributions` framework, a "normalizing flow" is implemented as a - `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" - is implemented using a `tf.while_loop` and a deep neural network (DNN) with - masked weights such that the autoregressive property is automatically met in - the `inverse`. - - A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the - (expensive) forward-mode calculation to draw samples and the (cheap) - reverse-mode calculation to compute log-probabilities. Conversely, a - `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses - the (expensive) forward-mode calculation to compute log-probabilities and the - (cheap) reverse-mode calculation to compute samples. See "Example Use" - [below] for more details. - - Given a `shift_and_log_scale_fn`, the forward and inverse transformations are - (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` - must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka - "alpha" [2]) such that each are broadcastable with the arguments to `forward` - and `inverse`, i.e., such that the calculations in `forward`, `inverse` - [below] are possible. - - For convenience, `masked_autoregressive_default_template` is offered as a - possible `shift_and_log_scale_fn` function. It implements the MADE - architecture [2]. MADE is a feed-forward network that computes a `shift` and - `log(scale)` using `masked_dense` layers in a deep neural network. Weights are - masked to ensure the autoregressive property. It is possible that this - architecture is suboptimal for your task. To build alternative networks, - either change the arguments to `masked_autoregressive_default_template`, use - the `masked_dense` function to roll-out your own, or use some other - architecture, e.g., using `tf.layers`. - - Warning: no attempt is made to validate that the `shift_and_log_scale_fn` - enforces the "autoregressive property". - - Assuming `shift_and_log_scale_fn` has valid shape and autoregressive - semantics, the forward transformation is, - - ```python - def forward(x): - y = zeros_like(x) - event_size = x.shape[-1] - for _ in range(event_size): - shift, log_scale = shift_and_log_scale_fn(y) - y = x * math_ops.exp(log_scale) + shift - return y - ``` - - and the inverse transformation is, - - ```python - def inverse(y): - shift, log_scale = shift_and_log_scale_fn(y) - return (y - shift) / math_ops.exp(log_scale) - ``` - - Notice that the `inverse` does not need a for-loop. This is because in the - forward pass each calculation of `shift` and `log_scale` is based on the `y` - calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is - equivalent to the scaling used in `forward` after `event_size` passes, i.e., - the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this - also proves the transform is bijective.) - - #### Example Use - - ```python - ds = tf.contrib.distributions - bs = tf.contrib.distributions.bijectors - - dims = 5 - - # A common choice for a normalizing flow is to use a Gaussian for the base - # distribution. (However, any continuous distribution would work.) E.g., - maf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512])), - event_shape=[dims]) - - x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. - maf.log_prob(x) # Almost free; uses Bijector caching. - maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. - - # [1] also describes an "Inverse Autoregressive Flow", e.g., - iaf = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.Invert(bs.MaskedAutoregressiveFlow( - shift_and_log_scale_fn=bs.masked_autoregressive_default_template( - hidden_layers=[512, 512]))), - event_shape=[dims]) - - x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. - iaf.log_prob(x) # Almost free; uses Bijector caching. - iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. - - # In many (if not most) cases the default `shift_and_log_scale_fn` will be a - # poor choice. Here's an example of using a "shift only" version and with a - # different number/depth of hidden layers. - shift_only = True - maf_no_scale_hidden2 = ds.TransformedDistribution( - distribution=ds.Normal(loc=0., scale=1.), - bijector=bs.MaskedAutoregressiveFlow( - bs.masked_autoregressive_default_template( - hidden_layers=[32], - shift_only=shift_only), - is_constant_jacobian=shift_only), - event_shape=[dims]) - ``` - - [1]: "Masked Autoregressive Flow for Density Estimation." - George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. - https://arxiv.org/abs/1705.07057 - - [2]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - """ - - def __init__(self, - shift_and_log_scale_fn, - is_constant_jacobian=False, - validate_args=False, - name=None): - """Creates the MaskedAutoregressiveFlow bijector. - - Args: - shift_and_log_scale_fn: Python `callable` which computes `shift` and - `log_scale` from both the forward domain (`x`) and the inverse domain - (`y`). Calculation must respect the "autoregressive property" (see class - docstring). Suggested default - `masked_autoregressive_default_template(hidden_layers=...)`. - Typically the function contains `tf.Variables` and is wrapped using - `tf.make_template`. Returning `None` for either (both) `shift`, - `log_scale` is equivalent to (but more efficient than) returning zero. - is_constant_jacobian: Python `bool`. Default: `False`. When `True` the - implementation assumes `log_scale` does not depend on the forward domain - (`x`) or inverse domain (`y`) values. (No validation is made; - `is_constant_jacobian=False` is always safe but possibly computationally - inefficient.) - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - """ - name = name or "masked_autoregressive_flow" - self._shift_and_log_scale_fn = shift_and_log_scale_fn - super(MaskedAutoregressiveFlow, self).__init__( - is_constant_jacobian=is_constant_jacobian, - validate_args=validate_args, - name=name) - - def _forward(self, x): - event_size = array_ops.shape(x)[-1] - def _loop_body(index, y0): - """While-loop body for autoregression calculation.""" - # Set caching device to avoid re-getting the tf.Variable for every while - # loop iteration. - with variable_scope_lib.variable_scope( - variable_scope_lib.get_variable_scope()) as vs: - if vs.caching_device is None: - vs.set_caching_device(lambda op: op.device) - shift, log_scale = self._shift_and_log_scale_fn(y0) - y = x - if log_scale is not None: - y *= math_ops.exp(log_scale) - if shift is not None: - y += shift - return index + 1, y - _, y = control_flow_ops.while_loop( - cond=lambda index, _: index < event_size, - body=_loop_body, - loop_vars=[0, array_ops.zeros_like(x, name="y0")]) - return y - - def _inverse(self, y): - shift, log_scale = self._shift_and_log_scale_fn(y) - x = y - if shift is not None: - x -= shift - if log_scale is not None: - x *= math_ops.exp(-log_scale) - return x - - def _inverse_log_det_jacobian(self, y): - _, log_scale = self._shift_and_log_scale_fn(y) - if log_scale is None: - return constant_op.constant(0., dtype=y.dtype, name="ildj") - return -math_ops.reduce_sum(log_scale, axis=-1) - - -MASK_INCLUSIVE = "inclusive" -MASK_EXCLUSIVE = "exclusive" - - -def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): - """Generate the slices for building an autoregressive mask.""" - # TODO(b/67594795): Better support of dynamic shape. - slices = [] - col = 0 - d_in = n_in // num_blocks - d_out = n_out // num_blocks - row = d_out if mask_type == MASK_EXCLUSIVE else 0 - for _ in range(num_blocks): - row_slice = slice(row, None) - col_slice = slice(col, col + d_in) - slices.append([row_slice, col_slice]) - col += d_in - row += d_out - return slices - - -def _gen_mask(num_blocks, - n_in, - n_out, - mask_type=MASK_EXCLUSIVE, - dtype=dtypes.float32): - """Generate the mask for building an autoregressive dense layer.""" - # TODO(b/67594795): Better support of dynamic shape. - mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) - slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) - for [row_slice, col_slice] in slices: - mask[row_slice, col_slice] = 1 - return mask - - -def masked_dense(inputs, - units, - num_blocks=None, - exclusive=False, - kernel_initializer=None, - reuse=None, - name=None, - *args, - **kwargs): - """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. - - See [1] for detailed explanation. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - inputs: Tensor input. - units: Python `int` scalar representing the dimensionality of the output - space. - num_blocks: Python `int` scalar representing the number of blocks for the - MADE masks. - exclusive: Python `bool` scalar representing whether to zero the diagonal of - the mask, used for the first layer of a MADE. - kernel_initializer: Initializer function for the weight matrix. - If `None` (default), weights are initialized using the - `tf.glorot_random_initializer`. - reuse: Python `bool` scalar representing whether to reuse the weights of a - previous layer by the same name. - name: Python `str` used to describe ops managed by this function. - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - Output tensor. - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - # TODO(b/67594795): Better support of dynamic shape. - input_depth = inputs.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - - mask = _gen_mask(num_blocks, input_depth, units, - MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T - - if kernel_initializer is None: - kernel_initializer = init_ops.glorot_normal_initializer() - - def masked_initializer(shape, dtype=None, partition_info=None): - return mask * kernel_initializer(shape, dtype, partition_info) - - with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): - layer = layers.Dense( - units, - kernel_initializer=masked_initializer, - kernel_constraint=lambda x: mask * x, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse, - *args, - **kwargs) - return layer.apply(inputs) - - -def masked_autoregressive_default_template( - hidden_layers, - shift_only=False, - activation=nn_ops.relu, - log_scale_min_clip=-5., - log_scale_max_clip=3., - log_scale_clip_gradient=False, - name=None, - *args, - **kwargs): - """Build the MADE Model [1]. - - This will be wrapped in a make_template to ensure the variables are only - created once. It takes the input and returns the `loc` ("mu" [1]) and - `log_scale` ("alpha" [1]) from the MADE network. - - Warning: This function uses `masked_dense` to create randomly initialized - `tf.Variables`. It is presumed that these will be fit, just as you would any - other neural architecture which uses `tf.layers.dense`. - - #### About Hidden Layers: - - Each element of `hidden_layers` should be greater than the `input_depth` - (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the - neural network). This is necessary to ensure the autoregressivity property. - - #### About Clipping: - - This function also optionally clips the `log_scale` (but possibly not its - gradient). This is useful because if `log_scale` is too small/large it might - underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` - bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` - `bool` indicates whether the gradient should also be clipped. The default does - not clip the gradient; this is useful because it still provides gradient - information (for fitting) yet solves the numerical stability problem. I.e., - `log_scale_clip_gradient = False` means - `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual - `grad[clip(x)] exp(clip(x))`. - - [1]: "MADE: Masked Autoencoder for Distribution Estimation." - Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. - https://arxiv.org/abs/1502.03509 - - Arguments: - hidden_layers: Python `list`-like of non-negative integer, scalars - indicating the number of units in each hidden layer. Default: `[512, 512]. - shift_only: Python `bool` indicating if only the `shift` term shall be - computed. Default: `False`. - activation: Activation function (callable). Explicitly setting to `None` - implies a linear activation. - log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The minimum value to clip by. Default: -5. - log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the - same shape as `log_scale`. The maximum value to clip by. Default: 3. - log_scale_clip_gradient: Python `bool` indicating that the gradient of - `tf.clip_by_value` should be preserved. Default: `False`. - name: A name for ops managed by this function. Default: - "masked_autoregressive_default_template". - *args: `tf.layers.dense` arguments. - **kwargs: `tf.layers.dense` keyword arguments. - - Returns: - shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). - log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). - - Raises: - NotImplementedError: if rightmost dimension of `inputs` is unknown prior to - graph execution. - """ - - with ops.name_scope(name, "masked_autoregressive_default_template", - values=[log_scale_min_clip, log_scale_max_clip]): - def _fn(x): - """MADE parameterized via `masked_autoregressive_default_template`.""" - # TODO(b/67594795): Better support of dynamic shape. - input_depth = x.shape.with_rank_at_least(1)[-1].value - if input_depth is None: - raise NotImplementedError( - "Rightmost dimension must be known prior to graph execution.") - input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() - else array_ops.shape(x)) - for i, units in enumerate(hidden_layers): - x = masked_dense( - inputs=x, - units=units, - num_blocks=input_depth, - exclusive=True if i == 0 else False, - activation=activation, - *args, - **kwargs) - x = masked_dense( - inputs=x, - units=(1 if shift_only else 2) * input_depth, - num_blocks=input_depth, - activation=None, - *args, - **kwargs) - if shift_only: - x = array_ops.reshape(x, shape=input_shape) - return x, None - x = array_ops.reshape( - x, shape=array_ops.concat([input_shape, [2]], axis=0)) - shift, log_scale = array_ops.unstack(x, num=2, axis=-1) - which_clip = (math_ops.clip_by_value if log_scale_clip_gradient - else _clip_by_value_preserve_grad) - log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) - return shift, log_scale - return template_ops.make_template( - "masked_autoregressive_default_template", _fn) - - -def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): - """Clips input while leaving gradient unaltered.""" - with ops.name_scope(name, "clip_by_value_preserve_grad", - [x, clip_value_min, clip_value_max]): - clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) - return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py index a187ce22d686ee1203802ae2bfe64b0e1a3ea850..8654cc39d0c41ec4f1b85cd5fc4366ceaf4b224d 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -12,18 +12,127 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Permute bijector.""" +"""Permutation bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Permute"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + tfd = tf.contrib.distributions + + reverse = tfd.bijectors.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py deleted file mode 100644 index b1d8f2f41b28a88208a19824377f93882b767f03..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py +++ /dev/null @@ -1,138 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Permutation bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Permute", -] - - -class Permute(bijector_lib.Bijector): - """Permutes the rightmost dimension of a `Tensor`. - - ```python - bs = tf.contrib.distributions.bijectors - - reverse = bs.Permute(permutation=[2, 1, 0]) - - reverse.forward([-1., 0., 1.]) - # ==> [1., 0., -1] - - reverse.inverse([1., 0., -1]) - # ==> [-1., 0., 1.] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - Warning: `tf.estimator` may repeatedly build the graph thus - `Permute(np.random.permutation(event_size)).astype("int32"))` is not a - reliable parameterization (nor would it be even if using `tf.constant`). A - safe alternative is to use `tf.get_variable` to achieve "init once" behavior, - i.e., - - ```python - def init_once(x, name): - return tf.get_variable(name, initializer=x, trainable=False) - - Permute(permutation=init_once( - np.random.permutation(event_size).astype("int32"), - name="permutation")) - ``` - - """ - - def __init__(self, permutation, validate_args=False, name=None): - """Creates the `Permute` bijector. - - Args: - permutation: An `int`-like vector-shaped `Tensor` representing the - permutation to apply to the rightmost dimension of the transformed - `Tensor`. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if `not permutation.dtype.is_integer`. - ValueError: if `permutation` does not contain exactly one of each of - `{0, 1, ..., d}`. - """ - with ops.name_scope(name, "permute", values=[permutation]): - permutation = ops.convert_to_tensor( - permutation, - name="permutation") - if not permutation.dtype.is_integer: - raise TypeError("permutation.dtype ({}) should be `int`-like.".format( - permutation.dtype.name)) - p = tensor_util.constant_value(permutation) - if p is not None: - if set(p) != set(np.arange(p.size)): - raise ValueError("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.") - elif validate_args: - p, _ = nn_ops.top_k(-permutation, - k=array_ops.shape(permutation)[-1], - sorted=True) - permutation = control_flow_ops.with_dependencies([ - check_ops.assert_equal( - -p, math_ops.range(array_ops.size(p)), - message=("Permutation over `d` must contain exactly one of " - "each of `{0, 1, ..., d}`.")), - ], permutation) - self._permutation = permutation - super(Permute, self).__init__( - is_constant_jacobian=True, - validate_args=validate_args, - name=name or "permute") - - @property - def permutation(self): - return self._permutation - - def _forward(self, x): - return array_ops.gather(x, self.permutation, axis=-1) - - def _inverse(self, y): - return array_ops.gather( - y, - array_ops.invert_permutation(self.permutation), - axis=-1) - - def _inverse_log_det_jacobian(self, y): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py index a83199549cd16101ab7b39b43d19a17bc66f03df..c37db61720d10949f294ff7b2e9778ba6efa57f0 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform.py @@ -18,12 +18,110 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.power_transform_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["PowerTransform"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "PowerTransform", +] + + +class PowerTransform(bijector.Bijector): + """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. + + The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps + inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` + of this bijector. + + This bijector is equivalent to the `Exp` bijector when `c=0`. + """ + + def __init__(self, + power=0., + event_ndims=0, + validate_args=False, + name="power_transform"): + """Instantiates the `PowerTransform` bijector. + + Args: + power: Python `float` scalar indicating the transform power, i.e., + `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: if `power < 0` or is not known statically. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[power]): + power = tensor_util.constant_value( + ops.convert_to_tensor(power, name="power")) + if power is None or power < 0: + raise ValueError("`power` must be a non-negative TF constant.") + self._power = power + super(PowerTransform, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def power(self): + """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" + return self._power + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + if self.power == 0.: + return math_ops.exp(x) + # If large x accuracy is an issue, consider using: + # (1. + x * self.power)**(1. / self.power) when x >> 1. + return math_ops.exp(math_ops.log1p(x * self.power) / self.power) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + if self.power == 0.: + return math_ops.log(y) + # If large y accuracy is an issue, consider using: + # (y**self.power - 1.) / self.power when y >> 1. + return math_ops.expm1(math_ops.log(y) * self.power) / self.power + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return (self.power - 1.) * math_ops.reduce_sum( + math_ops.log(y), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + if self.power == 0.: + return math_ops.reduce_sum(x, axis=event_dims) + return (1. / self.power - 1.) * math_ops.reduce_sum( + math_ops.log1p(x * self.power), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args or self.power == 0.: + return x + is_valid = check_ops.assert_non_negative( + 1. + self.power * x, + message="Forward transformation input must be at least {}.".format( + -1. / self.power)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_valid = check_ops.assert_positive( + y, message="Inverse transformation input must be greater than 0.") + return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py deleted file mode 100644 index c37db61720d10949f294ff7b2e9778ba6efa57f0..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/power_transform_impl.py +++ /dev/null @@ -1,127 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""PowerTransform bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "PowerTransform", -] - - -class PowerTransform(bijector.Bijector): - """Compute `Y = g(X) = (1 + X * c)**(1 / c), X >= -1 / c`. - - The [power transform](https://en.wikipedia.org/wiki/Power_transform) maps - inputs from `[0, inf]` to `[-1/c, inf]`; this is equivalent to the `inverse` - of this bijector. - - This bijector is equivalent to the `Exp` bijector when `c=0`. - """ - - def __init__(self, - power=0., - event_ndims=0, - validate_args=False, - name="power_transform"): - """Instantiates the `PowerTransform` bijector. - - Args: - power: Python `float` scalar indicating the transform power, i.e., - `Y = g(X) = (1 + X * c)**(1 / c)` where `c` is the `power`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - - Raises: - ValueError: if `power < 0` or is not known statically. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[power]): - power = tensor_util.constant_value( - ops.convert_to_tensor(power, name="power")) - if power is None or power < 0: - raise ValueError("`power` must be a non-negative TF constant.") - self._power = power - super(PowerTransform, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def power(self): - """The `c` in: `Y = g(X) = (1 + X * c)**(1 / c)`.""" - return self._power - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - if self.power == 0.: - return math_ops.exp(x) - # If large x accuracy is an issue, consider using: - # (1. + x * self.power)**(1. / self.power) when x >> 1. - return math_ops.exp(math_ops.log1p(x * self.power) / self.power) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - if self.power == 0.: - return math_ops.log(y) - # If large y accuracy is an issue, consider using: - # (y**self.power - 1.) / self.power when y >> 1. - return math_ops.expm1(math_ops.log(y) * self.power) / self.power - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return (self.power - 1.) * math_ops.reduce_sum( - math_ops.log(y), axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - if self.power == 0.: - return math_ops.reduce_sum(x, axis=event_dims) - return (1. / self.power - 1.) * math_ops.reduce_sum( - math_ops.log1p(x * self.power), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args or self.power == 0.: - return x - is_valid = check_ops.assert_non_negative( - 1. + self.power * x, - message="Forward transformation input must be at least {}.".format( - -1. / self.power)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_valid = check_ops.assert_positive( - y, message="Inverse transformation input must be greater than 0.") - return control_flow_ops.with_dependencies([is_valid], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py new file mode 100644 index 0000000000000000000000000000000000000000..2840f52e742eac5e9e37a576bf7f6d6f05a07a35 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/real_nvp.py @@ -0,0 +1,282 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Real NVP bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "RealNVP", + "real_nvp_default_template" +] + + +class RealNVP(bijector_lib.Bijector): + """RealNVP "affine coupling layer" for vector-valued events. + + Real NVP models a normalizing flow on a `D`-dimensional distribution via a + single `D-d`-dimensional conditional distribution [1]: + + `y[d:D] = y[d:D] * math_ops.exp(log_scale_fn(y[d:D])) + shift_fn(y[d:D])` + `y[0:d] = x[0:d]` + + The last `D-d` units are scaled and shifted based on the first `d` units only, + while the first `d` units are 'masked' and left unchanged. Real NVP's + `shift_and_log_scale_fn` computes vector-valued quantities. For + scale-and-shift transforms that do not depend on any masked units, i.e. + `d=0`, use the `tfb.Affine` bijector with learned parameters instead. + + Masking is currently only supported for base distributions with + `event_ndims=1`. For more sophisticated masking schemes like checkerboard or + channel-wise masking [2], use the `tfb.Permute` bijector to re-order desired + masked units into the first `d` units. For base distributions with + `event_ndims > 1`, use the `tfb.Reshape` bijector to flatten the event shape. + + Recall that the MAF bijector [2] implements a normalizing flow via an + autoregressive transformation. MAF and IAF have opposite computational + tradeoffs - MAF can train all units in parallel but must sample units + sequentially, while IAF must train units sequentially but can sample in + parallel. In contrast, Real NVP can compute both forward and inverse + computations in parallel. However, the lack of an autoregressive + transformations makes it less expressive on a per-bijector basis. + + A "valid" `shift_and_log_scale_fn` must compute each `shift` (aka `loc` or + "mu" [2]) and `log(scale)` (aka "alpha" [2]) such that each are broadcastable + with the arguments to `forward` and `inverse`, i.e., such that the + calculations in `forward`, `inverse` [below] are possible. For convenience, + `real_nvp_default_nvp` is offered as a possible `shift_and_log_scale_fn` + function. + + NICE [3] is a special case of the Real NVP bijector which discards the scale + transformation, resulting in a constant-time inverse-log-determinant-Jacobian. + To use a NICE bijector instead of Real NVP, `shift_and_log_scale_fn` should + return `(shift, None)`, and `is_constant_jacobian` should be set to `True` in + the `RealNVP` constructor. Calling `real_nvp_default_template` with + `shift_only=True` returns one such NICE-compatible `shift_and_log_scale_fn`. + + Caching: the scalar input depth `D` of the base distribution is not known at + construction time. The first call to any of `forward(x)`, `inverse(x)`, + `inverse_log_det_jacobian(x)`, or `forward_log_det_jacobian(x)` memoizes + `D`, which is re-used in subsequent calls. This shape must be known prior to + graph execution (which is the case if using tf.layers). + + #### Example Use + + ```python + tfd = tf.contrib.distributions + tfb = tfd.bijectors + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + nvp = tfd.TransformedDistribution( + distribution=tfd.MultivariateNormalDiag(loc=[0., 0., 0.])), + bijector=tfb.RealNVP( + num_masked=2, + shift_and_log_scale_fn=tfb.real_nvp_default_template( + hidden_layers=[512, 512]))) + + x = nvp.sample() + nvp.log_prob(x) + nvp.log_prob(0.) + ``` + + For more examples, see [4]. + + [1]: "Density Estimation using Real NVP." + Laurent Dinh, Jascha Sohl-Dickstein, Samy Bengio. ICLR. 2017. + https://arxiv.org/abs/1605.08803 + + [2]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [3]: "NICE: Non-linear Independent Components Estimation." + Laurent Dinh, David Krueger, Yoshua Bengio. ICLR. 2015. + https://arxiv.org/abs/1410.8516 + + [4]: "Normalizing Flows Tutorial, Part 2: Modern Normalizing Flows." + Eric Jang. Blog post. January 2018. + http://blog.evjang.com/2018/01/nf2.html + """ + + def __init__(self, + num_masked, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + name=None): + """Creates the Real NVP or NICE bijector. + + Args: + num_masked: Python `int` indicating that the first `d` units of the event + should be masked. Must be in the closed interval `[1, D-1]`, where `D` + is the event size of the base distribution. + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + ValueError: If num_masked < 1. + """ + name = name or "real_nvp" + if num_masked <= 0: + raise ValueError("num_masked must be a positive integer.") + self._num_masked = num_masked + # At construction time, we don't know input_depth. + self._input_depth = None + self._shift_and_log_scale_fn = shift_and_log_scale_fn + super(RealNVP, self).__init__( + event_ndims=1, + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _cache_input_depth(self, x): + if self._input_depth is None: + self._input_depth = x.shape.with_rank_at_least(1)[-1].value + if self._input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + if self._num_masked >= self._input_depth: + raise ValueError( + "Number of masked units must be smaller than the event size.") + + def _forward(self, x): + self._cache_input_depth(x) + # Performs scale and shift. + x0, x1 = x[:, :self._num_masked], x[:, self._num_masked:] + shift, log_scale = self._shift_and_log_scale_fn( + x0, self._input_depth - self._num_masked) + y1 = x1 + if log_scale is not None: + y1 *= math_ops.exp(log_scale) + if shift is not None: + y1 += shift + y = array_ops.concat([x0, y1], axis=-1) + return y + + def _inverse(self, y): + self._cache_input_depth(y) + # Performs un-shift and un-scale. + y0, y1 = y[:, :self._num_masked], y[:, self._num_masked:] + shift, log_scale = self._shift_and_log_scale_fn( + y0, self._input_depth - self._num_masked) + x1 = y1 + if shift is not None: + x1 -= shift + if log_scale is not None: + x1 *= math_ops.exp(-log_scale) + x = array_ops.concat([y0, x1], axis=-1) + return x + + def _inverse_log_det_jacobian(self, y): + self._cache_input_depth(y) + y0 = y[:, :self._num_masked] + _, log_scale = self._shift_and_log_scale_fn( + y0, self._input_depth - self._num_masked) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + def _forward_log_det_jacobian(self, x): + self._cache_input_depth(x) + x0 = x[:, :self._num_masked] + _, log_scale = self._shift_and_log_scale_fn( + x0, self._input_depth - self._num_masked) + if log_scale is None: + return constant_op.constant(0., dtype=x.dtype, name="ildj") + return math_ops.reduce_sum(log_scale, axis=-1) + + +def real_nvp_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + name=None, + *args, + **kwargs): + """Build a scale-and-shift function using a multi-layer neural network. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the `d`-dimensional input x[0:d] and returns the `D-d` + dimensional outputs `loc` ("mu") and `log_scale` ("alpha"). + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed (i.e. NICE bijector). Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + name: A name for ops managed by this function. Default: + "real_nvp_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "real_nvp_default_template"): + def _fn(x, output_units): + """Fully connected MLP parameterized via `real_nvp_template`.""" + for units in hidden_layers: + x = layers.dense( + inputs=x, + units=units, + activation=activation, + *args, + **kwargs) + x = layers.dense( + inputs=x, + units=(1 if shift_only else 2) * output_units, + activation=None, + *args, + **kwargs) + if shift_only: + return x, None + shift, log_scale = array_ops.split(x, 2, axis=-1) + return shift, log_scale + return template_ops.make_template( + "real_nvp_default_template", _fn) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py index 8997f7ab6929745275edb38712a5bbb0a9b25ddb..55eca063126797d577653f0d6bcdfddf8192bdb5 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -12,18 +12,303 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Reshape bijector.""" +"""Reshape bijectors.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["Reshape"] +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "Reshape", +] + + +def _static_ndims_from_shape(shape): + return shape.shape.with_rank_at_least(1)[0].value + + +def _ndims_from_shape(shape): + return array_ops.shape(shape)[0] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + + * The user must provide both the input and output shape, so that + the transformation can be inverted. If an input shape is not + specified, the default assumes a vector-shaped input, i.e., + event_shape_in = (-1,). + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + + Example usage: + ```python + + tfd = tf.contrib.distributions + + r = tfd.bijectors.Reshape(event_shape_out=[1, -1]) + + r.forward([3., 4.]) # shape [2] + # ==> [[3., 4.]] # shape [1, 2] + + r.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], + # [[3., 4.]]] # shape [2, 1, 2] + + r.inverse([[3., 4.]]) # shape [1,2] + # ==> [3., 4.] # shape [2] + + r.forward_log_det_jacobian(any_value) + # ==> 0. + + r.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in=(-1,), + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the event shape of the transformed output. + event_shape_in: An optional `int`-like vector-shape `Tensor` + representing the event shape of the input. This is required in + order to define inverse operations; the default of (-1,) + assumes a vector-shaped input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-integer `dtype`. + ValueError: if either of `event_shape_in` or `event_shape_out` + has non-vector shape (`rank > 1`), or if their sizes do not + match. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + assertions = [] + assertions.extend(self._maybe_check_valid_shape( + event_shape_out, validate_args)) + assertions.extend(self._maybe_check_valid_shape( + event_shape_in, validate_args)) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape, validate_args): + """Check that a shape Tensor is int-type and otherwise sane.""" + if not shape.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + shape.op.name, shape.dtype.name)) + + assertions = [] + + ndims = array_ops.rank(shape) + ndims_ = tensor_util.constant_value(ndims) + if ndims_ is not None and ndims_ > 1: + raise ValueError("`{}` rank ({}) should be <= 1.".format( + shape.op.name, ndims_)) + elif validate_args: + assertions.append(check_ops.assert_less_equal( + ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name))) + + shape_ = tensor_util.constant_value_as_shape(shape) + if shape_.is_fully_defined(): + es = np.int32(shape_.as_list()) + if sum(es == -1) > 1: + raise ValueError( + "`{}` must have at most one `-1` (given {})" + .format(shape.op.name, es)) + if np.any(es < -1): + raise ValueError( + "`{}` elements must be either positive integers or `-1`" + "(given {})." + .format(shape.op.name, es)) + elif validate_args: + assertions.extend([ + check_ops.assert_less_equal( + math_ops.reduce_sum( + math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)), + 1, + message="`{}` elements must have at most one `-1`." + .format(shape.op.name)), + check_ops.assert_greater_equal( + shape, -1, + message="`{}` elements must be either positive integers or `-1`." + .format(shape.op.name)), + ]) + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + event_ndims_in_ = _static_ndims_from_shape(event_shape_in) + event_ndims_in = _ndims_from_shape(event_shape_in) + x_ndims_, x_ndims = x.shape.ndims, array_ops.rank(x) + + assertions = [] + + # Ensure x.event_shape is compatible with event_shape_in. + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + # Compare the shape dimensions that are fully specified in the + # input (i.e., for which event_shape_in is not -1). If x_event_shape + # matches along all of these dimensions, it is compatible with + # the desired input shape and any further mismatches (i.e., + # imcompatibility with the desired *output* shape) will be + # caught inside of array_ops.reshape() below. + x_event_shape_specified_ = x_event_shape_[event_shape_in_ >= 0] + event_shape_in_specified_ = event_shape_in_[event_shape_in_ >= 0] + if not np.equal(x_event_shape_specified_, + event_shape_in_specified_).all(): + raise ValueError( + "Input `event_shape` does not match `event_shape_in` ({} vs {}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + # Similarly to the static case, we compare the shape dimensions + # that are fully specified in the input. We extract these + # dimensions using boolean_mask(), which requires that the mask + # have known ndims. We can assume that shape Tensors always have + # ndims==1 (this assumption is verified inside of + # _maybe_check_valid_shape), so the reshape operation is just a + # no-op that formally encodes this fact to make boolean_mask() + # happy. + event_shape_mask = array_ops.reshape(event_shape_in >= 0, [-1]) + x_event_shape_specified = array_ops.boolean_mask(x_event_shape, + event_shape_mask) + event_shape_in_specified = array_ops.boolean_mask(event_shape_in, + event_shape_mask) + assertions.append(check_ops.assert_equal( + x_event_shape_specified, event_shape_in_specified, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and event_ndims_in_ == x_ndims_): + # Hack to allow forward/inverse_event_shape to do shape + # inference by calling this helper method with a dummy Tensor of + # shape event_shape_in. In this special case, + # sample_and_batch_shape will be empty so we can preserve static + # shape information by avoiding the concat operation below + # (which would be a no-op). + new_shape = event_shape_out + else: + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + # NOTE: this method and the other *_event_shape* methods + # compute shape by explicit transformation of a dummy + # variable. This approach is not generally recommended because it + # bloats the graph and could in general trigger side effects. + # + # In this particular case of the Reshape bijector, the + # forward and inverse transforms have no side effects, and we + # believe the reduction in code complexity from delegating the + # heavy lifting to tf.reshape() is worth the added graph ops. + # However, you should think hard before implementing this approach + # in other Bijectors; it is strongly preferred to compute + # shapes explicitly whenever it's feasible to do so. + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return dummy_reshaped.shape + + def _inverse_event_shape(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return dummy_reshaped.shape + + def _forward_event_shape_tensor(self, input_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape) + dummy_reshaped = self.forward(dummy) + return array_ops.shape(dummy_reshaped) + + def _inverse_event_shape_tensor(self, output_shape): + with ops.control_dependencies(self._assertions): + dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape) + dummy_reshaped = self.inverse(dummy) + return array_ops.shape(dummy_reshaped) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py deleted file mode 100644 index 93682639aa3be3b8f59a369dedb6ee773c468130..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py +++ /dev/null @@ -1,297 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Reshape bijectors.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector as bijector_lib - - -__all__ = [ - "Reshape", -] - - -class Reshape(bijector_lib.Bijector): - """Reshapes the `event_shape` of a `Tensor`. - - The semantics generally follow that of `tf.reshape()`, with - a few differences: - * The user must provide both the input and output shape, so that - the transformation can be inverted. - * The `Reshape` bijector automatically broadcasts over the leftmost - dimensions of its input (`sample_shape` and `batch_shape`); only - the rightmost `event_ndims_in` dimensions are reshaped. The - number of dimensions to reshape is inferred from the provided - `event_shape_in` (`event_ndims_in = len(event_shape_in)`). - * The `Reshape` bijector does not currently support - partially-specified shapes, i.e., those with a dimension - implicitly specified by `-1`. - - Example usage: - ```python - - bs = tf.contrib.distributions.bijectors - - reverse = bs.Reshape(event_shape_out=[1,2], - event_shape_in=[2,]) - - reverse.forward([1., 2.]) # shape [2,] - # ==> [[1., 2.]] # shape [1,2] - - reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] - # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] - - reverse.inverse([[1., 2.]]) # shape [1,2] - # ==> [1., 2.] # shape [2,] - - reverse.forward_log_det_jacobian(any_value) - # ==> 0. - - reverse.inverse_log_det_jacobian(any_value) - # ==> 0. - ``` - - """ - - def __init__(self, event_shape_out, event_shape_in, - validate_args=False, name=None): - """Creates a `Reshape` bijector. - - Args: - event_shape_out: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - transformed output. - event_shape_in: An `int`-like vector-shaped `Tensor` - representing the fully specified (no -1's) event shape of the - input. - validate_args: Python `bool` indicating whether arguments should - be checked for correctness. - name: Python `str`, name given to ops managed by this object. - - Raises: - TypeError: if either `event_shape_in` or `event_shape_out` has - non-vector shape (`rank > 1`), or non-integer `dtype`. - ValueError: if either `event_shape_in` or `event_shape_out` - contains non-positive entries, or if their sizes do not match - (`prod(event_shape_in)` != `prod(event_shape_out)`), or if - their dimensionality(s) cannot be statically inferred. - """ - with ops.name_scope(name, "reshape", - values=[event_shape_out, event_shape_in]): - - event_shape_out = ops.convert_to_tensor(event_shape_out, - name="event_shape_out", - preferred_dtype=dtypes.int32) - event_shape_in = ops.convert_to_tensor(event_shape_in, - name="event_shape_in", - preferred_dtype=dtypes.int32) - - # check that input shapes are positive integers - assertions = [] - assertions += self._maybe_check_valid_shape( - event_shape_out, "event_shape_out", - validate_args=validate_args) - assertions += self._maybe_check_valid_shape( - event_shape_in, "event_shape_in", validate_args=validate_args) - - # check that prod(event_shape_in) = prod(event_shape_out) - assertions += self._maybe_check_matching_sizes( - event_shape_in, event_shape_out, validate_args=validate_args) - - self._assertions = assertions - self._event_shape_in = event_shape_in - self._event_shape_out = event_shape_out - self._event_shape_in_static = tensor_util.constant_value_as_shape( - event_shape_in) - self._event_shape_out_static = tensor_util.constant_value_as_shape( - event_shape_out) - - super(Reshape, self).__init__(is_constant_jacobian=True, - validate_args=validate_args, - name=name or "reshape") - - def _maybe_check_valid_shape(self, shape_tensor, label, - validate_args=False): - """Check that a shape Tensor is int-type and positive.""" - - assertions = [] - - if not shape_tensor.dtype.is_integer: - raise TypeError("{} dtype ({}) should be `int`-like.".format( - label, shape_tensor.dtype.name)) - - shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) - if shape_rank is not None and shape_rank > 1: - raise ValueError("{} rank should be <= 1.".format(label)) - - s = tensor_util.constant_value(shape_tensor) - if s is not None: - if (s <= 0).any(): - raise ValueError("{} entries must be positive, but found {}".format( - label, s)) - elif validate_args: - assertions.append(check_ops.assert_positive( - shape_tensor, message="{} entries must be positive".format(label))) - - return assertions - - def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, - validate_args=False): - """Check that prod(event_shape_in)==prod(event_shape_out).""" - - def _get_size_from_shape(shape): - """Computes size from a shape `Tensor`, statically if possible.""" - s = tensor_util.constant_value(shape) - if s is not None: - return [np.int32(np.prod(s))]*2 - return None, math_ops.reduce_prod(shape, name="size") - - # Ensure `event_shape_in` is compatible with `event_shape_out`. - event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_in) - event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking - event_shape_out) - - assertions = [] - if event_size_in_ is not None and event_size_out_ is not None: - if event_size_in_ != event_size_out_: - raise ValueError( - "Input `event_size` ({}) does not match output `event_size` ({}).". - format(event_size_in, event_size_out_)) - elif validate_args: - assertions.append(check_ops.assert_equal( - event_size_in, event_size_out, - message="Input/output `event_size`s do not match.")) - - return assertions - - def _reshape_helper(self, x, event_shape_in, event_shape_out): - """Reshape only the event_shape of an input `Tensor`.""" - - def _get_rank_from_shape(shape): - """Computes rank from a shape `Tensor`, statically if possible.""" - # Uses fact that rank is "shape of shape". - ndims = shape.shape.with_rank_at_least(1)[0].value - if ndims is not None: - return ndims, ndims - return None, array_ops.shape(shape)[0] - - event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) - - assertions = [] - # Ensure x.event_shape is compatible with event_shape_in. - if x.shape.ndims is not None: - x_ndims_, x_ndims = [x.shape.ndims]*2 - else: - x_ndims_, x_ndims = None, array_ops.rank(x) - - if (event_ndims_in_ is not None - and x_ndims_ is not None - and x.shape.with_rank_at_least(event_ndims_in_)[ - x_ndims_-event_ndims_in_:].is_fully_defined()): - x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking - np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 - else: - x_event_shape_, x_event_shape = ( - None, array_ops.shape(x)[x_ndims-event_ndims_in:]) - - event_shape_in_ = tensor_util.constant_value(event_shape_in) - - if x_event_shape_ is not None and event_shape_in_ is not None: - if not np.equal(x_event_shape_, event_shape_in_).all(): - raise ValueError( - "Input `event_shape` ({}) does not match `event_shape_in` ({}).". - format(x_event_shape_, event_shape_in_)) - elif self.validate_args: - assertions.append(check_ops.assert_equal( - x_event_shape, event_shape_in, - message="Input `event_shape` does not match `event_shape_in`.")) - - if assertions: - x = control_flow_ops.with_dependencies(assertions, x) - - # get the parts of shape(x) that will not change - sample_and_batch_shape = array_ops.shape(x) - - ndims = (x.shape.ndims if x.shape.ndims is not None - else array_ops.rank(x)) - sample_and_batch_shape = sample_and_batch_shape[ - :(ndims - math_ops.abs(event_ndims_in))] - - new_shape = array_ops.concat( - [sample_and_batch_shape, event_shape_out], axis=0) - - return array_ops.reshape(x, new_shape) - - def _forward(self, x): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(x, - self._event_shape_in, - self._event_shape_out) - - def _inverse(self, y): - with ops.control_dependencies(self._assertions): - return self._reshape_helper(y, - self._event_shape_out, - self._event_shape_in) - - def _inverse_log_det_jacobian(self, y): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=y.dtype) - - def _forward_log_det_jacobian(self, x): - with ops.control_dependencies(self._assertions): - return constant_op.constant(0., dtype=x.dtype) - - def _forward_event_shape(self, input_shape): - self._event_shape_in_static.assert_is_compatible_with(input_shape) - return self._event_shape_out_static - - def _inverse_event_shape(self, output_shape): - self._event_shape_out_static.assert_is_compatible_with(output_shape) - return self._event_shape_in_static - - def _forward_event_shape_tensor(self, input_shape): - input_assertions = self._maybe_check_valid_shape( - input_shape, "input event shape", validate_args=self.validate_args) - input_assertions += self._maybe_check_matching_sizes( - input_shape, self._event_shape_out, - validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - input_assertions + self._assertions, self._event_shape_out) - - def _inverse_event_shape_tensor(self, output_shape): - - output_assertions = self._maybe_check_valid_shape( - output_shape, "output event shape", validate_args=self.validate_args) - output_assertions += self._maybe_check_matching_sizes( - output_shape, self._event_shape_in, validate_args=self.validate_args) - - return control_flow_ops.with_dependencies( - output_assertions + self._assertions, self._event_shape_in) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py index c20e76c0b7367369865faf973377201c8b8b17e6..a640dfe7dfbcce96261589c7fc49107deaefdd54 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid.py @@ -18,12 +18,31 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Sigmoid"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Sigmoid", +] + + +class Sigmoid(bijector.Bijector): + """Bijector which computes `Y = g(X) = 1 / (1 + exp(-X))`.""" + + def __init__(self, validate_args=False, name="sigmoid"): + super(Sigmoid, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) + + def _forward(self, x): + return math_ops.sigmoid(x) + + def _inverse(self, y): + return math_ops.log(y) - math_ops.log1p(-y) + + def _inverse_log_det_jacobian(self, y): + return -math_ops.log(y) - math_ops.log1p(-y) + + def _forward_log_det_jacobian(self, x): + return -nn_ops.softplus(-x) - nn_ops.softplus(x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py index 448125230d24066697624bce03fed71a2c2f00b1..223bc9d042c69be05b0e578835a31ed6e83c0c97 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sigmoid_centered.py @@ -18,12 +18,22 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.contrib.distributions.python.ops.bijectors import softmax_centered -_allowed_symbols = ["SigmoidCentered"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SigmoidCentered", +] + + +class SigmoidCentered(softmax_centered.SoftmaxCentered): + """Bijector which computes Y = g(X) = exp([X 0]) / (1 + exp(-X)). + + Equivalent to: `bijector.SoftmaxCentered(event_ndims=0)`. + + See `bijector.SoftmaxCentered` for more details. + """ + + def __init__(self, validate_args=False, name="sigmoid_centered"): + super(SigmoidCentered, self).__init__( + event_ndims=0, validate_args=validate_args, name=name) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py index b3cf03c24612f5c618c71c0a8615f272acdf2d10..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh.py @@ -18,12 +18,162 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SinhArcsinh"] +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "SinhArcsinh", +] + + +def _sqrtx2p1(x): + """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" + return array_ops.where( + math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., + math_ops.sqrt(x**2. + 1.), + # For large x, calculating x**2 can overflow. This can be alleviated by + # considering: + # sqrt(1 + x**2) + # = exp(0.5 log(1 + x**2)) + # = exp(0.5 log(x**2 * (1 + x**-2))) + # = exp(log(x) + 0.5 * log(1 + x**-2)) + # = |x| * exp(0.5 log(1 + x**-2)) + # = |x| * sqrt(1 + x**-2) + # We omit the last term in this approximation. + # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, + # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, + # and higher order gradients, since the first order derivative of + # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), + # and all nth-order derivatives will be O(x**-(n + 2)). This makes any + # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. + math_ops.abs(x)) + + +class SinhArcsinh(bijector.Bijector): + """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. + + For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this + transformation is a + diffeomorphism of the real line `(-inf, inf)`. The inverse transform is + `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. + + The `SinhArcsinh` transformation of the Normal is described in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) + This Bijector allows a similar transformation of any distribution supported on + `(-inf, inf)`. + + #### Meaning of the parameters + + * If `skewness = 0` and `tailweight = 1`, this transform is the identity. + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is + "tilted" to the right. + * positive skew means positive values of `Y` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|Y|` become more likely. + * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is + "flat" around `Y = 0`, and a very steep drop-off in the tails. + * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more + peaked at the mode with heavier tails. + + To see the argument about the tails, note that for `|X| >> 1` and + `|X| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. + """ + + def __init__(self, + skewness=None, + tailweight=None, + event_ndims=0, + validate_args=False, + name="SinhArcsinh"): + """Instantiates the `SinhArcsinh` bijector. + + Args: + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. + tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[skewness, tailweight]): + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) + check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) + if validate_args: + self._tailweight = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._tailweight, + message="Argument tailweight was not positive") + ], self._tailweight) + super(SinhArcsinh, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def skewness(self): + """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._skewness + + @property + def tailweight(self): + """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" + return self._tailweight + + def _forward(self, x): + return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) + + def _inverse(self, y): + return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) + + def _inverse_log_det_jacobian(self, y): + # x = sinh(arcsinh(y) / tailweight - skewness) + # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), + # dx/dy + # = cosh(arcsinh(y) / tailweight - skewness) + # / (tailweight * sqrt(y**2 + 1)) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + math_ops.asinh(y) / self.tailweight - self.skewness) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). + / _sqrtx2p1(y)) + - math_ops.log(self.tailweight), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + # y = sinh((arcsinh(x) + skewness) * tailweight) + # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), + # dy/dx + # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + # This is computed inside the log to avoid catastrophic cancellations + # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). + math_ops.log(math_ops.cosh( + (math_ops.asinh(x) + self.skewness) * self.tailweight) + # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases + # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). + / _sqrtx2p1(x)) + + math_ops.log(self.tailweight), + axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py deleted file mode 100644 index 3a75e4ae9495793901b0da91a5aa3982aab35852..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SinhArcsinh bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - -__all__ = [ - "SinhArcsinh", -] - - -def _sqrtx2p1(x): - """Implementation of `sqrt(1 + x**2)` which is stable despite large `x`.""" - return array_ops.where( - math_ops.abs(x) * np.sqrt(np.finfo(x.dtype.as_numpy_dtype).eps) <= 1., - math_ops.sqrt(x**2. + 1.), - # For large x, calculating x**2 can overflow. This can be alleviated by - # considering: - # sqrt(1 + x**2) - # = exp(0.5 log(1 + x**2)) - # = exp(0.5 log(x**2 * (1 + x**-2))) - # = exp(log(x) + 0.5 * log(1 + x**-2)) - # = |x| * exp(0.5 log(1 + x**-2)) - # = |x| * sqrt(1 + x**-2) - # We omit the last term in this approximation. - # When |x| > 1 / sqrt(machineepsilon), the second term will be 1, - # due to sqrt(1 + x**-2) = 1. This is also true with the gradient term, - # and higher order gradients, since the first order derivative of - # sqrt(1 + x**-2) is -2 * x**-3 / (1 + x**-2) = -2 / (x**3 + x), - # and all nth-order derivatives will be O(x**-(n + 2)). This makes any - # gradient terms that contain any derivatives of sqrt(1 + x**-2) vanish. - math_ops.abs(x)) - - -class SinhArcsinh(bijector.Bijector): - """Compute `Y = g(X) = Sinh( (Arcsinh(X) + skewness) * tailweight )`. - - For `skewness in (-inf, inf)` and `tailweight in (0, inf)`, this - transformation is a - diffeomorphism of the real line `(-inf, inf)`. The inverse transform is - `X = g^{-1}(Y) = Sinh( ArcSinh(Y) / tailweight - skewness )`. - - The `SinhArcsinh` transformation of the Normal is described in - [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865) - This Bijector allows a similar transformation of any distribution supported on - `(-inf, inf)`. - - #### Meaning of the parameters - - * If `skewness = 0` and `tailweight = 1`, this transform is the identity. - * Positive (negative) `skewness` leads to positive (negative) skew. - * positive skew means, for unimodal `X` centered at zero, the mode of `Y` is - "tilted" to the right. - * positive skew means positive values of `Y` become more likely, and - negative values become less likely. - * Larger (smaller) `tailweight` leads to fatter (thinner) tails. - * Fatter tails mean larger values of `|Y|` become more likely. - * If `X` is a unit Normal, `tailweight < 1` leads to a distribution that is - "flat" around `Y = 0`, and a very steep drop-off in the tails. - * If `X` is a unit Normal, `tailweight > 1` leads to a distribution more - peaked at the mode with heavier tails. - - To see the argument about the tails, note that for `|X| >> 1` and - `|X| >> (|skewness| * tailweight)**tailweight`, we have - `Y approx 0.5 X**tailweight e**(sign(X) skewness * tailweight)`. - """ - - def __init__(self, - skewness=None, - tailweight=None, - event_ndims=0, - validate_args=False, - name="SinhArcsinh"): - """Instantiates the `SinhArcsinh` bijector. - - Args: - skewness: Skewness parameter. Float-type `Tensor`. Default is `0` - of type `float32`. - tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` and broadcastable `shape`. Default is `1` of type `float32`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[skewness, tailweight]): - tailweight = 1. if tailweight is None else tailweight - skewness = 0. if skewness is None else skewness - self._skewness = ops.convert_to_tensor( - skewness, name="skewness") - self._tailweight = ops.convert_to_tensor( - tailweight, name="tailweight", dtype=self._skewness.dtype) - check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) - if validate_args: - self._tailweight = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._tailweight, - message="Argument tailweight was not positive") - ], self._tailweight) - super(SinhArcsinh, self).__init__( - event_ndims=event_ndims, validate_args=validate_args, name=name) - - @property - def skewness(self): - """The `skewness` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._skewness - - @property - def tailweight(self): - """The `tailweight` in: `Y = Sinh((Arcsinh(X) + skewness) * tailweight)`.""" - return self._tailweight - - def _forward(self, x): - return math_ops.sinh((math_ops.asinh(x) + self.skewness) * self.tailweight) - - def _inverse(self, y): - return math_ops.sinh(math_ops.asinh(y) / self.tailweight - self.skewness) - - def _inverse_log_det_jacobian(self, y): - # x = sinh(arcsinh(y) / tailweight - skewness) - # Using sinh' = cosh, arcsinh'(y) = 1 / sqrt(y**2 + 1), - # dx/dy - # = cosh(arcsinh(y) / tailweight - skewness) - # / (tailweight * sqrt(y**2 + 1)) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(y) / tailweight) - skewness) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - math_ops.asinh(y) / self.tailweight - self.skewness) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) / tailweight) - skewness ~= arcsinh(x). - / _sqrtx2p1(y)) - - math_ops.log(self.tailweight), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - # y = sinh((arcsinh(x) + skewness) * tailweight) - # Using sinh' = cosh, arcsinh'(x) = 1 / sqrt(x**2 + 1), - # dy/dx - # = cosh((arcsinh(x) + skewness) * tailweight) * tailweight / sqrt(x**2 + 1) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - # This is computed inside the log to avoid catastrophic cancellations - # from cosh((arcsinh(x) + skewness) * tailweight) and sqrt(x**2 + 1). - math_ops.log(math_ops.cosh( - (math_ops.asinh(x) + self.skewness) * self.tailweight) - # TODO(srvasude): Consider using cosh(arcsinh(x)) in cases - # where (arcsinh(x) + skewness) * tailweight ~= arcsinh(x). - / _sqrtx2p1(x)) - + math_ops.log(self.tailweight), - axis=event_dims) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py index be6608f97880ae68e10b17c815bf2d8438293261..a9dcce6c526600f3b26c6bceb730417000917ce7 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered.py @@ -18,12 +18,223 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +import numpy as np -_allowed_symbols = ["SoftmaxCentered"] +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector -remove_undocumented(__name__, _allowed_symbols) + +__all__ = [ + "SoftmaxCentered", +] + + +class SoftmaxCentered(bijector.Bijector): + """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. + + To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a + bijection, the forward transformation appends a value to the input and the + inverse removes this coordinate. The appended coordinate represents a pivot, + e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last + coordinate. + + Because we append a coordinate, this bijector only supports `event_ndim in [0, + 1]`, i.e., scalars and vectors. + + Example Use: + + ```python + bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) + # Result: [0.2, 0.3, 0.4, 0.1] + # Extra result: 0.1 + + bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) + # Result: tf.log([2, 3, 4]) + # Extra coordinate removed. + ``` + + At first blush it may seem like the [Invariance of domain]( + https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this + implementation is not a bijection. However, the appended dimension + makes the (forward) image non-open and the theorem does not directly apply. + """ + + def __init__(self, + event_ndims=0, + validate_args=False, + name="softmax_centered"): + self._graph_parents = [] + self._name = name + with self._name_scope("init", values=[event_ndims]): + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims = tensor_util.constant_value(event_ndims) + if event_ndims is None or event_ndims not in [0, 1]: + raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") + self._static_event_ndims = event_ndims + super(SoftmaxCentered, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward_event_shape(self, input_shape): + if input_shape.ndims is None: + return input_shape + if input_shape.ndims != self._static_event_ndims: + raise ValueError("input_shape.dims = %d != %d" % + (input_shape.ndims, self._static_event_ndims)) + if input_shape.ndims == 0: + return tensor_shape.TensorShape([2]) + if input_shape.ndims == 1: + return tensor_shape.TensorShape(input_shape[0] + 1) + # Unreachable code: + raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) + + def _forward_event_shape_tensor(self, input_shape): + ndims = array_ops.shape(input_shape) + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_zero_or_one = check_ops.assert_equal( + ndims, 0 if self._static_event_ndims == 0 else 1, + message="event_ndims must be 0 or 1") + ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor( + [2], dtype=dtypes.int32, name="output_shape") + return input_shape + 1 + + def _inverse_event_shape(self, output_shape): + if output_shape.ndims is None: + return output_shape + if output_shape.ndims != 1: + raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) + if self._static_event_ndims == 0: + return tensor_shape.TensorShape([]) + return tensor_shape.TensorShape(output_shape[0] - 1) + + def _inverse_event_shape_tensor(self, output_shape): + ndims = array_ops.shape(output_shape)[0] + if self.validate_args: + # It is not possible for a negative shape so we need only check <= 1. + is_one = check_ops.assert_equal( + ndims, 1, message="event_ndims must be 1") + ndims = control_flow_ops.with_dependencies([is_one], ndims) + if self._static_event_ndims == 0: + return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") + return array_ops.expand_dims(output_shape[0] - 1, dim=0) + + def _forward(self, x): + # Pad the last dim with a zeros vector. We need this because it lets us + # infer the scale in the inverse function. + y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x + y = distribution_util.pad(y, axis=-1, back=True) + + # Set shape hints. + if x.shape.ndims is not None: + shape = x.shape.as_list() + if self._static_event_ndims == 0: + shape += [2] + elif shape[-1] is not None: + shape[-1] += 1 + shape = tensor_shape.TensorShape(shape) + y.shape.assert_is_compatible_with(shape) + y.set_shape(shape) + + # Since we only support event_ndims in [0, 1] and we do padding, we always + # reduce over the last dimension, i.e., dim=-1 (which is the default). + return nn_ops.softmax(y) + + def _inverse(self, y): + # To derive the inverse mapping note that: + # y[i] = exp(x[i]) / normalization + # and + # y[end] = 1 / normalization. + # Thus: + # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) + # = log(exp(x[i])/normalization) - log(y[end]) + # = log(y[i]) - log(y[end]) + shape = (np.asarray(y.shape.as_list(), dtype=np.int32) + if y.shape.is_fully_defined() + else array_ops.shape(y, name="shape")) + ndims = distribution_util.prefer_static_rank(y) + + # Do this first to make sure CSE catches that it'll happen again in + # _inverse_log_det_jacobian. + x = math_ops.log(y) + + # We now extract the last coordinate of the rightmost dimension. + # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. + begin = array_ops.one_hot(indices=ndims-1, + depth=ndims, + on_value=shape[-1]-np.array(1, dtype=shape.dtype), + dtype=shape.dtype) + size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) + log_normalization = -array_ops.strided_slice(x, begin, begin + size) + + # Here we slice out all but the last coordinate; see above for idea. + begin = array_ops.zeros_like(shape) + size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) + x = array_ops.strided_slice(x, begin, begin + size) + + x += log_normalization + + if self._static_event_ndims == 0: + x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) + + # Set shape hints. + if y.shape.ndims is not None: + shape = y.shape.as_list() + if self._static_event_ndims == 0: + shape = shape[:-1] + elif shape[-1] is not None: + shape[-1] -= 1 + shape = tensor_shape.TensorShape(shape) + x.shape.assert_is_compatible_with(shape) + x.set_shape(shape) + + return x + + def _inverse_log_det_jacobian(self, y): + # WLOG, consider the vector case: + # x = log(y[:-1]) - log(y[-1]) + # where, + # y[-1] = 1 - sum(y[:-1]). + # We have: + # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } + # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) + # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } + # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * + # det(diag(y[:-1])) } (2) + # = 1 / { y[-1] prod(y[:-1]) } + # = 1 / prod(y) + # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula + # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector + # docstring "Tip". + # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma + return -math_ops.reduce_sum(math_ops.log(y), axis=-1) + + def _forward_log_det_jacobian(self, x): + if self._static_event_ndims == 0: + return x - 2. * nn_ops.softplus(x) + else: + # This code is similar to nn_ops.log_softmax but different because we have + # an implicit zero column to handle. I.e., instead of: + # reduce_sum(logits - reduce_sum(exp(logits), dim)) + # we must do: + # log_normalization = 1 + reduce_sum(exp(logits)) + # -log_normalization + reduce_sum(logits - log_normalization) + log_normalization = nn_ops.softplus( + math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) + fldj = (-log_normalization + + math_ops.reduce_sum(x - log_normalization, + axis=-1, + keep_dims=True)) + return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py deleted file mode 100644 index 8645cc1b6b04be75a419342591272f07a4a1711c..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softmax_centered_impl.py +++ /dev/null @@ -1,245 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""SoftmaxCentered bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "SoftmaxCentered", -] - - -class SoftmaxCentered(bijector.Bijector): - """Bijector which computes `Y = g(X) = exp([X 0]) / sum(exp([X 0]))`. - - To implement [softmax](https://en.wikipedia.org/wiki/Softmax_function) as a - bijection, the forward transformation appends a value to the input and the - inverse removes this coordinate. The appended coordinate represents a pivot, - e.g., `softmax(x) = exp(x-c) / sum(exp(x-c))` where `c` is the implicit last - coordinate. - - Because we append a coordinate, this bijector only supports `event_ndim in [0, - 1]`, i.e., scalars and vectors. - - Example Use: - - ```python - bijector.SoftmaxCentered(event_ndims=1).forward(tf.log([2, 3, 4])) - # Result: [0.2, 0.3, 0.4, 0.1] - # Extra result: 0.1 - - bijector.SoftmaxCentered(event_ndims=1).inverse([0.2, 0.3, 0.4, 0.1]) - # Result: tf.log([2, 3, 4]) - # Extra coordinate removed. - ``` - - At first blush it may seem like the [Invariance of domain]( - https://en.wikipedia.org/wiki/Invariance_of_domain) theorem implies this - implementation is not a bijection. However, the appended dimension - makes the (forward) image non-open and the theorem does not directly apply. - """ - - def __init__(self, - event_ndims=0, - validate_args=False, - name="softmax_centered"): - self._graph_parents = [] - self._name = name - with self._name_scope("init", values=[event_ndims]): - event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") - event_ndims = tensor_util.constant_value(event_ndims) - if event_ndims is None or event_ndims not in [0, 1]: - raise ValueError("`event_ndims` must be a TF constant which is 0 or 1") - self._static_event_ndims = event_ndims - super(SoftmaxCentered, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward_event_shape(self, input_shape): - if input_shape.ndims is None: - return input_shape - if input_shape.ndims != self._static_event_ndims: - raise ValueError("input_shape.dims = %d != %d" % - (input_shape.ndims, self._static_event_ndims)) - if input_shape.ndims == 0: - return tensor_shape.TensorShape([2]) - if input_shape.ndims == 1: - return tensor_shape.TensorShape(input_shape[0] + 1) - # Unreachable code: - raise ValueError("event_ndims = %d must be 0 or 1" % input_shape.ndims) - - def _forward_event_shape_tensor(self, input_shape): - ndims = array_ops.shape(input_shape) - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_zero_or_one = check_ops.assert_equal( - ndims, 0 if self._static_event_ndims == 0 else 1, - message="event_ndims must be 0 or 1") - ndims = control_flow_ops.with_dependencies([is_zero_or_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor( - [2], dtype=dtypes.int32, name="output_shape") - return input_shape + 1 - - def _inverse_event_shape(self, output_shape): - if output_shape.ndims is None: - return output_shape - if output_shape.ndims != 1: - raise ValueError("output_shape.ndims = %d != 1" % output_shape.ndims) - if self._static_event_ndims == 0: - return tensor_shape.TensorShape([]) - return tensor_shape.TensorShape(output_shape[0] - 1) - - def _inverse_event_shape_tensor(self, output_shape): - ndims = array_ops.shape(output_shape)[0] - if self.validate_args: - # It is not possible for a negative shape so we need only check <= 1. - is_one = check_ops.assert_equal( - ndims, 1, message="event_ndims must be 1") - ndims = control_flow_ops.with_dependencies([is_one], ndims) - if self._static_event_ndims == 0: - return ops.convert_to_tensor([], dtype=dtypes.int32, name="output_shape") - return array_ops.expand_dims(output_shape[0] - 1, dim=0) - - def _forward(self, x): - # Pad the last dim with a zeros vector. We need this because it lets us - # infer the scale in the inverse function. - y = array_ops.expand_dims(x, dim=-1) if self._static_event_ndims == 0 else x - ndims = (y.get_shape().ndims if y.get_shape().ndims is not None - else array_ops.rank(y)) - y = array_ops.pad(y, - paddings=array_ops.concat( - (array_ops.zeros( - (ndims - 1, 2), dtype=dtypes.int32), [[0, 1]]), - 0)) - - # Set shape hints. - if x.get_shape().ndims is not None: - shape = x.get_shape().as_list() - if self._static_event_ndims == 0: - shape += [2] - elif shape[-1] is not None: - shape[-1] += 1 - shape = tensor_shape.TensorShape(shape) - y.get_shape().assert_is_compatible_with(shape) - y.set_shape(shape) - - # Since we only support event_ndims in [0, 1] and we do padding, we always - # reduce over the last dimension, i.e., dim=-1 (which is the default). - return nn_ops.softmax(y) - - def _inverse(self, y): - # To derive the inverse mapping note that: - # y[i] = exp(x[i]) / normalization - # and - # y[end] = 1 / normalization. - # Thus: - # x[i] = log(exp(x[i])) - log(y[end]) - log(normalization) - # = log(exp(x[i])/normalization) - log(y[end]) - # = log(y[i]) - log(y[end]) - shape = (np.asarray(y.get_shape().as_list(), dtype=np.int32) - if y.get_shape().is_fully_defined() - else array_ops.shape(y, name="shape")) - ndims = y.get_shape().ndims or math_ops.rank(y, name="ndims") - - # Do this first to make sure CSE catches that it'll happen again in - # _inverse_log_det_jacobian. - x = math_ops.log(y) - - # We now extract the last coordinate of the rightmost dimension. - # Our trick is to slice from [0,0,...,shape[-1]-1] to shape[:-1]+[1]. - begin = array_ops.one_hot(indices=ndims-1, - depth=ndims, - on_value=shape[-1]-np.array(1, dtype=shape.dtype), - dtype=shape.dtype) - size = array_ops.concat([shape[:-1], np.asarray([1], dtype=shape.dtype)], 0) - log_normalization = -array_ops.strided_slice(x, begin, begin + size) - - # Here we slice out all but the last coordinate; see above for idea. - begin = array_ops.zeros_like(shape) - size = array_ops.concat([shape[:-1], [shape[-1] - 1]], 0) - x = array_ops.strided_slice(x, begin, begin + size) - - x += log_normalization - - if self._static_event_ndims == 0: - x = array_ops.squeeze(x, squeeze_dims=[ndims-1]) - - # Set shape hints. - if y.get_shape().ndims is not None: - shape = y.get_shape().as_list() - if self._static_event_ndims == 0: - shape = shape[:-1] - elif shape[-1] is not None: - shape[-1] -= 1 - shape = tensor_shape.TensorShape(shape) - x.get_shape().assert_is_compatible_with(shape) - x.set_shape(shape) - - return x - - def _inverse_log_det_jacobian(self, y): - # WLOG, consider the vector case: - # x = log(y[:-1]) - log(y[-1]) - # where, - # y[-1] = 1 - sum(y[:-1]). - # We have: - # det{ dX/dY } = det{ diag(1 ./ y[:-1]) + 1 / y[-1] } - # = det{ inv{ diag(y[:-1]) - y[:-1]' y[:-1] } } (1) - # = 1 / det{ diag(y[:-1]) - y[:-1]' y[:-1] } - # = 1 / { (1 + y[:-1]' inv(diag(y[:-1])) y[:-1]) * - # det(diag(y[:-1])) } (2) - # = 1 / { y[-1] prod(y[:-1]) } - # = 1 / prod(y) - # (1) - https://en.wikipedia.org/wiki/Sherman%E2%80%93Morrison_formula - # or by noting that det{ dX/dY } = 1 / det{ dY/dX } from Bijector - # docstring "Tip". - # (2) - https://en.wikipedia.org/wiki/Matrix_determinant_lemma - return -math_ops.reduce_sum(math_ops.log(y), axis=-1) - - def _forward_log_det_jacobian(self, x): - if self._static_event_ndims == 0: - return x - 2. * nn_ops.softplus(x) - else: - # This code is similar to nn_ops.log_softmax but different because we have - # an implicit zero column to handle. I.e., instead of: - # reduce_sum(logits - reduce_sum(exp(logits), dim)) - # we must do: - # log_normalization = 1 + reduce_sum(exp(logits)) - # -log_normalization + reduce_sum(logits - log_normalization) - log_normalization = nn_ops.softplus( - math_ops.reduce_logsumexp(x, axis=-1, keep_dims=True)) - fldj = (-log_normalization + - math_ops.reduce_sum(x - log_normalization, - axis=-1, - keep_dims=True)) - return array_ops.squeeze(fldj, squeeze_dims=-1) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py index 250a1144b53bb43271ff7ee494604d9bae6feda8..81957fcf78922fa15fd20a25d144071f431161ae 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/softplus.py @@ -18,12 +18,127 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.softplus_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.distributions import util as distribution_util -_allowed_symbols = ["Softplus"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Softplus", +] + + +class Softplus(bijector.Bijector): + """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. + + The softplus `Bijector` has the following two useful properties: + + * The domain is the positive real numbers + * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as + the `Exp` `Bijector`. + + The optional nonzero `hinge_softness` parameter changes the transition at + zero. With `hinge_softness = c`, the bijector is: + + ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` + + For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, + so the behavior for large `x` is the same as the standard softplus. + + As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, + approaching `max(0, x)`. + + * `c = 1` is the default. + * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. + * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. + * `c = 0` results in a non-bijective transformation and triggers an exception. + + Example Use: + + ```python + # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 + # batch ndim and 2 event ndims (i.e., vector of matrices). + softplus = Softplus(event_ndims=2) + x = [[[1., 2], + [3, 4]], + [[5, 6], + [7, 8]]] + log(1 + exp(x)) == softplus.forward(x) + log(exp(x) - 1) == softplus.inverse(x) + ``` + + Note: log(.) and exp(.) are applied element-wise but the Jacobian is a + reduction over the event space. + """ + + @distribution_util.AppendDocstring( + kwargs_dict={ + "hinge_softness": ( + "Nonzero floating point `Tensor`. Controls the softness of what " + "would otherwise be a kink at the origin. Default is 1.0")}) + def __init__(self, + event_ndims=0, + hinge_softness=None, + validate_args=False, + name="softplus"): + with ops.name_scope(name, values=[hinge_softness]): + if hinge_softness is not None: + self._hinge_softness = ops.convert_to_tensor( + hinge_softness, name="hinge_softness") + else: + self._hinge_softness = None + if validate_args: + nonzero_check = check_ops.assert_none_equal( + ops.convert_to_tensor( + 0, dtype=self.hinge_softness.dtype), + self.hinge_softness, + message="hinge_softness must be non-zero") + self._hinge_softness = control_flow_ops.with_dependencies( + [nonzero_check], self.hinge_softness) + + super(Softplus, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + if self.hinge_softness is None: + return nn_ops.softplus(x) + hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) + return hinge_softness * nn_ops.softplus(x / hinge_softness) + + def _inverse(self, y): + if self.hinge_softness is None: + return distribution_util.softplus_inverse(y) + hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) + return hinge_softness * distribution_util.softplus_inverse( + y / hinge_softness) + + def _inverse_log_det_jacobian(self, y): + # Could also do: + # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), + # axis=event_dims) + # but the following is more numerically stable. Ie, + # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] + # ==> dX/dY = exp{Y} / (exp{Y} - 1) + # = 1 / (1 - exp{-Y}), + # which is the most stable for large Y > 0. For small Y, we use + # 1 - exp{-Y} approx Y. + if self.hinge_softness is not None: + y /= math_ops.cast(self.hinge_softness, y.dtype) + return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), + axis=self._event_dims_tensor(y)) + + def _forward_log_det_jacobian(self, x): + if self.hinge_softness is not None: + x /= math_ops.cast(self.hinge_softness, x.dtype) + return -math_ops.reduce_sum(nn_ops.softplus(-x), + axis=self._event_dims_tensor(x)) + + @property + def hinge_softness(self): + return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py deleted file mode 100644 index 81957fcf78922fa15fd20a25d144071f431161ae..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/softplus_impl.py +++ /dev/null @@ -1,144 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Softplus bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops.distributions import bijector -from tensorflow.python.ops.distributions import util as distribution_util - - -__all__ = [ - "Softplus", -] - - -class Softplus(bijector.Bijector): - """Bijector which computes `Y = g(X) = Log[1 + exp(X)]`. - - The softplus `Bijector` has the following two useful properties: - - * The domain is the positive real numbers - * `softplus(x) approx x`, for large `x`, so it does not overflow as easily as - the `Exp` `Bijector`. - - The optional nonzero `hinge_softness` parameter changes the transition at - zero. With `hinge_softness = c`, the bijector is: - - ```f_c(x) := c * g(x / c) = c * Log[1 + exp(x / c)].``` - - For large `x >> 1`, `c * Log[1 + exp(x / c)] approx c * Log[exp(x / c)] = x`, - so the behavior for large `x` is the same as the standard softplus. - - As `c > 0` approaches 0 from the right, `f_c(x)` becomes less and less soft, - approaching `max(0, x)`. - - * `c = 1` is the default. - * `c > 0` but small means `f(x) approx ReLu(x) = max(0, x)`. - * `c < 0` flips sign and reflects around the `y-axis`: `f_{-c}(x) = -f_c(-x)`. - * `c = 0` results in a non-bijective transformation and triggers an exception. - - Example Use: - - ```python - # Create the Y=g(X)=softplus(X) transform which works only on Tensors with 1 - # batch ndim and 2 event ndims (i.e., vector of matrices). - softplus = Softplus(event_ndims=2) - x = [[[1., 2], - [3, 4]], - [[5, 6], - [7, 8]]] - log(1 + exp(x)) == softplus.forward(x) - log(exp(x) - 1) == softplus.inverse(x) - ``` - - Note: log(.) and exp(.) are applied element-wise but the Jacobian is a - reduction over the event space. - """ - - @distribution_util.AppendDocstring( - kwargs_dict={ - "hinge_softness": ( - "Nonzero floating point `Tensor`. Controls the softness of what " - "would otherwise be a kink at the origin. Default is 1.0")}) - def __init__(self, - event_ndims=0, - hinge_softness=None, - validate_args=False, - name="softplus"): - with ops.name_scope(name, values=[hinge_softness]): - if hinge_softness is not None: - self._hinge_softness = ops.convert_to_tensor( - hinge_softness, name="hinge_softness") - else: - self._hinge_softness = None - if validate_args: - nonzero_check = check_ops.assert_none_equal( - ops.convert_to_tensor( - 0, dtype=self.hinge_softness.dtype), - self.hinge_softness, - message="hinge_softness must be non-zero") - self._hinge_softness = control_flow_ops.with_dependencies( - [nonzero_check], self.hinge_softness) - - super(Softplus, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - def _forward(self, x): - if self.hinge_softness is None: - return nn_ops.softplus(x) - hinge_softness = math_ops.cast(self.hinge_softness, x.dtype) - return hinge_softness * nn_ops.softplus(x / hinge_softness) - - def _inverse(self, y): - if self.hinge_softness is None: - return distribution_util.softplus_inverse(y) - hinge_softness = math_ops.cast(self.hinge_softness, y.dtype) - return hinge_softness * distribution_util.softplus_inverse( - y / hinge_softness) - - def _inverse_log_det_jacobian(self, y): - # Could also do: - # ildj = math_ops.reduce_sum(y - distribution_util.softplus_inverse(y), - # axis=event_dims) - # but the following is more numerically stable. Ie, - # Y = Log[1 + exp{X}] ==> X = Log[exp{Y} - 1] - # ==> dX/dY = exp{Y} / (exp{Y} - 1) - # = 1 / (1 - exp{-Y}), - # which is the most stable for large Y > 0. For small Y, we use - # 1 - exp{-Y} approx Y. - if self.hinge_softness is not None: - y /= math_ops.cast(self.hinge_softness, y.dtype) - return -math_ops.reduce_sum(math_ops.log(-math_ops.expm1(-y)), - axis=self._event_dims_tensor(y)) - - def _forward_log_det_jacobian(self, x): - if self.hinge_softness is not None: - x /= math_ops.cast(self.hinge_softness, x.dtype) - return -math_ops.reduce_sum(nn_ops.softplus(-x), - axis=self._event_dims_tensor(x)) - - @property - def hinge_softness(self): - return self._hinge_softness diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py index d439f28884d8bd7f2b808317e10c5b5e44bfcfa2..00520bcda85e9527767e6342bf75f10667c264a8 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/weibull.py @@ -18,12 +18,132 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.distributions.python.ops.bijectors.weibull_impl import * -# pylint: enable=wildcard-import -from tensorflow.python.util.all_util import remove_undocumented +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector -_allowed_symbols = ["Weibull"] -remove_undocumented(__name__, _allowed_symbols) +__all__ = [ + "Weibull", +] + + +class Weibull(bijector.Bijector): + """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. + + This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): + + ```none + Y ~ Weibull(scale, concentration) + pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( + scale / concentration) ** (concentration - 1) * exp( + -(y / scale) ** concentration) + ``` + """ + + def __init__(self, + scale=1., + concentration=1., + event_ndims=0, + validate_args=False, + name="weibull"): + """Instantiates the `Weibull` bijector. + + Args: + scale: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `concentration`. + This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + concentration: Positive Float-type `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[scale, concentration]): + self._scale = ops.convert_to_tensor(scale, name="scale") + self._concentration = ops.convert_to_tensor( + concentration, name="concentration") + check_ops.assert_same_float_dtype([self._scale, self._concentration]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, + message="Argument scale was not positive") + ], self._scale) + self._concentration = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._concentration, + message="Argument concentration was not positive") + ], self._concentration) + + super(Weibull, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + @property + def scale(self): + """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._scale + + @property + def concentration(self): + """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" + return self._concentration + + def _forward(self, x): + x = self._maybe_assert_valid_x(x) + return -math_ops.expm1(-((x / self.scale) ** self.concentration)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + -math_ops.log1p(-y) + + (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + + math_ops.log(self.scale / self.concentration), + axis=event_dims) + + def _forward_log_det_jacobian(self, x): + x = self._maybe_assert_valid_x(x) + event_dims = self._event_dims_tensor(x) + return math_ops.reduce_sum( + -(x / self.scale) ** self.concentration + + (self.concentration - 1) * math_ops.log(x) + + math_ops.log(self.concentration) + + -self.concentration * math_ops.log(self.scale), + axis=event_dims) + + def _maybe_assert_valid_x(self, x): + if not self.validate_args: + return x + is_valid = check_ops.assert_non_negative( + x, + message="Forward transformation input must be at least {}.".format(0)) + return control_flow_ops.with_dependencies([is_valid], x) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py deleted file mode 100644 index 00520bcda85e9527767e6342bf75f10667c264a8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/distributions/python/ops/bijectors/weibull_impl.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Weibull bijector.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import bijector - - -__all__ = [ - "Weibull", -] - - -class Weibull(bijector.Bijector): - """Compute `Y = g(X) = 1 - exp((-X / scale) ** concentration), X >= 0`. - - This bijector maps inputs from `[0, inf]` to [0, 1]`. The inverse of the - bijector applied to a uniform random variable `X ~ U(0, 1) gives back a - random variable with the - [Weibull distribution](https://en.wikipedia.org/wiki/Weibull_distribution): - - ```none - Y ~ Weibull(scale, concentration) - pdf(y; scale, concentration, y >= 0) = (scale / concentration) * ( - scale / concentration) ** (concentration - 1) * exp( - -(y / scale) ** concentration) - ``` - """ - - def __init__(self, - scale=1., - concentration=1., - event_ndims=0, - validate_args=False, - name="weibull"): - """Instantiates the `Weibull` bijector. - - Args: - scale: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `concentration`. - This is `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - concentration: Positive Float-type `Tensor` that is the same dtype and is - broadcastable with `scale`. - This is `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`. - event_ndims: Python scalar indicating the number of dimensions associated - with a particular draw from the distribution. - validate_args: Python `bool` indicating whether arguments should be - checked for correctness. - name: Python `str` name given to ops managed by this object. - """ - self._graph_parents = [] - self._name = name - self._validate_args = validate_args - with self._name_scope("init", values=[scale, concentration]): - self._scale = ops.convert_to_tensor(scale, name="scale") - self._concentration = ops.convert_to_tensor( - concentration, name="concentration") - check_ops.assert_same_float_dtype([self._scale, self._concentration]) - if validate_args: - self._scale = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._scale, - message="Argument scale was not positive") - ], self._scale) - self._concentration = control_flow_ops.with_dependencies([ - check_ops.assert_positive( - self._concentration, - message="Argument concentration was not positive") - ], self._concentration) - - super(Weibull, self).__init__( - event_ndims=event_ndims, - validate_args=validate_args, - name=name) - - @property - def scale(self): - """The `l` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._scale - - @property - def concentration(self): - """The `k` in `Y = g(X) = 1 - exp((-x / l) ** k)`.""" - return self._concentration - - def _forward(self, x): - x = self._maybe_assert_valid_x(x) - return -math_ops.expm1(-((x / self.scale) ** self.concentration)) - - def _inverse(self, y): - y = self._maybe_assert_valid_y(y) - return self.scale * (-math_ops.log1p(-y)) ** (1 / self.concentration) - - def _inverse_log_det_jacobian(self, y): - y = self._maybe_assert_valid_y(y) - event_dims = self._event_dims_tensor(y) - return math_ops.reduce_sum( - -math_ops.log1p(-y) + - (1 / self.concentration - 1) * math_ops.log(-math_ops.log1p(-y)) + - math_ops.log(self.scale / self.concentration), - axis=event_dims) - - def _forward_log_det_jacobian(self, x): - x = self._maybe_assert_valid_x(x) - event_dims = self._event_dims_tensor(x) - return math_ops.reduce_sum( - -(x / self.scale) ** self.concentration + - (self.concentration - 1) * math_ops.log(x) + - math_ops.log(self.concentration) + - -self.concentration * math_ops.log(self.scale), - axis=event_dims) - - def _maybe_assert_valid_x(self, x): - if not self.validate_args: - return x - is_valid = check_ops.assert_non_negative( - x, - message="Forward transformation input must be at least {}.".format(0)) - return control_flow_ops.with_dependencies([is_valid], x) - - def _maybe_assert_valid_y(self, y): - if not self.validate_args: - return y - is_positive = check_ops.assert_non_negative( - y, message="Inverse transformation input must be greater than 0.") - less_than_one = check_ops.assert_less_equal( - y, constant_op.constant(1., y.dtype), - message="Inverse transformation input must be less than or equal to 1.") - return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py index 8d59c1abfbc607c67b2bbca21f880743a43e5b2a..6f5d724a2a945ed8f9c159d8314327c6f994d1db 100644 --- a/tensorflow/contrib/distributions/python/ops/cauchy.py +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -43,16 +43,17 @@ class Cauchy(distribution.Distribution): The probability density function (pdf) is, ```none - pdf(x; loc, scale) = 1 / (pi * scale * (1 + ((x - loc) / scale)**2)) + pdf(x; loc, scale) = 1 / (pi scale (1 + z**2)) + z = (x - loc) / scale ``` where `loc` is the location, and `scale` is the scale. The Cauchy distribution is a member of the [location-scale family]( https://en.wikipedia.org/wiki/Location-scale_family), i.e. + `Y ~ Cauchy(loc, scale)` is equivalent to, ```none X ~ Cauchy(loc=0, scale=1) - Y ~ Cauchy(loc=loc, scale=scale) Y = loc + scale * X ``` @@ -61,14 +62,16 @@ class Cauchy(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Cauchy distribution. - dist = Cauchy(loc=0., scale=3.) + dist = tfd.Cauchy(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Cauchy distributions. - dist = Cauchy(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Cauchy(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,18 +79,17 @@ class Cauchy(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - - Arguments are broadcast when possible. - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Cauchy distributions. # Both have median 1, but different scales. - dist = tf.contrib.distributions.Cauchy(loc=1., scale=[11, 22.]) + dist = tfd.Cauchy(loc=1., scale=[11, 22.]) + # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. - dist.prob(3.0) + dist.prob(3.) ``` + """ def __init__(self, diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 599c855cda434d9249187d5d154d50a8a8c49a6c..1d4c5660d8d73b7b6a7e758fc834ccfddeb5c8ea 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -121,7 +121,7 @@ class ConditionalTransformedDistribution( log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._is_maybe_event_override: log_prob = math_ops.reduce_sum(log_prob, self._reduce_event_indices) - return ildj + log_prob + return math_ops.cast(ildj, log_prob.dtype) + log_prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _prob(self, y, bijector_kwargs=None, distribution_kwargs=None): @@ -143,7 +143,7 @@ class ConditionalTransformedDistribution( prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: prob = math_ops.reduce_prod(prob, self._reduce_event_indices) - return math_ops.exp(ildj) * prob + return math_ops.exp(math_ops.cast(ildj, prob.dtype)) * prob @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _log_cdf(self, y, bijector_kwargs=None, distribution_kwargs=None): diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py index 850d08d1bd69ebc7661557d648e2bffe77e6a908..8049522e9f5dc26b244b7e710a9ae8b981efd6b6 100644 --- a/tensorflow/contrib/distributions/python/ops/deterministic.py +++ b/tensorflow/contrib/distributions/python/ops/deterministic.py @@ -290,8 +290,10 @@ class VectorDeterministic(_BaseDeterministic): #### Examples ```python + tfd = tf.contrib.distributions + # Initialize a single VectorDeterministic supported at [0., 2.] in R^2. - constant = tf.contrib.distributions.Deterministic([0., 2.]) + constant = tfd.Deterministic([0., 2.]) constant.prob([0., 2.]) ==> 1. constant.prob([0., 3.]) @@ -299,7 +301,7 @@ class VectorDeterministic(_BaseDeterministic): # Initialize a [3] batch of constants on R^2. loc = [[0., 1.], [2., 3.], [4., 5.]] - constant = constant_lib.VectorDeterministic(loc) + constant = tfd.VectorDeterministic(loc) constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]]) ==> [1., 0., 0.] ``` diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index 869b5698e57d199755ce1686a74a1eafe3b73e7d..289e1d50e1146a641c0cc433ece3465aed73b1c2 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -21,7 +21,6 @@ from __future__ import print_function from tensorflow.contrib import linalg from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops @@ -330,54 +329,14 @@ def shapes_from_loc_and_scale(loc, scale, name="shapes_from_loc_and_scale"): else: loc_batch_shape = ops.convert_to_tensor(loc_batch_shape, name="loc_batch_shape") + # This is defined in the core util module. + # pylint: disable=undefined-variable batch_shape = prefer_static_broadcast_shape(batch_shape, loc_batch_shape) + # pylint: enable=undefined-variable return batch_shape, event_shape -def prefer_static_broadcast_shape( - shape1, shape2, name="prefer_static_broadcast_shape"): - """Convenience function which statically broadcasts shape when possible. - - Args: - shape1: `1-D` integer `Tensor`. Already converted to tensor! - shape2: `1-D` integer `Tensor`. Already converted to tensor! - name: A string name to prepend to created ops. - - Returns: - The broadcast shape, either as `TensorShape` (if broadcast can be done - statically), or as a `Tensor`. - """ - with ops.name_scope(name, values=[shape1, shape2]): - def make_shape_tensor(x): - return ops.convert_to_tensor(x, name="shape", dtype=dtypes.int32) - - def get_tensor_shape(s): - if isinstance(s, tensor_shape.TensorShape): - return s - s_ = tensor_util.constant_value(make_shape_tensor(s)) - if s_ is not None: - return tensor_shape.TensorShape(s_) - return None - - def get_shape_tensor(s): - if not isinstance(s, tensor_shape.TensorShape): - return make_shape_tensor(s) - if s.is_fully_defined(): - return make_shape_tensor(s.as_list()) - raise ValueError("Cannot broadcast from partially " - "defined `TensorShape`.") - - shape1_ = get_tensor_shape(shape1) - shape2_ = get_tensor_shape(shape2) - if shape1_ is not None and shape2_ is not None: - return array_ops.broadcast_static_shape(shape1_, shape2_) - - shape1_ = get_shape_tensor(shape1) - shape2_ = get_shape_tensor(shape2) - return array_ops.broadcast_dynamic_shape(shape1_, shape2_) - - def get_broadcast_shape(*tensors): """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. @@ -484,6 +443,44 @@ def maybe_check_scalar_distribution( return assertions +def pad_mixture_dimensions(x, mixture_distribution, categorical_distribution, + event_ndims): + """Pad dimensions of event tensors for mixture distributions. + + See `Mixture._sample_n` and `MixtureSameFamily._sample_n` for usage examples. + + Args: + x: event tensor to pad. + mixture_distribution: Base distribution of the mixture. + categorical_distribution: `Categorical` distribution that mixes the base + distribution. + event_ndims: Integer specifying the number of event dimensions in the event + tensor. + + Returns: + A padded version of `x` that can broadcast with `categorical_distribution`. + """ + with ops.name_scope("pad_mix_dims", values=[x]): + def _get_ndims(d): + if d.batch_shape.ndims is not None: + return d.batch_shape.ndims + return array_ops.shape(d.batch_shape_tensor())[0] + dist_batch_ndims = _get_ndims(mixture_distribution) + cat_batch_ndims = _get_ndims(categorical_distribution) + pad_ndims = array_ops.where( + categorical_distribution.is_scalar_batch(), + dist_batch_ndims, + dist_batch_ndims - cat_batch_ndims) + s = array_ops.shape(x) + x = array_ops.reshape(x, shape=array_ops.concat([ + s[:-1], + array_ops.ones([pad_ndims], dtype=dtypes.int32), + s[-1:], + array_ops.ones([event_ndims], dtype=dtypes.int32), + ], axis=0)) + return x + + def static_value(x): """Returns the static value of a `Tensor` or `None`.""" return tensor_util.constant_value(ops.convert_to_tensor(x)) diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py index ba8d3c639b397422f0f6210ba9f48650f0da1e3e..d0efaefb8e78ddf4436e9e5a112d2c1cdddaf3b5 100644 --- a/tensorflow/contrib/distributions/python/ops/gumbel.py +++ b/tensorflow/contrib/distributions/python/ops/gumbel.py @@ -62,15 +62,17 @@ class _Gumbel(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Gumbel distribution. - dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.) + dist = tfd.Gumbel(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Gumbels. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Gumbel(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -85,7 +87,7 @@ class _Gumbel(distribution.Distribution): ```python # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.]) + dist = tfd.Gumbel(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/half_normal.py b/tensorflow/contrib/distributions/python/ops/half_normal.py index 12059b6a9e199dc3ae00ac47a62ece9c9a147000..fc0751a6e0b78cb3d79bd3478e740bb05cd26428 100644 --- a/tensorflow/contrib/distributions/python/ops/half_normal.py +++ b/tensorflow/contrib/distributions/python/ops/half_normal.py @@ -84,6 +84,7 @@ class HalfNormal(distribution.Distribution): ``` """ + def __init__(self, scale, validate_args=False, @@ -120,7 +121,7 @@ class HalfNormal(distribution.Distribution): @staticmethod def _param_shapes(sample_shape): - return {'scale': ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} + return {"scale": ops.convert_to_tensor(sample_shape, dtype=dtypes.int32)} @property def scale(self): diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py index 6a74ca9a0ae1ad30081d21cc15a65be052a99e2a..cbce005013281ff3c58c94d525d5ce7a865d725a 100644 --- a/tensorflow/contrib/distributions/python/ops/independent.py +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -68,11 +68,11 @@ class Independent(distribution_lib.Distribution): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Make independent distribution from a 2-batch Normal. - ind = ds.Independent( - distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), + ind = tfd.Independent( + distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]), reinterpreted_batch_ndims=1) # All batch dims have been "absorbed" into event dims. @@ -80,8 +80,8 @@ class Independent(distribution_lib.Distribution): ind.event_shape # ==> [2] # Make independent distribution from a 2-batch bivariate Normal. - ind = ds.Independent( - distribution=ds.MultivariateNormalDiag( + ind = tfd.Independent( + distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5]), reinterpreted_batch_ndims=1) diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py index 956dee38a378813434656a28a69c89b6ec1e8b72..ee4d86867d48b20e97757bcec57d452085814b80 100644 --- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py +++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py @@ -88,8 +88,9 @@ class InverseGamma(distribution.Distribution): #### Examples ```python - dist = InverseGamma(concentration=3.0, rate=2.0) - dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) + tfd = tf.contrib.distributions + dist = tfd.InverseGamma(concentration=3.0, rate=2.0) + dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0]) ``` """ diff --git a/tensorflow/contrib/distributions/python/ops/kumaraswamy.py b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..74d5d8773cf3e69a52554c87d656fea2835c8354 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/kumaraswamy.py @@ -0,0 +1,258 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Kumaraswamy distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops.distributions import beta +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util +from tensorflow.python.util.tf_export import tf_export + +__all__ = [ + "Kumaraswamy", +] + +_kumaraswamy_sample_note = """Note: `x` must have dtype `self.dtype` and be in +`[0, 1].` It must have a shape compatible with `self.batch_shape()`.""" + + +def _harmonic_number(x): + """Compute the harmonic number from its analytic continuation. + + Derivation from [1] and Euler's constant [2]. + [1] - + https://en.wikipedia.org/wiki/Digamma_function#Relation_to_harmonic_numbers + [2] - https://en.wikipedia.org/wiki/Euler%E2%80%93Mascheroni_constant + + + Args: + x: input float. + + Returns: + z: The analytic continuation of the harmonic number for the input. + + """ + one = array_ops.ones([], dtype=x.dtype) + return math_ops.digamma(x + one) - math_ops.digamma(one) + + +@tf_export("distributions.Kumaraswamy") +class Kumaraswamy(beta.Beta): + """Kumaraswamy distribution. + + The Kumaraswamy distribution is defined over the `(0, 1)` interval using + parameters + `concentration1` (aka "alpha") and `concentration0` (aka "beta"). It has a + shape similar to the Beta distribution, but is reparameterizeable. + + #### Mathematical Details + + The probability density function (pdf) is, + + ```none + pdf(x; alpha, beta) = alpha * beta * x**(alpha - 1) * (1 - x**alpha)**(beta - + 1) + ``` + + where: + + * `concentration1 = alpha`, + * `concentration0 = beta`, + + Distribution parameters are automatically broadcast in all functions; see + examples for details. + + #### Examples + + ```python + # Create a batch of three Kumaraswamy distributions. + alpha = [1, 2, 3] + beta = [1, 2, 3] + dist = Kumaraswamy(alpha, beta) + + dist.sample([4, 5]) # Shape [4, 5, 3] + + # `x` has three batch entries, each with two samples. + x = [[.1, .4, .5], + [.2, .3, .5]] + # Calculate the probability of each pair of samples under the corresponding + # distribution in `dist`. + dist.prob(x) # Shape [2, 3] + ``` + + ```python + # Create batch_shape=[2, 3] via parameter broadcast: + alpha = [[1.], [2]] # Shape [2, 1] + beta = [3., 4, 5] # Shape [3] + dist = Kumaraswamy(alpha, beta) + + # alpha broadcast as: [[1., 1, 1,], + # [2, 2, 2]] + # beta broadcast as: [[3., 4, 5], + # [3, 4, 5]] + # batch_Shape [2, 3] + dist.sample([4, 5]) # Shape [4, 5, 2, 3] + + x = [.2, .3, .5] + # x will be broadcast as [[.2, .3, .5], + # [.2, .3, .5]], + # thus matching batch_shape [2, 3]. + dist.prob(x) # Shape [2, 3] + ``` + + """ + + def __init__(self, + concentration1=None, + concentration0=None, + validate_args=False, + allow_nan_stats=True, + name="Kumaraswamy"): + """Initialize a batch of Kumaraswamy distributions. + + Args: + concentration1: Positive floating-point `Tensor` indicating mean + number of successes; aka "alpha". Implies `self.dtype` and + `self.batch_shape`, i.e., + `concentration1.shape = [N1, N2, ..., Nm] = self.batch_shape`. + concentration0: Positive floating-point `Tensor` indicating mean + number of failures; aka "beta". Otherwise has same semantics as + `concentration1`. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + super(Kumaraswamy, self).__init__( + concentration1=concentration1, + concentration0=concentration0, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + name=name) + self._reparameterization_type = distribution.FULLY_REPARAMETERIZED + + def _sample_n(self, n, seed=None): + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + uniform_sample = random_ops.random_uniform( + shape=shape, minval=0.0, maxval=1.0, dtype=self.dtype, seed=seed) + + kumaraswamy_sample = (1 - uniform_sample**(1. / expanded_concentration0))**( + 1. / expanded_concentration1) + return kumaraswamy_sample + + @distribution_util.AppendDocstring(_kumaraswamy_sample_note) + def _log_cdf(self, x): + a = self.concentration1 + b = self.concentration0 + return math_ops.log1p(-(1 - x**a)**b) + + @distribution_util.AppendDocstring(_kumaraswamy_sample_note) + def _cdf(self, x): + a = self.concentration1 + b = self.concentration0 + return 1 - (1 - x**a)**b + + def _survival_function(self, x): + a = self.concentration1 + b = self.concentration0 + return (1 - x**a)**b + + def _log_survival_function(self, x): + a = self.concentration1 + b = self.concentration0 + return b * math_ops.log1p(-x**a) + + def _log_unnormalized_prob(self, x): + x = self._maybe_assert_valid_sample(x) + a = self.concentration1 + b = self.concentration0 + return (a - 1) * math_ops.log(x) + (b - 1) * math_ops.log1p(-x**a) + + def _log_normalization(self): + a = self.concentration1 + b = self.concentration0 + return -(math_ops.log(a) + math_ops.log(b)) + + def _entropy(self): + a = self.concentration1 + b = self.concentration0 + return (1 - 1. / a) + ( + 1 - 1. / b) * _harmonic_number(b) + math_ops.log(a) + math_ops.log(b) + + def _moment(self, n): + """Compute the n'th (uncentered) moment.""" + expanded_concentration1 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration1 + expanded_concentration0 = array_ops.ones_like( + self.total_concentration, dtype=self.dtype) * self.concentration0 + beta_arg0 = 1 + n / expanded_concentration1 + beta_arg = array_ops.stack([beta_arg0, expanded_concentration0], -1) + log_moment = math_ops.log(expanded_concentration0) + special_math_ops.lbeta( + beta_arg) + return math_ops.exp(log_moment) + + def _mean(self): + return self._moment(1) + + def _variance(self): + # TODO(b/72696533): Investigate a more numerically stable version. + return self._moment(2) - math_ops.square(self._moment(1)) + + @distribution_util.AppendDocstring( + """Note: The mode is undefined when `concentration1 <= 1` or + `concentration0 <= 1`. If `self.allow_nan_stats` is `True`, `NaN` + is used for undefined modes. If `self.allow_nan_stats` is `False` an + exception is raised when one or more modes are undefined.""") + def _mode(self): + a = self.concentration1 + b = self.concentration0 + mode = ((a - 1) / (a * b - 1))**(1. / a) + if self.allow_nan_stats: + nan = array_ops.fill( + self.batch_shape_tensor(), + np.array(np.nan, dtype=self.dtype.as_numpy_dtype), + name="nan") + is_defined = (self.concentration1 > 1.) & (self.concentration0 > 1.) + return array_ops.where(is_defined, mode, nan) + return control_flow_ops.with_dependencies([ + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration1, + message="Mode undefined for concentration1 <= 1."), + check_ops.assert_less( + array_ops.ones([], dtype=self.dtype), + self.concentration0, + message="Mode undefined for concentration0 <= 1.") + ], mode) diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py index 48794a48828fe796e233e968d8c755136ce166ad..473677f8d91b184e029f345bb05f5c5d63df7a40 100644 --- a/tensorflow/contrib/distributions/python/ops/logistic.py +++ b/tensorflow/contrib/distributions/python/ops/logistic.py @@ -60,15 +60,17 @@ class Logistic(distribution.Distribution): Examples of initialization of one or a batch of distributions. ```python + tfd = tf.contrib.distributions + # Define a single scalar Logistic distribution. - dist = tf.contrib.distributions.Logistic(loc=0., scale=3.) + dist = tfd.Logistic(loc=0., scale=3.) # Evaluate the cdf at 1, returning a scalar. dist.cdf(1.) # Define a batch of two scalar valued Logistics. # The first has mean 1 and scale 11, the second 2 and 22. - dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.]) + dist = tfd.Logistic(loc=[1, 2.], scale=[11, 22.]) # Evaluate the pdf of the first distribution on 0, and the second on 1.5, # returning a length two tensor. @@ -76,14 +78,11 @@ class Logistic(distribution.Distribution): # Get 3 samples, returning a 3 x 2 tensor. dist.sample([3]) - ``` - Arguments are broadcast when possible. - - ```python + # Arguments are broadcast when possible. # Define a batch of two scalar valued Logistics. # Both have mean 1, but different scales. - dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.]) + dist = tfd.Logistic(loc=1., scale=[11, 22.]) # Evaluate the pdf of both distributions on the same point, 3.0, # returning a length 2 tensor. diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index e676931d9145e72907d990148ee2d180e0da0258..cef6a143fc615901315a3780bf4ed53b8c7cd177 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -49,13 +49,13 @@ class Mixture(distribution.Distribution): ```python # Create a mixture of two Gaussians: - ds = tf.contrib.distributions + tfd = tf.contrib.distributions mix = 0.3 - bimix_gauss = ds.Mixture( - cat=ds.Categorical(probs=[mix, 1.-mix]), + bimix_gauss = tfd.Mixture( + cat=tfd.Categorical(probs=[mix, 1.-mix]), components=[ - ds.Normal(loc=-1., scale=0.1), - ds.Normal(loc=+1., scale=0.5), + tfd.Normal(loc=-1., scale=0.1), + tfd.Normal(loc=+1., scale=0.5), ]) # Plot the PDF. @@ -71,6 +71,7 @@ class Mixture(distribution.Distribution): components, validate_args=False, allow_nan_stats=True, + use_static_graph=False, name="Mixture"): """Initialize a Mixture distribution. @@ -96,6 +97,11 @@ class Mixture(distribution.Distribution): exception if a statistic (e.g. mean/mode/etc...) is undefined for any batch member. If `True`, batch members with valid parameters leading to undefined statistics will return NaN for this statistic. + use_static_graph: Calls to `sample` will not rely on dynamic tensor + indexing, allowing for some static graph compilation optimizations, but + at the expense of sampling all underlying distributions in the mixture. + (Possibly useful when running on TPUs). + Default value: `False` (i.e., use dynamic indexing). name: A name for this distribution (optional). Raises: @@ -178,6 +184,10 @@ class Mixture(distribution.Distribution): self._static_event_shape = static_event_shape self._static_batch_shape = static_batch_shape + self._use_static_graph = use_static_graph + if use_static_graph and static_num_components is None: + raise ValueError("Number of categories must be known statically when " + "`static_sample=True`.") # We let the Mixture distribution access _graph_parents since its arguably # more like a baseclass. graph_parents = self._cat._graph_parents # pylint: disable=protected-access @@ -292,6 +302,31 @@ class Mixture(distribution.Distribution): return mixture_log_cdf def _sample_n(self, n, seed=None): + if self._use_static_graph: + # This sampling approach is almost the same as the approach used by + # `MixtureSameFamily`. The differences are due to having a list of + # `Distribution` objects rather than a single object, and maintaining + # random seed management that is consistent with the non-static code path. + samples = [] + cat_samples = self.cat.sample(n, seed=seed) + for c in range(self.num_components): + seed = distribution_util.gen_new_seed(seed, "mixture") + samples.append(self.components[c].sample(n, seed=seed)) + x = array_ops.stack( + samples, -self._static_event_shape.ndims - 1) # [n, B, k, E] + npdt = x.dtype.as_numpy_dtype + mask = array_ops.one_hot( + indices=cat_samples, # [n, B] + depth=self._num_components, # == k + on_value=np.ones([], dtype=npdt), + off_value=np.zeros([], dtype=npdt)) # [n, B, k] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self._cat, + self._static_event_shape.ndims) # [n, B, k, [1]*e] + return math_ops.reduce_sum( + x * mask, + axis=-1 - self._static_event_shape.ndims) # [n, B, E] + with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") static_n = tensor_util.constant_value(n) diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py index 5558ef0f255db684b229d129666634e50c625887..b93bdc5ab4010663baddda1410b302644853648b 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.python.framework import dtypes +from tensorflow.contrib.distributions.python.ops import distribution_util as distribution_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -43,15 +43,14 @@ class MixtureSameFamily(distribution.Distribution): #### Examples ```python - import matplotlib.pyplot as plt - ds = tf.contrib.distributions + tfd = tf.contrib.distributions ### Create a mixture of two scalar Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.Normal( + components_distribution=tfd.Normal( loc=[-1., 1], # One for each component. scale=[0.1, 0.5])) # And same here. @@ -63,14 +62,15 @@ class MixtureSameFamily(distribution.Distribution): # Plot PDF. x = np.linspace(-2., 3., int(1e4), dtype=np.float32) + import matplotlib.pyplot as plt plt.plot(x, gm.prob(x).eval()); ### Create a mixture of two Bivariate Gaussians: - gm = ds.MixtureSameFamily( - mixture_distribution=ds.Categorical( + gm = tfd.MixtureSameFamily( + mixture_distribution=tfd.Categorical( probs=[0.3, 0.7]), - components_distribution=ds.MultivariateNormalDiag( + components_distribution=tfd.MultivariateNormalDiag( loc=[[-1., 1], # component 1 [1, -1]], # component 2 scale_identity_multiplier=[.3, .6])) @@ -239,7 +239,9 @@ class MixtureSameFamily(distribution.Distribution): depth=self._num_components, # == k on_value=np.ones([], dtype=npdt), off_value=np.zeros([], dtype=npdt)) # [n, B, k] - mask = self._pad_mix_dims(mask) # [n, B, k, [1]*e] + mask = distribution_utils.pad_mixture_dimensions( + mask, self, self.mixture_distribution, + self._event_shape().ndims) # [n, B, k, [1]*e] return math_ops.reduce_sum( x * mask, axis=-1 - self._event_ndims) # [n, B, E] @@ -248,14 +250,15 @@ class MixtureSameFamily(distribution.Distribution): x = self._pad_sample_dims(x) log_prob_x = self.components_distribution.log_prob(x) # [S, B, k] log_mix_prob = nn_ops.log_softmax( - self.mixture_distribution.logits, dim=-1) # [B, k] + self.mixture_distribution.logits, axis=-1) # [B, k] return math_ops.reduce_logsumexp( log_prob_x + log_mix_prob, axis=-1) # [S, B] def _mean(self): with ops.control_dependencies(self._runtime_assertions): - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] return math_ops.reduce_sum( probs * self.components_distribution.mean(), axis=-1 - self._event_ndims) # [B, E] @@ -264,15 +267,16 @@ class MixtureSameFamily(distribution.Distribution): x = self._pad_sample_dims(x) log_cdf_x = self.components_distribution.log_cdf(x) # [S, B, k] log_mix_prob = nn_ops.log_softmax( - self.mixture_distribution.logits, dim=-1) # [B, k] + self.mixture_distribution.logits, axis=-1) # [B, k] return math_ops.reduce_logsumexp( log_cdf_x + log_mix_prob, axis=-1) # [S, B] def _variance(self): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims( - self.mixture_distribution.probs) # [B, k, [1]*e] + probs = distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, [1]*e] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.variance(), axis=-1 - self._event_ndims) # [B, E] @@ -291,8 +295,12 @@ class MixtureSameFamily(distribution.Distribution): with ops.control_dependencies(self._runtime_assertions): # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) - probs = self._pad_mix_dims(self._pad_mix_dims( - self.mixture_distribution.probs)) # [B, k, 1, 1] + probs = distribution_utils.pad_mixture_dimensions( + distribution_utils.pad_mixture_dimensions( + self.mixture_distribution.probs, self, self.mixture_distribution, + self._event_shape().ndims), + self, self.mixture_distribution, + self._event_shape().ndims) # [B, k, 1, 1] mean_cond_var = math_ops.reduce_sum( probs * self.components_distribution.covariance(), axis=-3) # [B, e, e] @@ -312,26 +320,6 @@ class MixtureSameFamily(distribution.Distribution): shape[:d], [1], shape[d:]], axis=0)) return x - def _pad_mix_dims(self, x): - with ops.name_scope("pad_mix_dims", values=[x]): - def _get_ndims(d): - if d.batch_shape.ndims is not None: - return d.batch_shape.ndims - return array_ops.shape(d.batch_shape_tensor())[0] - dist_batch_ndims = _get_ndims(self) - cat_batch_ndims = _get_ndims(self.mixture_distribution) - bnd = distribution_util.pick_vector( - self.mixture_distribution.is_scalar_batch(), - [dist_batch_ndims], [cat_batch_ndims])[0] - s = array_ops.shape(x) - x = array_ops.reshape(x, shape=array_ops.concat([ - s[:-1], - array_ops.ones([bnd], dtype=dtypes.int32), - s[-1:], - array_ops.ones([self._event_ndims], dtype=dtypes.int32), - ], axis=0)) - return x - def _outer_squared_difference(x, y): """Convenience function analogous to tf.squared_difference.""" diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py index 163cf75d990d5fe7ec1e3aaf0040fc71f61774a7..e862552880f4073c8fa8e90134d0633e7484b0bf 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py @@ -84,10 +84,10 @@ class MultivariateNormalDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -101,7 +101,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity Gaussian. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -119,7 +119,7 @@ class MultivariateNormalDiag( mvn.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate Gaussians. - mvn = ds.MultivariateNormalDiag( + mvn = tfd.MultivariateNormalDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index 040bc230722194316b8a74627344e315a2578281..413e88f03ae0286c294f3404549a73e1a47dcff7 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -86,7 +86,7 @@ class MultivariateNormalDiagPlusLowRank( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`, # `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is @@ -97,7 +97,7 @@ class MultivariateNormalDiagPlusLowRank( [-1, 1], [2, -0.5]] # shape: [3, 2] m = [4., 5] # shape: [2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu scale_diag=d scale_perturb_factor=U, @@ -118,7 +118,7 @@ class MultivariateNormalDiagPlusLowRank( m = [[0.1, 0.2], [0.4, 0.5]] # shape: [b, r] = [2, 2] - mvn = ds.MultivariateNormalDiagPlusLowRank( + mvn = tfd.MultivariateNormalDiagPlusLowRank( loc=mu, scale_perturb_factor=U, scale_perturb_diag=m) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index f9952b2069d6dfd2593e6bd71ede0badf44cdf98..4bea99fbb75349f97fde473cb5716fe6c426ce90 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -73,14 +73,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] cov = [[ 0.36, 0.12, 0.06], [ 0.12, 0.29, -0.13], [ 0.06, -0.13, 0.26]] - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance_matrix=cov) @@ -100,7 +100,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite. - mvn = ds.MultivariateNormalFullCovariance( + mvn = tfd.MultivariateNormalFullCovariance( loc=mu, covariance=covariance_matrix) @@ -167,12 +167,11 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: - assert_symmetric = check_ops.assert_equal( - covariance_matrix, - array_ops.matrix_transpose(covariance_matrix), - message="Matrix was not symmetric.") - covariance_matrix = control_flow_ops.with_dependencies( - [assert_symmetric], covariance_matrix) + covariance_matrix = control_flow_ops.with_dependencies([ + check_ops.assert_near( + covariance_matrix, + array_ops.matrix_transpose(covariance_matrix), + message="Matrix was not symmetric")], covariance_matrix) # No need to validate that covariance_matrix is non-singular. # LinearOperatorLowerTriangular has an assert_non_singular method that # is called by the Bijector. diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 300bdd5f6064a1cc9c336689ac4fae04338edb30..a7399792892f4c179c05168184d76ec95c168b51 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -90,8 +90,7 @@ class MultivariateNormalLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -103,9 +102,9 @@ class MultivariateNormalLinearOperator( # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale)) + scale=tf.linalg.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -122,9 +121,9 @@ class MultivariateNormalLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - mvn = ds.MultivariateNormalLinearOperator( + mvn = tfd.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 260dcc18f513d5440d3d39368539274c03faa72a..6c7dc4ca7aaf5b3a20b072e9360d15528ad10556 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -76,12 +76,13 @@ class MultivariateNormalTriL( ``` Trainable (batch) lower-triangular matrices can be created with - `ds.matrix_diag_transform()` and/or `ds.fill_triangular()` + `tf.contrib.distributions.matrix_diag_transform()` and/or + `tf.contrib.distributions.fill_triangular()` #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -92,7 +93,7 @@ class MultivariateNormalTriL( # ==> [[ 0.6, 0. , 0. ], # [ 0.2, 0.5, 0. ], # [ 0.1, -0.3, 0.4]]) - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=scale) @@ -112,7 +113,7 @@ class MultivariateNormalTriL( mu = [[1., 2, 3], [11, 22, 33]] # shape: [2, 3] tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal. - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=mu, scale_tril=tril) @@ -124,9 +125,9 @@ class MultivariateNormalTriL( # Instantiate a "learnable" MVN. dims = 4 with tf.variable_scope("model"): - mvn = ds.MultivariateNormalTriL( + mvn = tfd.MultivariateNormalTriL( loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), - scale_tril=ds.fill_triangular( + scale_tril=tfd.fill_triangular( tf.get_variable(shape=[dims * (dims + 1) / 2], dtype=tf.float32, name="chol_Sigma"))) ``` diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index e1118ed4312ca2ed678a05a298110e2669d0a27e..92f2bba1828696248c9d9460566a08ba372c3358 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -22,21 +22,135 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib +from tensorflow.contrib.distributions.python.ops.bijectors.exp import Exp +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_lib __all__ = [ "PoissonLogNormalQuadratureCompound", + "quadrature_scheme_lognormal_gauss_hermite", + "quadrature_scheme_lognormal_quantiles", ] +def quadrature_scheme_lognormal_gauss_hermite( + loc, scale, quadrature_size, + validate_args=False, name=None): # pylint: disable=unused-argument + """Use Gauss-Hermite quadrature to form quadrature on positive-reals. + + Note: for a given `quadrature_size`, this method is generally less accurate + than `quadrature_scheme_lognormal_quantiles`. + + Args: + loc: `float`-like (batch of) scalar `Tensor`; the location parameter of + the LogNormal prior. + scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of + the LogNormal prior. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: (Batch of) length-`quadrature_size` vectors representing the + `log_rate` parameters of a `Poisson`. + probs: (Batch of) length-`quadrature_size` vectors representing the + weight associate with each `grid` value. + """ + with ops.name_scope(name, "vector_diffeomixture_quadrature_gauss_hermite", + [loc, scale]): + grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) + grid = grid.astype(loc.dtype.as_numpy_dtype) + probs = probs.astype(loc.dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1, keepdims=True) + probs = ops.convert_to_tensor(probs, name="probs", dtype=loc.dtype) + # The following maps the broadcast of `loc` and `scale` to each grid + # point, i.e., we are creating several log-rates that correspond to the + # different Gauss-Hermite quadrature points and (possible) batches of + # `loc` and `scale`. + grid = (loc[..., array_ops.newaxis] + + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) + return grid, probs + + +def quadrature_scheme_lognormal_quantiles( + loc, scale, quadrature_size, + validate_args=False, name=None): + """Use LogNormal quantiles to form quadrature on positive-reals. + + Args: + loc: `float`-like (batch of) scalar `Tensor`; the location parameter of + the LogNormal prior. + scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of + the LogNormal prior. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: (Batch of) length-`quadrature_size` vectors representing the + `log_rate` parameters of a `Poisson`. + probs: (Batch of) length-`quadrature_size` vectors representing the + weight associate with each `grid` value. + """ + with ops.name_scope(name, "quadrature_scheme_lognormal_quantiles", + [loc, scale]): + # Create a LogNormal distribution. + dist = transformed_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=loc, scale=scale), + bijector=Exp(event_ndims=0), + validate_args=validate_args) + batch_ndims = dist.batch_shape.ndims + if batch_ndims is None: + batch_ndims = array_ops.shape(dist.batch_shape_tensor())[0] + + def _compute_quantiles(): + """Helper to build quantiles.""" + # Omit {0, 1} since they might lead to Inf/NaN. + zero = array_ops.zeros([], dtype=dist.dtype) + edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1] + # Expand edges so its broadcast across batch dims. + edges = array_ops.reshape(edges, shape=array_ops.concat([ + [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) + quantiles = dist.quantile(edges) + # Cyclically permute left by one. + perm = array_ops.concat([ + math_ops.range(1, 1 + batch_ndims), [0]], axis=0) + quantiles = array_ops.transpose(quantiles, perm) + return quantiles + quantiles = _compute_quantiles() + + # Compute grid as quantile midpoints. + grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. + # Set shape hints. + grid.set_shape(dist.batch_shape.concatenate([quadrature_size])) + + # By construction probs is constant, i.e., `1 / quadrature_size`. This is + # important, because non-constant probs leads to non-reparameterizable + # samples. + probs = array_ops.fill( + dims=[quadrature_size], + value=1. / math_ops.cast(quadrature_size, dist.dtype)) + + return grid, probs + + class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): """`PoissonLogNormalQuadratureCompound` distribution. @@ -47,30 +161,18 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ```none p(k|loc, scale) = int_{R_+} dl LogNormal(l | loc, scale) Poisson(k | l) - = int_{R} dz ((lambda(z) sqrt(2) scale) - * exp(-z**2) / (lambda(z) sqrt(2 pi) sigma) - * Poisson(k | lambda(z))) - = int_{R} dz exp(-z**2) / sqrt(pi) Poisson(k | lambda(z)) approx= sum{ prob[d] Poisson(k | lambda(grid[d])) : d=0, ..., deg-1 } ``` - where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms - are from [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that - the second line made the substitution: - `z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above] - and `dl = sqrt(2) scale lambda(z) dz` + By default, the `grid` is chosen as quantiles of the `LogNormal` distribution + parameterized by `loc`, `scale` and the `prob` vector is + `[1. / quadrature_size]*quadrature_size`. In the non-approximation case, a draw from the LogNormal prior represents the Poisson rate parameter. Unfortunately, the non-approximate distribution lacks an analytical probability density function (pdf). Therefore the `PoissonLogNormalQuadratureCompound` class implements an approximation based - on [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). + on [quadrature](https://en.wikipedia.org/wiki/Numerical_integration). Note: although the `PoissonLogNormalQuadratureCompound` is approximately the Poisson-LogNormal compound distribution, it is itself a valid distribution. @@ -84,10 +186,8 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): https://en.wikipedia.org/wiki/Compound_probability_distribution). Using variable-substitution and [numerical quadrature]( https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can - redefine the distribution to be a parameter-less convex combination of `deg` - different Poisson samples. + based on `LogNormal` quantiles) we can redefine the distribution to be a + parameter-less convex combination of `deg` different Poisson samples. That is, defined over positive integers, this distribution is parameterized by a (batch of) `loc` and `scale` scalars. @@ -96,46 +196,51 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ```none pdf(k | loc, scale, deg) - = sum{ prob[d] Poisson(k | lambda=exp(sqrt(2) scale grid[d] + loc)) + = sum{ prob[d] Poisson(k | lambda=exp(grid[d])) : d=0, ..., deg-1 } ``` - where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( - https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html) - and `prob = w / sqrt(pi)`. - #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions + # Create two batches of PoissonLogNormalQuadratureCompounds, one with # prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.` - pln = ds.PoissonLogNormalQuadratureCompound( + pln = tfd.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_grid_and_probs=( - np.polynomial.hermite.hermgauss(deg=10)), + quadrature_size=10, validate_args=True) """ def __init__(self, loc, scale, - quadrature_grid_and_probs=None, + quadrature_size=8, + quadrature_fn=quadrature_scheme_lognormal_quantiles, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): - """Constructs the PoissonLogNormalQuadratureCompound on `R**k`. + """Constructs the PoissonLogNormalQuadratureCompound`. + + Note: `probs` returned by (optional) `quadrature_fn` are presumed to be + either a length-`quadrature_size` vector or a batch of vectors in 1-to-1 + correspondence with the returned `grid`. (I.e., broadcasting is only + partially supported.) Args: loc: `float`-like (batch of) scalar `Tensor`; the location parameter of the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. - quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s - representing the sample points and the corresponding (possibly - normalized) weight. When `None`, defaults to: - `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + quadrature_fn: Python callable taking `loc`, `scale`, + `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` + representing the LogNormal grid and corresponding normalized weight. + normalized) weight. + Default value: `quadrature_scheme_lognormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -147,47 +252,41 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): name: Python `str` name prefixed to Ops created by this class. Raises: - TypeError: if `loc.dtype != scale[0].dtype`. + TypeError: if `quadrature_grid` and `quadrature_probs` have different base + `dtype`. """ parameters = locals() with ops.name_scope(name, values=[loc, scale]): - loc = ops.convert_to_tensor(loc, name="loc") - self._loc = loc + if loc is not None: + loc = ops.convert_to_tensor(loc, name="loc") + if scale is not None: + scale = ops.convert_to_tensor( + scale, dtype=None if loc is None else loc.dtype, name="scale") + self._quadrature_grid, self._quadrature_probs = tuple(quadrature_fn( + loc, scale, quadrature_size, validate_args)) + + dt = self._quadrature_grid.dtype + if dt.base_dtype != self._quadrature_probs.dtype.base_dtype: + raise TypeError("Quadrature grid dtype ({}) does not match quadrature " + "probs dtype ({}).".format( + dt.name, self._quadrature_probs.dtype.name)) - scale = ops.convert_to_tensor(scale, name="scale") - self._scale = scale - - dtype = loc.dtype.base_dtype - if dtype != scale.dtype.base_dtype: - raise TypeError( - "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( - loc.dtype.name, scale.dtype.name)) - - grid, probs = distribution_util.process_quadrature_grid_and_probs( - quadrature_grid_and_probs, dtype, validate_args) - self._quadrature_grid = grid - self._quadrature_probs = probs - self._quadrature_size = distribution_util.dimension_size(probs, axis=0) + self._distribution = poisson_lib.Poisson( + log_rate=self._quadrature_grid, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats) self._mixture_distribution = categorical_lib.Categorical( logits=math_ops.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) - # The following maps the broadcast of `loc` and `scale` to each grid - # point, i.e., we are creating several log-rates that correspond to the - # different Gauss-Hermite quadrature points and (possible) batches of - # `loc` and `scale`. - self._log_rate = (loc[..., array_ops.newaxis] - + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) - - self._distribution = poisson_lib.Poisson( - log_rate=self._log_rate, - validate_args=validate_args, - allow_nan_stats=allow_nan_stats) + self._loc = loc + self._scale = scale + self._quadrature_size = quadrature_size super(PoissonLogNormalQuadratureCompound, self).__init__( - dtype=dtype, + dtype=dt, reparameterization_type=distribution_lib.NOT_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, @@ -197,12 +296,12 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): @property def mixture_distribution(self): - """Distribution which randomly selects a Poisson with Gauss-Hermite rate.""" + """Distribution which randomly selects a Poisson with quadrature param.""" return self._mixture_distribution @property def distribution(self): - """Base Poisson parameterized by a Gauss-Hermite grid of rates.""" + """Base Poisson parameterized by a quadrature grid.""" return self._distribution @property @@ -216,24 +315,18 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): return self._scale @property - def quadrature_grid(self): - """Quadrature grid points.""" - return self._quadrature_grid - - @property - def quadrature_probs(self): - """Quadrature normalized weights.""" - return self._quadrature_probs + def quadrature_size(self): + return self._quadrature_size def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( - array_ops.shape(self.loc), - array_ops.shape(self.scale)) + self.distribution.batch_shape_tensor(), + array_ops.shape(self.mixture_distribution.logits))[:-1] def _batch_shape(self): return array_ops.broadcast_static_shape( - self.loc.shape, - self.scale.shape) + self.distribution.batch_shape, + self.mixture_distribution.logits.shape)[:-1] def _event_shape(self): return tensor_shape.scalar() @@ -241,18 +334,31 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): def _sample_n(self, n, seed=None): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. - batch_size = (np.prod(self.batch_shape.as_list(), dtype=np.int32) - if self.batch_shape.is_fully_defined() - else math_ops.reduce_prod(self.batch_shape_tensor())) + batch_size = self.batch_shape.num_elements() + if batch_size is None: + batch_size = math_ops.reduce_prod(self.batch_shape_tensor()) + # We need to "sample extra" from the mixture distribution if it doesn't + # already specify a probs vector for each batch coordinate. + # We only support this kind of reduced broadcasting, i.e., there is exactly + # one probs vector for all batch dims or one for each. ids = self._mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( - self.is_scalar_batch(), - np.int32([]), - [batch_size])), + self.mixture_distribution.is_scalar_batch(), + [batch_size], + np.int32([]))), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) + # We need to flatten batch dims in case mixture_distribution has its own + # batch dims. + ids = array_ops.reshape(ids, shape=concat_vectors( + [n], + distribution_util.pick_vector( + self.is_scalar_batch(), + np.int32([]), + np.int32([-1])))) + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, limit=batch_size * self._quadrature_size, @@ -275,7 +381,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): def _mean(self): return math_ops.exp( math_ops.reduce_logsumexp( - self.mixture_distribution.logits + self._log_rate, + self.mixture_distribution.logits + self.distribution.log_rate, axis=-1)) def _variance(self): @@ -300,7 +406,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): # Var[E[Z | V]] = sum{ prob[d] (Mean[d] - Mean)**2 : d=0, ..., deg-1 } v = array_ops.stack([ # log(self.distribution.variance()) = log(Var[d]) = log(rate[d]) - self._log_rate, + self.distribution.log_rate, # log((Mean[d] - Mean)**2) 2. * math_ops.log( math_ops.abs(self.distribution.mean() @@ -311,14 +417,9 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): axis=[-2, -1]) -def static_value(x): - """Returns the static value of a `Tensor` or `None`.""" - return tensor_util.constant_value(ops.convert_to_tensor(x)) - - def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" - args_ = [static_value(x) for x in args] + args_ = [distribution_util.static_value(x) for x in args] if any(vec is None for vec in args_): return array_ops.concat(args, axis=0) return [val for vec in args_ for val in vec] diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index b6becfa9fc93f189a1a7bf7b2a7af8dc1f2e9720..2aa771a71efe52c8d86d459f090ea8ee137c4487 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -278,7 +278,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): * math_ops.log(self.temperature)) # compute the unnormalized density log_softmax = nn_ops.log_softmax(logits_2d - x_2d * self._temperature_2d) - log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keep_dims=False) + log_unnorm_prob = math_ops.reduce_sum(log_softmax, [-1], keepdims=False) # combine unnormalized density with normalization constant log_prob = log_norm_const + log_unnorm_prob # Reshapes log_prob to be consistent with shape of user-supplied logits diff --git a/tensorflow/contrib/distributions/python/ops/sample_stats.py b/tensorflow/contrib/distributions/python/ops/sample_stats.py index 2a4b92c72900f79785e7e34b77179d3decbace5b..dfc813361977c159d8d48f9d5b9ff03db5b4acdc 100644 --- a/tensorflow/contrib/distributions/python/ops/sample_stats.py +++ b/tensorflow/contrib/distributions/python/ops/sample_stats.py @@ -28,12 +28,190 @@ from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import spectral_ops +from tensorflow.python.ops.distributions import util __all__ = [ + "auto_correlation", "percentile", ] +# TODO(langmore) Write separate versions of this for real/complex dtype, taking +# advantage of optimized real-fft ops. +def auto_correlation( + x, + axis=-1, + max_lags=None, + center=True, + normalize=True, + name="auto_correlation"): + """Auto correlation along one axis. + + Given a `1-D` wide sense stationary (WSS) sequence `X`, the auto correlation + `RXX` may be defined as (with `E` expectation and `Conj` complex conjugate) + + ``` + RXX[m] := E{ W[m] Conj(W[0]) } = E{ W[0] Conj(W[-m]) }, + W[n] := (X[n] - MU) / S, + MU := E{ X[0] }, + S**2 := E{ (X[0] - MU) Conj(X[0] - MU) }. + ``` + + This function takes the viewpoint that `x` is (along one axis) a finite + sub-sequence of a realization of (WSS) `X`, and then uses `x` to produce an + estimate of `RXX[m]` as follows: + + After extending `x` from length `L` to `inf` by zero padding, the auto + correlation estimate `rxx[m]` is computed for `m = 0, 1, ..., max_lags` as + + ``` + rxx[m] := (L - m)**-1 sum_n w[n + m] Conj(w[n]), + w[n] := (x[n] - mu) / s, + mu := L**-1 sum_n x[n], + s**2 := L**-1 sum_n (x[n] - mu) Conj(x[n] - mu) + ``` + + The error in this estimate is proportional to `1 / sqrt(len(x) - m)`, so users + often set `max_lags` small enough so that the entire output is meaningful. + + Note that since `mu` is an imperfect estimate of `E{ X[0] }`, and we divide by + `len(x) - m` rather than `len(x) - m - 1`, our estimate of auto correlation + contains a slight bias, which goes to zero as `len(x) - m --> infinity`. + + Args: + x: `float32` or `complex64` `Tensor`. + axis: Python `int`. The axis number along which to compute correlation. + Other dimensions index different batch members. + max_lags: Positive `int` tensor. The maximum value of `m` to consider + (in equation above). If `max_lags >= x.shape[axis]`, we effectively + re-set `max_lags` to `x.shape[axis] - 1`. + center: Python `bool`. If `False`, do not subtract the mean estimate `mu` + from `x[n]` when forming `w[n]`. + normalize: Python `bool`. If `False`, do not divide by the variance + estimate `s**2` when forming `w[n]`. + name: `String` name to prepend to created ops. + + Returns: + `rxx`: `Tensor` of same `dtype` as `x`. `rxx.shape[i] = x.shape[i]` for + `i != axis`, and `rxx.shape[axis] = max_lags + 1`. + + Raises: + TypeError: If `x` is not a supported type. + """ + # Implementation details: + # Extend length N / 2 1-D array x to length N by zero padding onto the end. + # Then, set + # F[x]_k := sum_n x_n exp{-i 2 pi k n / N }. + # It is not hard to see that + # F[x]_k Conj(F[x]_k) = F[R]_k, where + # R_m := sum_n x_n Conj(x_{(n - m) mod N}). + # One can also check that R_m / (N / 2 - m) is an unbiased estimate of RXX[m]. + + # Since F[x] is the DFT of x, this leads us to a zero-padding and FFT/IFFT + # based version of estimating RXX. + # Note that this is a special case of the Wiener-Khinchin Theorem. + with ops.name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + + # Rotate dimensions of x in order to put axis at the rightmost dim. + # FFT op requires this. + rank = util.prefer_static_rank(x) + if axis < 0: + axis = rank + axis + shift = rank - 1 - axis + # Suppose x.shape[axis] = T, so there are T "time" steps. + # ==> x_rotated.shape = B + [T], + # where B is x_rotated's batch shape. + x_rotated = util.rotate_transpose(x, shift) + + if center: + x_rotated -= math_ops.reduce_mean(x_rotated, axis=-1, keepdims=True) + + # x_len = N / 2 from above explanation. The length of x along axis. + # Get a value for x_len that works in all cases. + x_len = util.prefer_static_shape(x_rotated)[-1] + + # TODO(langmore) Investigate whether this zero padding helps or hurts. At + # the moment is is necessary so that all FFT implementations work. + # Zero pad to the next power of 2 greater than 2 * x_len, which equals + # 2**(ceil(Log_2(2 * x_len))). Note: Log_2(X) = Log_e(X) / Log_e(2). + x_len_float64 = math_ops.cast(x_len, np.float64) + target_length = math_ops.pow( + np.float64(2.), + math_ops.ceil(math_ops.log(x_len_float64 * 2) / np.log(2.))) + pad_length = math_ops.cast(target_length - x_len_float64, np.int32) + + # We should have: + # x_rotated_pad.shape = x_rotated.shape[:-1] + [T + pad_length] + # = B + [T + pad_length] + x_rotated_pad = util.pad(x_rotated, axis=-1, back=True, count=pad_length) + + dtype = x.dtype + if not dtype.is_complex: + if not dtype.is_floating: + raise TypeError("Argument x must have either float or complex dtype" + " found: {}".format(dtype)) + x_rotated_pad = math_ops.complex(x_rotated_pad, + dtype.real_dtype.as_numpy_dtype(0.)) + + # Autocorrelation is IFFT of power-spectral density (up to some scaling). + fft_x_rotated_pad = spectral_ops.fft(x_rotated_pad) + spectral_density = fft_x_rotated_pad * math_ops.conj(fft_x_rotated_pad) + # shifted_product is R[m] from above detailed explanation. + # It is the inner product sum_n X[n] * Conj(X[n - m]). + shifted_product = spectral_ops.ifft(spectral_density) + + # Cast back to real-valued if x was real to begin with. + shifted_product = math_ops.cast(shifted_product, dtype) + + # Figure out if we can deduce the final static shape, and set max_lags. + # Use x_rotated as a reference, because it has the time dimension in the far + # right, and was created before we performed all sorts of crazy shape + # manipulations. + know_static_shape = True + if not x_rotated.shape.is_fully_defined(): + know_static_shape = False + if max_lags is None: + max_lags = x_len - 1 + else: + max_lags = ops.convert_to_tensor(max_lags, name="max_lags") + max_lags_ = tensor_util.constant_value(max_lags) + if max_lags_ is None or not know_static_shape: + know_static_shape = False + max_lags = math_ops.minimum(x_len - 1, max_lags) + else: + max_lags = min(x_len - 1, max_lags_) + + # Chop off the padding. + # We allow users to provide a huge max_lags, but cut it off here. + # shifted_product_chopped.shape = x_rotated.shape[:-1] + [max_lags] + shifted_product_chopped = shifted_product[..., :max_lags + 1] + + # If possible, set shape. + if know_static_shape: + chopped_shape = x_rotated.shape.as_list() + chopped_shape[-1] = min(x_len, max_lags + 1) + shifted_product_chopped.set_shape(chopped_shape) + + # Recall R[m] is a sum of N / 2 - m nonzero terms x[n] Conj(x[n - m]). The + # other terms were zeros arising only due to zero padding. + # `denominator = (N / 2 - m)` (defined below) is the proper term to + # divide by by to make this an unbiased estimate of the expectation + # E[X[n] Conj(X[n - m])]. + x_len = math_ops.cast(x_len, dtype.real_dtype) + max_lags = math_ops.cast(max_lags, dtype.real_dtype) + denominator = x_len - math_ops.range(0., max_lags + 1.) + denominator = math_ops.cast(denominator, dtype) + shifted_product_rotated = shifted_product_chopped / denominator + + if normalize: + shifted_product_rotated /= shifted_product_rotated[..., :1] + + # Transpose dimensions back to those of x. + return util.rotate_transpose(shifted_product_rotated, -shift) + + # TODO(langmore) To make equivalent to numpy.percentile: # Make work with a sequence of floats or single float for 'q'. # Make work with "linear", "midpoint" interpolation. (linear should be default) diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py index b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1..c4b8f055b7fbc3f0835b503eddd7617610326d8c 100644 --- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -115,7 +115,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution): tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) distribution: `tf.Distribution`-like instance. Distribution that is transformed to produce this distribution. - Default is `ds.Normal(0., 1.)`. + Default is `tf.distributions.Normal(0., 1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index 77f2a39273dc365a4ac202d846dd2bc364655c86..15b0820cbdf560e04a304c40a47e541006523b6d 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -40,6 +40,7 @@ class DiscreteScalarDistributionTestHelpers(object): def run_test_sample_consistent_log_prob( self, sess_run_fn, dist, num_samples=int(1e5), num_threshold=int(1e3), seed=42, + batch_size=None, rtol=1e-2, atol=0.): """Tests that sample/log_prob are consistent with each other. @@ -66,6 +67,8 @@ class DiscreteScalarDistributionTestHelpers(object): seed: Python `int` indicating the seed to use when sampling from `dist`. In general it is not recommended to use `None` during a test as this increases the likelihood of spurious test failure. + batch_size: Hint for unpacking result of samples. Default: `None` means + batch_size is inferred. rtol: Python `float`-type indicating the admissible relative error between analytical and sample statistics. atol: Python `float`-type indicating the admissible absolute error between @@ -80,10 +83,11 @@ class DiscreteScalarDistributionTestHelpers(object): # Histogram only supports vectors so we call it once per batch coordinate. y = dist.sample(num_samples, seed=seed) y = array_ops.reshape(y, shape=[num_samples, -1]) - batch_size = math_ops.reduce_prod(dist.batch_shape_tensor()) + if batch_size is None: + batch_size = math_ops.reduce_prod(dist.batch_shape_tensor()) batch_dims = array_ops.shape(dist.batch_shape_tensor())[0] edges_expanded_shape = 1 + array_ops.pad([-2], paddings=[[0, batch_dims]]) - for b, x in enumerate(array_ops.unstack(y, axis=1)): + for b, x in enumerate(array_ops.unstack(y, num=batch_size, axis=1)): counts, edges = self.histogram(x) edges = array_ops.reshape(edges, edges_expanded_shape) probs = math_ops.exp(dist.log_prob(edges)) @@ -323,7 +327,7 @@ class VectorDistributionTestHelpers(object): num_samples=int(1e5), seed=24, rtol=1e-2, - atol=0., + atol=0.1, cov_rtol=None, cov_atol=None): """Tests that sample/mean/covariance are consistent with each other. diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 92043d6a08833888c36009261addca0d14949ea8..0c747f8e68529484ae6f695b8500cde74857bb11 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -22,141 +22,237 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator +from tensorflow.contrib.distributions.python.ops.bijectors.softmax_centered import SoftmaxCentered from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.distributions import normal as normal_lib from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib -static_value = distribution_util.static_value - __all__ = [ "VectorDiffeomixture", + "quadrature_scheme_softmaxnormal_gauss_hermite", + "quadrature_scheme_softmaxnormal_quantiles", ] -class VectorDiffeomixture(distribution_lib.Distribution): - """VectorDiffeomixture distribution. +def quadrature_scheme_softmaxnormal_gauss_hermite( + normal_loc, normal_scale, quadrature_size, + validate_args=False, name=None): + """Use Gauss-Hermite quadrature to form quadrature on `K - 1` simplex. - The VectorDiffeomixture is an approximation to a [compound distribution]( - https://en.wikipedia.org/wiki/Compound_probability_distribution), i.e., + A `SoftmaxNormal` random variable `Y` may be generated via - ```none - p(x) = int_{X} q(x | v) p(v) dv - = lim_{Q->infty} sum{ prob[i] q(x | loc=sum_k^K lambda[k;i] loc[k], - scale=sum_k^K lambda[k;i] scale[k]) - : i=0, ..., Q-1 } + ``` + Y = SoftmaxCentered(X), + X = Normal(normal_loc, normal_scale) ``` - where `q(x | v)` is a vector version of the `distribution` argument and `p(v)` - is a SoftmaxNormal parameterized by `mix_loc` and `mix_scale`. The - vector-ization of `distribution` entails an affine transformation of iid - samples from `distribution`. The `prob` term is from quadrature and - `lambda[k] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[k])` where the - `grid` points correspond to the `prob`s. - - In the non-approximation case, a draw from the mixture distribution (the - "prior") represents the convex weights for different affine transformations. - I.e., draw a mixing vector `v` (from the `K-1`-simplex) and let the final - sample be: `y = (sum_k^K v[k] scale[k]) @ x + (sum_k^K v[k] loc[k])` where `@` - denotes matrix multiplication. However, the non-approximate distribution does - not have an analytical probability density function (pdf). Therefore the - `VectorDiffeomixture` class implements an approximation based on - [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in - Note: although the `VectorDiffeomixture` is approximately the - `SoftmaxNormal-Distribution` compound distribution, it is itself a valid - distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which - are all mutually consistent. - - #### Intended Use - - This distribution is noteworthy because it implements a mixture of - `Vector`-ized distributions yet has samples differentiable in the - distribution's parameters (aka "reparameterized"). It has an analytical - density function with `O(dKQ)` complexity. `d` is the vector dimensionality, - `K` is the number of components, and `Q` is the number of quadrature points. - These properties make it well-suited for Bayesian Variational Inference, i.e., - as a surrogate family for the posterior. - - For large values of `mix_scale`, the `VectorDistribution` behaves increasingly - like a discrete mixture. (In most cases this limit is only achievable by also - increasing the quadrature polynomial degree, `Q`.) - - The term `Vector` is consistent with similar named Tensorflow `Distribution`s. - For more details, see the "About `Vector` distributions in Tensorflow." - section. - - The term `Diffeomixture` is a portmanteau of - [diffeomorphism](https://en.wikipedia.org/wiki/Diffeomorphism) and [compound - mixture](https://en.wikipedia.org/wiki/Compound_probability_distribution). For - more details, see the "About `Diffeomixture`s and reparametrization.`" - section. - - #### Mathematical Details - - The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior") - [compound distribution]( - https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [numerical quadrature]( - https://en.wikipedia.org/wiki/Numerical_integration) (default: - [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can - redefine the distribution to be a parameter-less convex combination of `K` - different affine combinations of a `d` iid samples from `distribution`. - - That is, defined over `R**d` this distribution is parameterized by a - (batch of) length-`K` `mix_loc` and `mix_scale` vectors, a length-`K` list of - (a batch of) length-`d` `loc` vectors, and a length-`K` list of `scale` - `LinearOperator`s each operating on a (batch of) length-`d` vector space. - Finally, a `distribution` parameter specifies the underlying base distribution - which is "lifted" to become multivariate ("lifting" is the same concept as in - `TransformedDistribution`). - - The probability density function (pdf) is, + Note: for a given `quadrature_size`, this method is generally less accurate + than `quadrature_scheme_softmaxnormal_quantiles`. + + Args: + normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + The location parameter of the Normal used to construct the SoftmaxNormal. + normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. + The scale parameter of the Normal used to construct the SoftmaxNormal. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + convex combination of affine parameters for `K` components. + `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. + probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + associated with each grid point. + """ + with ops.name_scope(name, "quadrature_scheme_softmaxnormal_gauss_hermite", + [normal_loc, normal_scale]): + normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc") + dt = normal_loc.dtype.base_dtype + normal_scale = ops.convert_to_tensor( + normal_scale, dtype=dt, name="normal_scale") - ```none - pdf(y; mix_loc, mix_scale, loc, scale, phi) - = sum{ prob[i] phi(f_inverse(x; i)) / abs(det(interp_scale[i])) - : i=0, ..., Q-1 } + normal_scale = maybe_check_quadrature_param( + normal_scale, "normal_scale", validate_args) + + grid, probs = np.polynomial.hermite.hermgauss(deg=quadrature_size) + grid = grid.astype(dt.dtype.as_numpy_dtype) + probs = probs.astype(dt.dtype.as_numpy_dtype) + probs /= np.linalg.norm(probs, ord=1, keepdims=True) + probs = ops.convert_to_tensor(probs, name="probs", dtype=dt) + + grid = softmax( + -distribution_util.pad( + (normal_loc[..., array_ops.newaxis] + + np.sqrt(2.) * normal_scale[..., array_ops.newaxis] * grid), + axis=-2, + front=True), + axis=-2) # shape: [B, components, deg] + + return grid, probs + + +def quadrature_scheme_softmaxnormal_quantiles( + normal_loc, normal_scale, quadrature_size, + validate_args=False, name=None): + """Use SoftmaxNormal quantiles to form quadrature on `K - 1` simplex. + + A `SoftmaxNormal` random variable `Y` may be generated via + + ``` + Y = SoftmaxCentered(X), + X = Normal(normal_loc, normal_scale) ``` - where, `phi` is the base distribution pdf, and, + Args: + normal_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`, B>=0. + The location parameter of the Normal used to construct the SoftmaxNormal. + normal_scale: `float`-like `Tensor`. Broadcastable with `normal_loc`. + The scale parameter of the Normal used to construct the SoftmaxNormal. + quadrature_size: Python `int` scalar representing the number of quadrature + points. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + name: Python `str` name prefixed to Ops created by this class. + + Returns: + grid: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + convex combination of affine parameters for `K` components. + `grid[..., :, n]` is the `n`-th grid point, living in the `K - 1` simplex. + probs: Shape `[b1, ..., bB, K, quadrature_size]` `Tensor` representing the + associated with each grid point. + """ + with ops.name_scope(name, "softmax_normal_grid_and_probs", + [normal_loc, normal_scale]): + normal_loc = ops.convert_to_tensor(normal_loc, name="normal_loc") + dt = normal_loc.dtype.base_dtype + normal_scale = ops.convert_to_tensor( + normal_scale, dtype=dt, name="normal_scale") + + normal_scale = maybe_check_quadrature_param( + normal_scale, "normal_scale", validate_args) + + dist = normal_lib.Normal(loc=normal_loc, scale=normal_scale) + + def _get_batch_ndims(): + """Helper to get dist.batch_shape.ndims, statically if possible.""" + ndims = dist.batch_shape.ndims + if ndims is None: + ndims = array_ops.shape(dist.batch_shape_tensor())[0] + return ndims + batch_ndims = _get_batch_ndims() + + def _get_final_shape(qs): + """Helper to build `TensorShape`.""" + bs = dist.batch_shape.with_rank_at_least(1) + num_components = bs[-1].value + if num_components is not None: + num_components += 1 + tail = tensor_shape.TensorShape([num_components, qs]) + return bs[:-1].concatenate(tail) + + def _compute_quantiles(): + """Helper to build quantiles.""" + # Omit {0, 1} since they might lead to Inf/NaN. + zero = array_ops.zeros([], dtype=dist.dtype) + edges = math_ops.linspace(zero, 1., quadrature_size + 3)[1:-1] + # Expand edges so its broadcast across batch dims. + edges = array_ops.reshape(edges, shape=array_ops.concat([ + [-1], array_ops.ones([batch_ndims], dtype=dtypes.int32)], axis=0)) + quantiles = dist.quantile(edges) + quantiles = SoftmaxCentered(event_ndims=1).forward(quantiles) + # Cyclically permute left by one. + perm = array_ops.concat([ + math_ops.range(1, 1 + batch_ndims), [0]], axis=0) + quantiles = array_ops.transpose(quantiles, perm) + quantiles.set_shape(_get_final_shape(quadrature_size + 1)) + return quantiles + quantiles = _compute_quantiles() + + # Compute grid as quantile midpoints. + grid = (quantiles[..., :-1] + quantiles[..., 1:]) / 2. + # Set shape hints. + grid.set_shape(_get_final_shape(quadrature_size)) + + # By construction probs is constant, i.e., `1 / quadrature_size`. This is + # important, because non-constant probs leads to non-reparameterizable + # samples. + probs = array_ops.fill( + dims=[quadrature_size], + value=1. / math_ops.cast(quadrature_size, dist.dtype)) + + return grid, probs + + +class VectorDiffeomixture(distribution_lib.Distribution): + """VectorDiffeomixture distribution. + + A vector diffeomixture (VDM) is a distribution parameterized by a convex + combination of `K` component `loc` vectors, `loc[k], k = 0,...,K-1`, and `K` + `scale` matrices `scale[k], k = 0,..., K-1`. It approximates the following + [compound distribution] + (https://en.wikipedia.org/wiki/Compound_probability_distribution) ```none - f_inverse(x; i) = inv(interp_scale[i]) @ (x - interp_loc[i]) - interp_loc[i] = sum{ lambda[k; i] loc[k] : k=0, ..., K-1 } - interp_scale[i] = sum{ lambda[k; i] scale[k] : k=0, ..., K-1 } + p(x) = int p(x | z) p(z) dz, + where z is in the K-simplex, and + p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) ``` - and, + The integral `int p(x | z) p(z) dz` is approximated with a quadrature scheme + adapted to the mixture density `p(z)`. The `N` quadrature points `z_{N, n}` + and weights `w_{N, n}` (which are non-negative and sum to 1) are chosen + such that - ```none - grid, weight = np.polynomial.hermite.hermgauss(quadrature_size) - prob[k] = weight[k] / sqrt(pi) - lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i]) + ```q_N(x) := sum_{n=1}^N w_{n, N} p(x | z_{N, n}) --> p(x)``` + + as `N --> infinity`. + + Since `q_N(x)` is in fact a mixture (of `N` points), we may sample from + `q_N` exactly. It is important to note that the VDM is *defined* as `q_N` + above, and *not* `p(x)`. Therefore, sampling and pdf may be implemented as + exact (up to floating point error) methods. + + A common choice for the conditional `p(x | z)` is a multivariate Normal. + + The implemented marginal `p(z)` is the `SoftmaxNormal`, which is a + `K-1` dimensional Normal transformed by a `SoftmaxCentered` bijector, making + it a density on the `K`-simplex. That is, + + ``` + Z = SoftmaxCentered(X), + X = Normal(mix_loc / temperature, 1 / temperature) ``` - The distribution corresponding to `phi` must be a scalar-batch, scalar-event - distribution. Typically it is reparameterized. If not, it must be a function - of non-trainable parameters. + The default quadrature scheme chooses `z_{N, n}` as `N` midpoints of + the quantiles of `p(z)` (generalized quantiles if `K > 2`). - WARNING: If you backprop through a VectorDiffeomixture sample and the "base" - distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable - variables, then the gradient is not guaranteed correct! + See [1] for more details. + + [1]. "Quadrature Compound: An approximating family of distributions" + Joshua Dillon, Ian Langmore, arXiv preprints + https://arxiv.org/abs/1801.03080 #### About `Vector` distributions in TensorFlow. @@ -164,12 +260,11 @@ class VectorDiffeomixture(distribution_lib.Distribution): particularly useful in [variational Bayesian methods](https://en.wikipedia.org/wiki/Variational_Bayesian_methods). - Conditioned on a draw from the SoftmaxNormal, `Y|v` is a vector whose + Conditioned on a draw from the SoftmaxNormal, `X|z` is a vector whose components are linear combinations of affine transformations, thus is itself - an affine transformation. Therefore `Y|v` lives in the vector space generated - by vectors of affine-transformed distributions. + an affine transformation. - Note: The marginals `Y_1|v, ..., Y_d|v` are *not* generally identical to some + Note: The marginals `X_1|v, ..., X_d|v` are *not* generally identical to some parameterization of `distribution`. This is due to the fact that the sum of draws from `distribution` are not generally itself the same `distribution`. @@ -185,32 +280,35 @@ class VectorDiffeomixture(distribution_lib.Distribution): optimize Monte-Carlo objectives. Such objectives are a finite-sample approximation of an expectation and arise throughout scientific computing. + WARNING: If you backprop through a VectorDiffeomixture sample and the "base" + distribution is both: not `FULLY_REPARAMETERIZED` and a function of trainable + variables, then the gradient is not guaranteed correct! + #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions - # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and + # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.], # another with mix_loc=[1]. In both cases, `K=2` and the affine # transformations involve: # k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity # k=1: loc=[2.]*dims scale=LinOpDiag dims = 5 - vdm = ds.VectorDiffeomixture( + vdm = tfd.VectorDiffeomixture( mix_loc=[[0.], [1]], - mix_scale=[1.], - distribution=ds.Normal(loc=0., scale=1.), + temperature=[1.], + distribution=tfd.Normal(loc=0., scale=1.), loc=[ None, # Equivalent to `np.zeros(dims, dtype=np.float32)`. np.float32([2.]*dims), ], scale=[ - la.LinearOperatorScaledIdentity( + tf.linalg.LinearOperatorScaledIdentity( num_rows=dims, multiplier=np.float32(1.1), is_positive_definite=True), - la.LinearOperatorDiag( + tf.linalg.LinearOperatorDiag( diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), is_positive_definite=True), ], @@ -219,21 +317,33 @@ class VectorDiffeomixture(distribution_lib.Distribution): def __init__(self, mix_loc, - mix_scale, + temperature, distribution, loc=None, scale=None, - quadrature_grid_and_probs=None, + quadrature_size=8, + quadrature_fn=quadrature_scheme_softmaxnormal_quantiles, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): - """Constructs the VectorDiffeomixture on `R**k`. + """Constructs the VectorDiffeomixture on `R^d`. + + The vector diffeomixture (VDM) approximates the compound distribution + + ```none + p(x) = int p(x | z) p(z) dz, + where z is in the K-simplex, and + p(x | z) := p(x | loc=sum_k z[k] loc[k], scale=sum_k z[k] scale[k]) + ``` Args: - mix_loc: `float`-like `Tensor`. Represents the `location` parameter of the - SoftmaxNormal used for selecting one of the `K` affine transformations. - mix_scale: `float`-like `Tensor`. Represents the `scale` parameter of the - SoftmaxNormal used for selecting one of the `K` affine transformations. + mix_loc: `float`-like `Tensor` with shape `[b1, ..., bB, K-1]`. + In terms of samples, larger `mix_loc[..., k]` ==> + `Z` is more likely to put more weight on its `kth` component. + temperature: `float`-like `Tensor`. Broadcastable with `mix_loc`. + In terms of samples, smaller `temperature` means one component is more + likely to dominate. I.e., smaller `temperature` makes the VDM look more + like a standard mixture of `K` components. distribution: `tf.Distribution`-like instance. Distribution from which `d` iid samples are used as input to the selected affine transformation. Must be a scalar-batch, scalar-event distribution. Typically @@ -252,10 +362,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices - quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s - representing the sample points and the corresponding (possibly - normalized) weight. When `None`, defaults to: - `np.polynomial.hermite.hermgauss(deg=8)`. + quadrature_size: Python `int` scalar representing number of + quadrature points. Larger `quadrature_size` means `q_N(x)` better + approximates `p(x)`. + quadrature_fn: Python callable taking `normal_loc`, `normal_scale`, + `quadrature_size`, `validate_args` and returning `tuple(grid, probs)` + representing the SoftmaxNormal grid and corresponding normalized weight. + normalized) weight. + Default value: `quadrature_scheme_softmaxnormal_quantiles`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -279,7 +393,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ValueError: if `not distribution.is_scalar_event`. """ parameters = locals() - with ops.name_scope(name, values=[mix_loc, mix_scale]): + with ops.name_scope(name, values=[mix_loc, temperature]): if not scale or len(scale) < 2: raise ValueError("Must specify list (or list-like object) of scale " "LinearOperators, one for each component with " @@ -322,11 +436,15 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - grid, probs = distribution_util.process_quadrature_grid_and_probs( - quadrature_grid_and_probs, dtype, validate_args) - self._quadrature_grid = grid - self._quadrature_probs = probs - self._quadrature_size = distribution_util.dimension_size(probs, axis=0) + mix_loc = ops.convert_to_tensor( + mix_loc, dtype=dtype, name="mix_loc") + temperature = ops.convert_to_tensor( + temperature, dtype=dtype, name="temperature") + self._grid, probs = tuple(quadrature_fn( + mix_loc / temperature, + 1. / temperature, + quadrature_size, + validate_args)) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to @@ -336,22 +454,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, allow_nan_stats=allow_nan_stats) - mix_loc = maybe_check_mix_param( - mix_loc, "mix_loc", dtype, validate_args) - mix_scale = maybe_check_mix_param( - mix_scale, "mix_scale", dtype, validate_args) - asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) if asserts: - mix_loc = control_flow_ops.with_dependencies(asserts, mix_loc) + self._grid = control_flow_ops.with_dependencies( + asserts, self._grid) self._distribution = distribution - # shape: [B, deg] - self._interpolate_weight = math_ops.sigmoid( - mix_loc - + np.sqrt(2.) * mix_scale * grid) - self._interpolated_affine = [ AffineLinearOperator(shift=loc_, scale=scale_, @@ -359,15 +468,16 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( - interpolate_loc(self._quadrature_size, - self._interpolate_weight, - loc), - interpolate_scale(self._quadrature_size, - self._interpolate_weight, - scale)))] + interpolate_loc(self._grid, loc), + interpolate_scale(self._grid, scale)))] - self._batch_shape_, self._event_shape_ = determine_batch_event_shapes( - mix_loc, mix_scale, self._endpoint_affine) + [ + self._batch_shape_, + self._batch_shape_tensor_, + self._event_shape_, + self._event_shape_tensor_, + ] = determine_batch_event_shapes(self._grid, + self._endpoint_affine) super(VectorDiffeomixture, self).__init__( dtype=dtype, @@ -386,8 +496,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): allow_nan_stats=allow_nan_stats, parameters=parameters, graph_parents=( - [mix_loc, mix_scale] - + distribution._graph_parents # pylint: disable=protected-access + distribution._graph_parents # pylint: disable=protected-access + [loc_ for loc_ in loc if loc_ is not None] + [p for scale_ in scale for p in scale_.graph_parents]), name=name) @@ -403,9 +512,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): return self._distribution @property - def interpolate_weight(self): + def grid(self): """Grid of mixing probabilities, one for each grid point.""" - return self._interpolate_weight + return self._grid @property def endpoint_affine(self): @@ -417,27 +526,17 @@ class VectorDiffeomixture(distribution_lib.Distribution): """Affine transformation for each convex combination of `K` components.""" return self._interpolated_affine - @property - def quadrature_grid(self): - """Quadrature grid points.""" - return self._quadrature_grid - - @property - def quadrature_probs(self): - """Quadrature normalized weights.""" - return self._quadrature_probs - def _batch_shape_tensor(self): - return self._batch_shape_ + return self._batch_shape_tensor_ def _batch_shape(self): - return tensor_shape.TensorShape(static_value(self._batch_shape_)) + return self._batch_shape_ def _event_shape_tensor(self): - return self._event_shape_ + return self._event_shape_tensor_ def _event_shape(self): - return tensor_shape.TensorShape(static_value(self._event_shape_)) + return self._event_shape_ def _sample_n(self, n, seed=None): x = self.distribution.sample( @@ -450,27 +549,53 @@ class VectorDiffeomixture(distribution_lib.Distribution): # Get ids as a [n, batch_size]-shaped matrix, unless batch_shape=[] then get # ids as a [n]-shaped vector. - batch_size = reduce_prod(self.batch_shape_tensor()) - ids = self._mixture_distribution.sample( + batch_size = self.batch_shape.num_elements() + if batch_size is None: + batch_size = array_ops.reduce_prod(self.batch_shape_tensor()) + mix_batch_size = self.mixture_distribution.batch_shape.num_elements() + if mix_batch_size is None: + mix_batch_size = math_ops.reduce_prod( + self.mixture_distribution.batch_shape_tensor()) + ids = self.mixture_distribution.sample( sample_shape=concat_vectors( [n], distribution_util.pick_vector( self.is_scalar_batch(), np.int32([]), - [batch_size])), + [batch_size // mix_batch_size])), seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) - - # Stride `quadrature_size` for `batch_size` number of times. + # We need to flatten batch dims in case mixture_distribution has its own + # batch dims. + ids = array_ops.reshape(ids, shape=concat_vectors( + [n], + distribution_util.pick_vector( + self.is_scalar_batch(), + np.int32([]), + np.int32([-1])))) + + # Stride `components * quadrature_size` for `batch_size` number of times. + stride = self.grid.shape.with_rank_at_least( + 2)[-2:].num_elements() + if stride is None: + stride = array_ops.reduce_prod( + array_ops.shape(self.grid)[-2:]) offset = math_ops.range(start=0, - limit=batch_size * self._quadrature_size, - delta=self._quadrature_size, + limit=batch_size * stride, + delta=stride, dtype=ids.dtype) weight = array_ops.gather( - array_ops.reshape(self.interpolate_weight, shape=[-1]), + array_ops.reshape(self.grid, shape=[-1]), ids + offset) - weight = weight[..., array_ops.newaxis] + # At this point, weight flattened all batch dims into one. + # We also need to append a singleton to broadcast with event dims. + if self.batch_shape.is_fully_defined(): + new_shape = [-1] + self.batch_shape.as_list() + [1] + else: + new_shape = array_ops.concat( + ([-1], self.batch_shape_tensor(), [1]), axis=0) + weight = array_ops.reshape(weight, shape=new_shape) if len(x) != 2: # We actually should have already triggered this exception. However as a @@ -500,10 +625,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): self.mixture_distribution.logits - fldj + log_prob, axis=-1) def _mean(self): - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) - + p = self._expand_mix_distribution_probs() m = self._expand_base_distribution_mean() mean = None for k, aff in enumerate(self.interpolated_affine): @@ -537,13 +659,11 @@ class VectorDiffeomixture(distribution_lib.Distribution): self._covariance_of_mean_given_quadrature_component(diag_only=True)) def _mean_of_covariance_given_quadrature_component(self, diag_only): - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) + p = self.mixture_distribution.probs # To compute E[Cov(Z|V)], we'll add matrices within three categories: # scaled-identity, diagonal, and full. Then we'll combine these at the end. - scaled_identity = None + scale_identity_multiplier = None diag = None full = None @@ -551,10 +671,12 @@ class VectorDiffeomixture(distribution_lib.Distribution): s = aff.scale # Just in case aff.scale has side-effects, we'll call once. if (s is None or isinstance(s, linop_identity_lib.LinearOperatorIdentity)): - scaled_identity = add(scaled_identity, p[..., k, array_ops.newaxis]) + scale_identity_multiplier = add(scale_identity_multiplier, + p[..., k, array_ops.newaxis]) elif isinstance(s, linop_identity_lib.LinearOperatorScaledIdentity): - scaled_identity = add(scaled_identity, (p[..., k, array_ops.newaxis] * - math_ops.square(s.multiplier))) + scale_identity_multiplier = add( + scale_identity_multiplier, + (p[..., k, array_ops.newaxis] * math_ops.square(s.multiplier))) elif isinstance(s, linop_diag_lib.LinearOperatorDiag): diag = add(diag, (p[..., k, array_ops.newaxis] * math_ops.square(s.diag_part()))) @@ -566,12 +688,13 @@ class VectorDiffeomixture(distribution_lib.Distribution): full = add(full, x) # We must now account for the fact that the base distribution might have a - # non-unity variance. Recall that `Cov(SX+m) = S.T Cov(X) S = S.T S Var(X)`. + # non-unity variance. Recall that, since X ~ iid Law(X_0), + # `Cov(SX+m) = S Cov(X) S.T = S S.T Diag(Var(X_0))`. # We can scale by `Var(X)` (vs `Cov(X)`) since X corresponds to `d` iid # samples from a scalar-event distribution. v = self.distribution.variance() - if scaled_identity is not None: - scaled_identity *= v + if scale_identity_multiplier is not None: + scale_identity_multiplier *= v if diag is not None: diag *= v[..., array_ops.newaxis] if full is not None: @@ -580,10 +703,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): if diag_only: # Apparently we don't need the full matrix, just the diagonal. r = add(diag, full) - if r is None and scaled_identity is not None: + if r is None and scale_identity_multiplier is not None: ones = array_ops.ones(self.event_shape_tensor(), dtype=self.dtype) - return scaled_identity * ones - return add(r, scaled_identity) + return scale_identity_multiplier[..., array_ops.newaxis] * ones + return add(r, scale_identity_multiplier) # `None` indicates we don't know if the result is positive-definite. is_positive_definite = (True if all(aff.scale.is_positive_definite @@ -599,10 +722,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): to_add.append(linop_full_lib.LinearOperatorFullMatrix( matrix=full, is_positive_definite=is_positive_definite)) - if scaled_identity is not None: + if scale_identity_multiplier is not None: to_add.append(linop_identity_lib.LinearOperatorScaledIdentity( num_rows=self.event_shape_tensor()[0], - multiplier=scaled_identity, + multiplier=scale_identity_multiplier, is_positive_definite=is_positive_definite)) return (linop_add_lib.add_operators(to_add)[0].to_dense() @@ -611,10 +734,9 @@ class VectorDiffeomixture(distribution_lib.Distribution): def _covariance_of_mean_given_quadrature_component(self, diag_only): square = math_ops.square if diag_only else vec_osquare - # Since we created logits to already be scaled, we can use exp which is - # slightly cheaper than `self.mixture_distribution.probs`. - p = math_ops.exp(self.mixture_distribution.logits) - + p = self._expand_mix_distribution_probs() + if not diag_only: + p = p[..., array_ops.newaxis, :] # Assuming event.ndims=1. m = self._expand_base_distribution_mean() cov_e_z_given_v = None @@ -638,17 +760,25 @@ class VectorDiffeomixture(distribution_lib.Distribution): m.set_shape(self.batch_shape.concatenate(self.event_shape)) return m - -def maybe_check_mix_param(param, name, expected_base_dtype, validate_args): - """Helper which checks validity of `mix_loc` and `mix_scale` init args.""" + def _expand_mix_distribution_probs(self): + p = self.mixture_distribution.probs # [B, deg] + deg = p.shape.with_rank_at_least(1)[-1].value + if deg is None: + deg = array_ops.shape(p)[-1] + event_ndims = self.event_shape.ndims + if event_ndims is None: + event_ndims = array_ops.shape(self.event_shape_tensor())[0] + expand_shape = array_ops.concat([ + self.mixture_distribution.batch_shape_tensor(), + array_ops.ones([event_ndims], dtype=dtypes.int32), + [deg], + ], axis=0) + return array_ops.reshape(p, shape=expand_shape) + + +def maybe_check_quadrature_param(param, name, validate_args): + """Helper which checks validity of `loc` and `scale` init args.""" with ops.name_scope(name="check_" + name, values=[param]): - param = ops.convert_to_tensor(param, dtype=expected_base_dtype, name=name) - - if param.dtype.base_dtype != expected_base_dtype: - raise TypeError( - "dtype mismatch; {}.base_dtype=\"{}\" is not \"{}\".".format( - name, param.dtype.base_dtype.name, expected_base_dtype.name)) - assertions = [] if param.shape.ndims is not None: if param.shape.ndims == 0: @@ -679,79 +809,84 @@ def maybe_check_mix_param(param, name, expected_base_dtype, validate_args): return param -def determine_batch_event_shapes(mix_loc, mix_scale, endpoint_affine): +def determine_batch_event_shapes(grid, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): - mix_batch_shape = distribution_util.prefer_static_broadcast_shape( - array_ops.shape(mix_loc, name="mix_loc_shape"), - array_ops.shape(mix_scale, name="mix_scale_shape")) - if isinstance(mix_batch_shape, tensor_shape.TensorShape): - mix_batch_shape = mix_batch_shape.with_rank_at_least(1)[:-1] - else: - s = static_value(mix_batch_shape) - if s is not None: - mix_batch_shape = ops.convert_to_tensor( - s[:-1], dtype=dtypes.int32, name="mix_batch_shape") - else: - mix_batch_shape = mix_batch_shape[:-1] - - # We broadcast with a 1D constant to automatically make the result a - # TensorShape if possible. - batch_shape = distribution_util.prefer_static_broadcast_shape( - mix_batch_shape, - constant_op.constant([], dtype=dtypes.int32, name="batch_shape")) - event_shape = constant_op.constant( - [], dtype=dtypes.int32, name="event_shape") + # grid # shape: [B, k, q] + # endpoint_affine # len=k, shape: [B, d, d] + batch_shape = grid.shape[:-2] + batch_shape_tensor = array_ops.shape(grid)[:-2] + event_shape = None + event_shape_tensor = None + + def _set_event_shape(shape, shape_tensor): + if event_shape is None: + return shape, shape_tensor + return (array_ops.broadcast_static_shape(event_shape, shape), + array_ops.broadcast_dynamic_shape( + event_shape_tensor, shape_tensor)) + for aff in endpoint_affine: - b, e = distribution_util.shapes_from_loc_and_scale(aff.shift, aff.scale) - if batch_shape is None: - batch_shape = distribution_util.prefer_static_broadcast_shape( - mix_batch_shape, b) - else: - batch_shape = distribution_util.prefer_static_broadcast_shape( - batch_shape, b) - event_shape = distribution_util.prefer_static_broadcast_shape( - event_shape, e) - if isinstance(batch_shape, tensor_shape.TensorShape): - batch_shape = ops.convert_to_tensor( - batch_shape.as_list(), dtype=dtypes.int32, name="batch_shape") - if isinstance(event_shape, tensor_shape.TensorShape): - event_shape = ops.convert_to_tensor( - event_shape.as_list(), dtype=dtypes.int32, name="event_shape") - return batch_shape, event_shape - - -def interpolate_loc(deg, interpolate_weight, loc): + if aff.shift is not None: + batch_shape = array_ops.broadcast_static_shape( + batch_shape, aff.shift.shape[:-1]) + batch_shape_tensor = array_ops.broadcast_dynamic_shape( + batch_shape_tensor, array_ops.shape(aff.shift)[:-1]) + event_shape, event_shape_tensor = _set_event_shape( + aff.shift.shape[-1:], array_ops.shape(aff.shift)[-1:]) + + if aff.scale is not None: + batch_shape = array_ops.broadcast_static_shape( + batch_shape, aff.scale.batch_shape) + batch_shape_tensor = array_ops.broadcast_dynamic_shape( + batch_shape_tensor, aff.scale.batch_shape_tensor()) + event_shape, event_shape_tensor = _set_event_shape( + tensor_shape.TensorShape([aff.scale.range_dimension]), + aff.scale.range_dimension_tensor()[array_ops.newaxis]) + + return batch_shape, batch_shape_tensor, event_shape, event_shape_tensor + + +def interpolate_loc(grid, loc): """Helper which interpolates between two locs.""" if len(loc) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(loc))) - with ops.name_scope("interpolate_loc", values=[interpolate_weight, loc]): + deg = grid.shape.with_rank_at_least(1)[-1].value + if deg is None: + raise ValueError("Num quadrature grid points must be known prior " + "to graph execution.") + with ops.name_scope("interpolate_loc", values=[grid, loc]): if loc is None or loc[0] is None and loc[1] is None: return [None]*deg - w = interpolate_weight[..., array_ops.newaxis, :] # shape: [B, 1, deg] + # shape: [B, 1, k, deg] + w = grid[..., array_ops.newaxis, :, :] loc = [x[..., array_ops.newaxis] # shape: [B, e, 1] if x is not None else None for x in loc] if loc[0] is None: - x = (1. - w) * loc[1] # shape: [B, e, deg] + x = w[..., 1, :] * loc[1] # shape: [B, e, deg] elif loc[1] is None: - x = w * loc[0] # shape: [B, e, deg] + x = w[..., 0, :] * loc[0] # shape: [B, e, deg] else: delta = loc[0] - loc[1] - x = w * delta + loc[1] # shape: [B, e, deg] + x = w[..., 0, :] * delta + loc[1] # shape: [B, e, deg] return [x[..., k] for k in range(deg)] # list(shape:[B, e]) -def interpolate_scale(deg, interpolate_weight, scale): +def interpolate_scale(grid, scale): """Helper which interpolates between two scales.""" if len(scale) != 2: raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - with ops.name_scope("interpolate_scale", values=[interpolate_weight]): + deg = grid.shape.with_rank_at_least(1)[-1].value + if deg is None: + raise ValueError("Num quadrature grid points must be known prior " + "to graph execution.") + with ops.name_scope("interpolate_scale", values=[grid]): return [linop_add_lib.add_operators([ - linop_scale(interpolate_weight[..., k], scale[0]), - linop_scale(1. - interpolate_weight[..., k], scale[1]), - ])[0] for k in range(deg)] + linop_scale(grid[..., k, q], s) + for k, s in enumerate(scale) + ])[0] for q in range(deg)] def linop_scale(w, op): @@ -791,39 +926,12 @@ def linop_scale(w, op): def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" - args_ = [static_value(x) for x in args] + args_ = [distribution_util.static_value(x) for x in args] if any(vec is None for vec in args_): return array_ops.concat(args, axis=0) return [val for vec in args_ for val in vec] -def reduce_prod(x): - """Same as `math_ops.reduce_prod` but statically if possible.""" - x_ = static_value(x) - if x_ is not None: - return np.prod(x_, dtype=x.dtype.as_numpy_dtype) - return array_ops.reduce_prod(x) - - -def ndims_from_shape(shape): - """Returns `Tensor`'s `rank` implied by a `Tensor` shape.""" - if shape.shape.ndims not in (None, 1): - raise ValueError("input is not a valid shape: not 1D") - if not shape.dtype.is_integer: - raise TypeError("input is not a valid shape: wrong dtype") - if shape.shape.is_fully_defined(): - return shape.shape.as_list()[0] - return array_ops.shape(shape)[0] - - -def ndims(x): - """Returns rank, statically if possible.""" - x = ops.convert_to_tensor(x) - if x.shape.ndims is not None: - return x.shape.ndims - return array_ops.rank(x) - - def add(x, y): """Adds inputs; interprets `None` as zero.""" if x is None: @@ -836,3 +944,18 @@ def add(x, y): def vec_osquare(x): """Computes the outer-product of a (batch of) vector, i.e., x.T x.""" return x[..., :, array_ops.newaxis] * x[..., array_ops.newaxis, :] + + +def softmax(x, axis, name=None): + """Equivalent to tf.nn.softmax but works around b/70297725.""" + with ops.name_scope(name, "softmax", [x, axis]): + x = ops.convert_to_tensor(x, name="x") + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x, name="ndims")) + axis = ops.convert_to_tensor(axis, dtype=dtypes.int32, name="axis") + axis_ = tensor_util.constant_value(axis) + if axis_ is not None: + axis = np.int(ndims + axis_ if axis_ < 0 else axis_) + else: + axis = array_ops.where(axis < 0, ndims + axis, axis) + return nn_ops.softmax(x, axis=axis) diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index 356d78b67a8107750f68f7f84d73d1231f5b2b03..526fe2d39aef9aed833b889de80e849c469435e7 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -89,14 +89,13 @@ class VectorExponentialDiag( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. # The first component has pdf exp{-x}, the second 0.5 exp{-x / 2} - vex = ds.VectorExponentialDiag(scale_diag=[1., 2.]) + vex = tfd.VectorExponentialDiag(scale_diag=[1., 2.]) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([3., 4.]).eval() # shape: [] @@ -107,7 +106,7 @@ class VectorExponentialDiag( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialDiag(loc, scale_diag) + vex = tfd.VectorExponentialDiag(loc, scale_diag) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1..9d5fd9ac4178a1ae29b1ce32f304b22fd3d234dc 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -107,16 +107,15 @@ class VectorExponentialLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. mat = [[1.0, 0.1], [0.1, 1.0]] - vex = ds.VectorExponentialLinearOperator( - scale=la.LinearOperatorFullMatrix(mat)) + vex = tfd.VectorExponentialLinearOperator( + scale=tf.linalg.LinearOperatorFullMatrix(mat)) # Compute the pdf of an`R^2` observation; return a scalar. vex.prob([1., 2.]).eval() # shape: [] @@ -127,9 +126,9 @@ class VectorExponentialLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vex = ds.VectorExponentialLinearOperator( + vex = tfd.VectorExponentialLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[1.9, 2.2, 3.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py index 0e3867809a820f49cfa7f5282c47f786626481a6..8dd983b750d9b39775e570800006011f4968f7f3 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py @@ -101,10 +101,10 @@ class VectorLaplaceDiag( #### Examples ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 2-variate VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_diag=[1, 2.]) @@ -118,7 +118,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [] # Initialize a 3-batch, 2-variate scaled-identity VectorLaplace. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[1., -1], scale_identity_multiplier=[1, 2., 3]) @@ -136,7 +136,7 @@ class VectorLaplaceDiag( vla.prob([-1., 0]).eval() # shape: [3] # Initialize a 2-batch of 3-variate VectorLaplace's. - vla = ds.VectorLaplaceDiag( + vla = tfd.VectorLaplaceDiag( loc=[[1., 2, 3], [11, 22, 33]] # shape: [2, 3] scale_diag=[[1., 2, 3], diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1..ec485c95c15da2794b67d2699d2bdd9db97bb6c4 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -109,8 +109,7 @@ class VectorLaplaceLinearOperator( #### Examples ```python - ds = tf.contrib.distributions - la = tf.linalg + tfd = tf.contrib.distributions # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -124,9 +123,9 @@ class VectorLaplaceLinearOperator( # [ 0.1, -0.3, 0.4]]) # Divide scale by sqrt(2) so that the final covariance will be what we want. - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) + scale=tf.linalg.LinearOperatorLowerTriangular(scale / tf.sqrt(2.))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -143,9 +142,9 @@ class VectorLaplaceLinearOperator( scale_diag = [[1., 2, 3], [0.5, 1, 1.5]] # shape: [2, 3] - vla = ds.VectorLaplaceLinearOperator( + vla = tfd.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorDiag(scale_diag)) + scale=tf.linalg.LinearOperatorDiag(scale_diag)) # Compute the pdf of two `R^3` observations; return a length-2 vector. x = [[-0.9, 0, 0.1], diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py index 544a8710709a0afb56c6ae6f36d35de892e8e420..e1ccf116457a97261b9ce3965552764771d3bdd2 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -143,7 +143,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): broadcastable with `event_shape`. distribution: `tf.Distribution`-like instance. Distribution from which `k` iid samples are used as input to transformation `F`. Default is - `ds.Normal(0., 1.)`. + `tf.distributions.Normal(loc=0., scale=1.)`. Must be a scalar-batch, scalar-event distribution. Typically `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is a function of non-trainable parameters. WARNING: If you backprop through diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py index 29d41ab81c62d621c3c3533e1449341e9a085645..8c67647a618d22a58428d78865c4ebf7d98bdf9e 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py +++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py @@ -91,14 +91,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): Extra leading dimensions, if provided, allow for batches. ```python - ds = tf.contrib.distributions + tfd = tf.contrib.distributions # Initialize a single 3-variate vector Student's t-distribution. mu = [1., 2, 3] chol = [[1., 0, 0.], [1, 3, 0], [1, 2, 3]] - vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(df=2, loc=mu, scale_tril=chol) # Evaluate this on an observation in R^3, returning a scalar. vt.prob([-1., 0, 1]) @@ -107,7 +107,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution): mu = [[1., 2, 3], [11, 22, 33]] chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal. - vt = ds.VectorStudentT(loc=mu, scale_tril=chol) + vt = tfd.VectorStudentT(loc=mu, scale_tril=chol) # Evaluate this on a two observations, each in R^3, returning a length two # tensor. diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index dcc370cd00d5f93cd5b145a31fd58ef5041a86a8..9d2ca07c3a25fa7acb9b0f5806b763d9a57b51fa 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -41,28 +41,8 @@ support for distributed and multi-GPU training and CPU performance. ## Installation -Since eager execution is not yet part of a TensorFlow release, using it requires -either [building from source](https://www.tensorflow.org/install/install_sources) -or the latest nightly builds. The nightly builds are available as: - -- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and - -- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. - -For example, to run the latest nightly docker image: - -```sh -# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker -nvidia-docker pull tensorflow/tensorflow:nightly-gpu -nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu - -# If you do not have a GPU, use the CPU-only image -docker pull tensorflow/tensorflow:nightly -docker run -it -p 8888:8888 tensorflow/tensorflow:nightly -``` - -And then visit http://localhost:8888 in your browser for a Jupyter notebook -environment. Try out the notebooks below. +Eager execution is included in TensorFlow versions 1.5 and above. +Installation instructions at https://www.tensorflow.org/install/ ## Documentation @@ -76,3 +56,6 @@ For an introduction to eager execution in TensorFlow, see: ## Changelog - 2017/10/31: Initial preview release. +- 2017/12/01: Example of dynamic neural network: + [SPINN: Stack-augmented Parser-Interpreter Neural Network](https://arxiv.org/abs/1603.06021). + See [README.md](python/examples/spinn/README.md) for details. diff --git a/tensorflow/contrib/eager/proto/BUILD b/tensorflow/contrib/eager/proto/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..aedfec8924e7314addd22349c0576a84a58d9aa3 --- /dev/null +++ b/tensorflow/contrib/eager/proto/BUILD @@ -0,0 +1,24 @@ +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +tf_proto_library( + name = "checkpointable_object_graph_proto", + srcs = [ + "checkpointable_object_graph.proto", + ], + visibility = ["//tensorflow/contrib/eager/python:__subpackages__"], +) diff --git a/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto new file mode 100644 index 0000000000000000000000000000000000000000..4f71aec96a2c3edee8a32b4e14584bd56ef3d439 --- /dev/null +++ b/tensorflow/contrib/eager/proto/checkpointable_object_graph.proto @@ -0,0 +1,57 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; + +package tensorflow.contrib.eager; + +// Prototype for an addition to BundleHeaderProto which saves extra information +// about the objects which own variables, allowing for more robust checkpoint +// loading into modified programs. + +message CheckpointableObjectGraph { + message Object { + message ObjectReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // being referenced. + int32 node_id = 1; + // A user-provided name for the edge. + string local_name = 2; + } + + message VariableReference { + // A name for the variable which is unique within the object which owns + // it. Does not include a name_scope or variable_scope prefix. + string local_name = 1; + // The full name of the variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 2; + // The generated name of the variable in the checkpoint. + string checkpoint_key = 3; + } + + message SlotVariableReference { + // An index into `CheckpointableObjectGraph.nodes`, indicating the object + // which created the variable that this variable is slotting for. + int32 original_variable_node_id = 1; + // The local name of the variable being slotted for within the object that + // owns it. + string original_variable_local_name = 2; + // The name of the slot (e.g. "m"/"v"). + string slot_name = 3; + // The full name of the slot variable. Used to allow name-based loading of + // checkpoints which were saved using an object-based API. + string full_name = 4; + // The generated name of the variable in the checkpoint. + string checkpoint_key = 5; + } + + // Objects which this object depends on. + repeated ObjectReference children = 1; + // Non-slot variables owned by this object. + repeated VariableReference variables = 2; + // Slot variables owned by this object. + repeated SlotVariableReference slot_variables = 3; + } + + repeated Object nodes = 1; +} diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index bf2e883bc53c3281ef89d1200f5a089305ef3e72..cfb38a1d26c41a3923da7c989244a3d53b6a496b 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -19,6 +19,8 @@ py_library( "//tensorflow/python:framework_test_lib", "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:template", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/eager:backprop", @@ -50,7 +52,7 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_py", + "//tensorflow/contrib/data/python/ops:prefetching_ops", "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", @@ -67,6 +69,7 @@ cuda_py_test( srcs = ["datasets_test.py"], additional_deps = [ ":datasets", + "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", @@ -103,37 +106,6 @@ cuda_py_test( ], ) -py_library( - name = "summary_writer", - srcs = ["summary_writer.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/summary:gen_summary_ops", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:state_ops", - "//tensorflow/python:summary_op_util", - "//tensorflow/python:variable_scope", - "//tensorflow/python/eager:context", - ], -) - -cuda_py_test( - name = "summary_writer_test", - srcs = ["summary_writer_test.py"], - additional_deps = [ - ":summary_writer", - "//third_party/py/numpy", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:constant_op", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:test", - ], -) - py_library( name = "metrics", srcs = [ @@ -232,6 +204,7 @@ py_test( srcs_version = "PY2AND3", deps = [ ":network", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:constant_op", "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", @@ -246,6 +219,51 @@ py_test( ], ) +py_library( + name = "checkpointable", + srcs = ["checkpointable.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/eager/proto:checkpointable_object_graph_proto_py", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + ], +) + +py_test( + name = "checkpointable_test", + srcs = ["checkpointable_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":checkpointable", + ":network", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:layers_base", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + "@six_archive//:six", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/checkpointable.py b/tensorflow/contrib/eager/python/checkpointable.py new file mode 100644 index 0000000000000000000000000000000000000000..896b38a7348e1fdd5a13b197e3ee34f5c4c5a22c --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable.py @@ -0,0 +1,773 @@ +"""An object-local variable management scheme.""" +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re +import weakref + +from tensorflow.contrib.eager.proto import checkpointable_object_graph_pb2 +from tensorflow.python.eager import context +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import optimizer as optimizer_lib +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import slot_creator +from tensorflow.python.training import training + +_CheckpointableReference = collections.namedtuple( + "_CheckpointableReference", + [ + # The local name if explicitly specified, else None. + "name", + # The Checkpointable object being referenced. + "ref" + ]) + +# Validation regular expression for the local names of Checkpointable +# objects. In particular, disallows "/" in names, and reserves dash-prefixed +# names (which are not valid Python identifiers, so we're not restricting the +# __setattr__ syntax that way). +_VALID_LOCAL_NAME = re.compile(r"^[A-Za-z0-9_.][A-Za-z0-9_.-]*$") + +# Keyword for identifying that the next bit of a checkpoint variable name is a +# slot name. May not be the local name of a checkpointable. Checkpoint names for +# slot variables look like: +# +# /<_OPTIMIZER_SLOTS_NAME>// +# +# Where is a full path from the checkpoint root to the +# variable being slotted for. +_OPTIMIZER_SLOTS_NAME = "-OPTIMIZER_SLOT" + + +def _assign_existing_variable(variable_to_restore, value_pointer): + """Set a variable from a _ValuePointer object.""" + base_type = variable_to_restore.dtype.base_dtype + with ops.colocate_with(variable_to_restore): + # TODO(allenl): Handle partitioned variables + value_to_restore, = io_ops.restore_v2( + prefix=value_pointer.save_path, + tensor_names=[value_pointer.checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="checkpoint_initializer") + initializer_op = state_ops.assign(variable_to_restore, value_to_restore) + variable_to_restore._initializer_op = initializer_op # pylint:disable=protected-access + if value_pointer.session is not None: + value_pointer.session.run(initializer_op) + + +def _default_getter(name, shape, dtype, initializer=None, + partition_info=None, **kwargs): + """A pared-down version of get_variable which does not reuse variables.""" + dtype = dtypes.as_dtype(dtype) + shape_object = tensor_shape.as_shape(shape) + with ops.init_scope(): + if initializer is None: + initializer, initializing_from_value = ( + variable_scope._get_default_variable_store()._get_default_initializer( # pylint: disable=protected-access + name=name, shape=shape_object, dtype=dtype)) + else: + initializing_from_value = not callable(initializer) + # Same logic as get_variable + if initializing_from_value: + if shape is not None: + raise ValueError("If initializer is a constant, do not specify shape.") + initial_value = initializer + variable_dtype = None + else: + # Instantiate initializer if provided initializer is a type object. + if isinstance(initializer, type(init_ops.Initializer)): + initializer = initializer(dtype=dtype) + def initial_value(): + return initializer( + shape_object.as_list(), dtype=dtype, partition_info=partition_info) + variable_dtype = dtype.base_dtype + return resource_variable_ops.ResourceVariable( + initial_value=initial_value, + name=name, + dtype=variable_dtype, + **kwargs + ) + + +class Checkpointable(object): + """Manages variables and dependencies on other objects. + + To make reliable checkpoints, all `Checkpointable`s on which this object + depends must be registered in the constructor using `track_checkpointable` in + a deterministic order, and if possible they should be named. Variables may be + created using `add_variable` outside of the constructor and in any order, but + only these variables will be saved. + """ + + def __init__(self): + # A list of _CheckpointableReference objects. + self._checkpoint_dependencies = [] + # Maps names -> Checkpointable objects for named dependencies + self._dependency_names = {} + # Set of all tracked Checkpointables + self._already_tracked = set() + self._owned_variables = {} # local name -> variable object + self._deferred_restorations = {} # local name -> _VariableRestoration + # object + + def __setattr__(self, name, value): + """Support self.foo = checkpointable syntax. + + `self.foo = checkpointable` is equivalent to + `self.foo = self.track_checkpointable(checkpointable, name='foo')`. + + No new tracking if `value` is not a `Checkpointable`, or if `value` is + already being tracked (either because of an explicit `track_checkpointable` + or a previous `__setattr__`). + + Args: + name: The name of the property being set. + value: The new value for the property. + """ + # Give child classes (e.g. Network) priority, then track only if the object + # hasn't been added to _already_tracked. + super(Checkpointable, self).__setattr__(name, value) + if (isinstance(value, Checkpointable) + and value not in self._already_tracked): + self.track_checkpointable(value, name=name) + + def add_variable(self, name, shape=None, dtype=dtypes.float32, + initializer=None, **kwargs): + """Create a new variable object to be saved with this `Checkpointable`. + + If the user has requested that this object or another `Checkpointable` which + depends on this object be restored from a checkpoint (deferred loading + before variable object creation), `initializer` may be ignored and the value + from the checkpoint used instead. + + Args: + name: A name for the variable. Must be unique within this object. + shape: The shape of the variable. + dtype: The data type of the variable. + initializer: The initializer to use. Ignored if deferred loading has been + requested. + **kwargs: Passed to the ResourceVariable constructor. + + Returns: + The new variable object. + + Raises: + ValueError: If the variable name is not unique. + RuntimeError: If __init__ has not been called. + """ + if not hasattr(self, "_owned_variables"): + raise RuntimeError("Need to call Checkpointable.__init__ before adding " + "variables.") + if name in self._owned_variables: + raise ValueError( + ("A variable named '%s' already exists in this Checkpointable, but " + "Checkpointable.add_variable called to create another with " + "that name. Variable names must be unique within a Checkpointable " + "object.") % (name,)) + if "getter" in kwargs: + # Allow the getter to be overridden, typically because there is a need for + # compatibility with some other variable creation mechanism. This should + # be relatively uncommon in user code. + getter = kwargs.pop("getter") + else: + getter = _default_getter + deferred_restoration = self._deferred_restorations.pop(name, None) + if deferred_restoration is not None: + dtype = deferred_restoration.value_pointer.dtype + base_type = dtype.base_dtype + # TODO(allenl): Handle partitioned variables here too + with ops.init_scope(): + initializer, = io_ops.restore_v2( + prefix=deferred_restoration.value_pointer.save_path, + tensor_names=[deferred_restoration.value_pointer.checkpoint_key], + shape_and_slices=[""], + dtypes=[base_type], + name="checkpoint_initializer") + # We need to un-set the shape so get_variable doesn't complain, but we + # also need to set the static shape information on the initializer if + # possible so we don't get a variable with an unknown shape. + initializer.set_shape(shape) + # Un-set shape since we're using a constant initializer + shape = None + + new_variable = getter( + name=name, shape=shape, dtype=dtype, initializer=initializer, **kwargs) + if deferred_restoration is not None: + if deferred_restoration.value_pointer.session is not None: + deferred_restoration.value_pointer.session.run(new_variable.initializer) + for slot_restoration in deferred_restoration.slot_restorations: + strong_ref = slot_restoration.optimizer_ref() + if strong_ref is None: + # If the optimizer object has been garbage collected, there's no need + # to create the slot variable. + continue + strong_ref._process_slot_restoration( # pylint: disable=protected-access + slot_restoration, new_variable) + self._owned_variables[name] = new_variable + return new_variable + + def track_checkpointable(self, checkpointable, name): + """Declare a dependency on another `Checkpointable` object. + + Indicates that checkpoints for this object should include variables from + `checkpointable`. + + Variables in a checkpoint are mapped to `Checkpointable`s based on names. To + avoid breaking existing checkpoints when modifying a class, neither variable + names nor dependency names (the names passed to `track_checkpointable`) may + change. + + Args: + checkpointable: A `Checkpointable` which this object depends on. + name: A local name for `checkpointable`, used for loading checkpoints into + the correct objects. Python 2 identifiers are valid names, with the + addition of leading numerals, periods anywhere, and non-leading dashes. + Specifically names must match the regular expression + `^[A-Za-z0-9_.][A-Za-z0-9_.-]*$`. + + Returns: + `checkpointable`, for convenience when declaring a dependency and + assigning to a member variable in one statement. + + Raises: + RuntimeError: If __init__ was not called. + TypeError: If `checkpointable` does not inherit from `Checkpointable`. + ValueError: For invalid names. + """ + if not hasattr(self, "_checkpoint_dependencies"): + raise RuntimeError("Need to call Checkpointable.__init__ before calling " + "Checkpointable.track_checkpointable().") + if not isinstance(checkpointable, Checkpointable): + raise TypeError( + ("Checkpointable.track_checkpointable() passed type %s, not a " + "Checkpointable.") % (type(checkpointable),)) + if not _VALID_LOCAL_NAME.match(name): + raise ValueError( + ("Checkpointable names must match the regular expression '%s', but " + "got an invalid name '%s' instead.") % (_VALID_LOCAL_NAME.pattern, + name)) + if (name in self._dependency_names + and self._dependency_names[name] is not checkpointable): + raise ValueError( + ("Called Checkpointable.track_checkpointable() with name='%s', but " + "a Checkpointable with this name is already declared as a " + "dependency. Names must be unique.") % (name,)) + self._dependency_names[name] = checkpointable + self._checkpoint_dependencies.append( + _CheckpointableReference(name=name, ref=checkpointable)) + self._already_tracked.add(checkpointable) + return checkpointable + + def _process_restoration(self, restoration): + """Restore a variable and its slot variables (may be deferred).""" + variable_to_restore = self._owned_variables.get(restoration.name, None) + if variable_to_restore is not None: + # This variable already exists, so just do an assignment for this and any + # slot variables which depend on it. + _assign_existing_variable( + variable_to_restore, value_pointer=restoration.value_pointer) + for slot_restoration in restoration.slot_restorations: + strong_ref = slot_restoration.optimizer_ref() + if strong_ref is None: + continue + strong_ref._process_slot_restoration( # pylint: disable=protected-access + slot_restoration, variable_to_restore) + else: + # Save this restoration for later. This intentionally overwrites any + # previous deferred restorations, since that gives the same semantics as + # direct assignment. + self._deferred_restorations[restoration.name] = restoration + + def _process_slot_restoration(self, slot_restoration, variable): + """Restore a slot variable's value (creating it if necessary).""" + # TODO(allenl): Move this to Optimizer + assert isinstance(self, optimizer_lib.Optimizer) + named_slots = self._slot_dict(slot_restoration.slot_name) + variable_key = optimizer_lib._var_key(variable) # pylint: disable=protected-access + existing_slot_variable = named_slots.get(variable_key, None) + if existing_slot_variable is None: + base_dtype = slot_restoration.value_pointer.dtype.base_dtype + initializer, = io_ops.restore_v2( + prefix=slot_restoration.value_pointer.save_path, + tensor_names=[slot_restoration.value_pointer.checkpoint_key], + shape_and_slices=[""], + dtypes=[base_dtype], + name="checkpoint_initializer") + new_slot_variable = slot_creator.create_slot(variable, initializer, + slot_restoration.slot_name) + if slot_restoration.value_pointer.session is not None: + slot_restoration.value_pointer.session.run( + new_slot_variable.initializer) + named_slots[variable_key] = new_slot_variable + else: + _assign_existing_variable( + existing_slot_variable, value_pointer=slot_restoration.value_pointer) + + @property + def checkpoint_dependencies(self): + """Other `Checkpointable` objects on which this object depends.""" + return self._checkpoint_dependencies + + +def _breadth_first_checkpointable_traversal(root_checkpointable): + """Find shortest paths to all variables owned by dependencies of root.""" + bfs_sorted = [] + root_checkpointable_reference = _CheckpointableReference( + name=None, ref=root_checkpointable) + to_visit = collections.deque([root_checkpointable_reference]) + path_to_root = {root_checkpointable_reference: ()} + while to_visit: + current_checkpointable = to_visit.popleft() + bfs_sorted.append(current_checkpointable) + for child_checkpointable in ( + current_checkpointable.ref.checkpoint_dependencies): + if child_checkpointable not in path_to_root: + path_to_root[child_checkpointable] = ( + path_to_root[current_checkpointable] + (child_checkpointable,)) + to_visit.append(child_checkpointable) + return bfs_sorted, path_to_root + + +def _object_prefix_from_path(path_to_root): + return "/".join( + (checkpointable.name for checkpointable in path_to_root)) + + +def _escape_variable_name(variable_name): + # We need to support slashes in variable names for compatibility, since this + # naming scheme is being patched in to things like Layer.add_variable where + # slashes were previously accepted. We also want to use slashes to indicate + # edges traversed to reach the variable, so we escape forward slashes in + # variable names. + return variable_name.replace("_S_", "_S_.").replace(r"/", r"_S__") + + +def _variable_naming_for_object(path_to_root): + """Make a function for naming variables in an object.""" + # Name non-slot variables: + # + # / + # + # is not necessarily unique, but this is fine since we also + # save the graph of `Checkpointable`s with the checkpoint. Even if this path + # no longer exists because of a change in the Python program, we can look up + # the `Checkpointable` which owns the variable in the checkpoint's graph and + # use another path if one still exists. + + object_prefix = _object_prefix_from_path(path_to_root) + if object_prefix: + object_prefix += "/" + + def _name_single_variable(local_name): + """Names a variable within an object.""" + return object_prefix + _escape_variable_name(local_name) + + return _name_single_variable + + +def _slot_variable_naming_for_optimizer(optimizer, path_to_root): + """Make a function for naming slot variables in an optimizer.""" + # Name slot variables: + # + # /<_OPTIMIZER_SLOTS_NAME>// + # + # where is exactly the checkpoint name used for the original + # variable, including the path from the checkpoint root and the local name in + # the object which owns it. Note that we only save slot variables if the + # variable it's slotting for is also being saved. + + optimizer_identifier = "/%s/%s/" % (_OPTIMIZER_SLOTS_NAME, + _object_prefix_from_path(path_to_root)) + + def _name_slot_variable(variable_path, slot_name): + """With an optimizer specified, name a slot variable.""" + + if not _VALID_LOCAL_NAME.match(slot_name): + # Slot variable names include the name of the slot. We need to + # validate that part of the name to be sure that the checkpoint name + # is a valid name scope name. + raise ValueError( + ("Could not save slot variables for optimizer %s, because its " + "slot name has invalid characters (got '%s', was expecting it " + "to match the regular expression '%s').") % + (optimizer, slot_name, _VALID_LOCAL_NAME.pattern)) + + return variable_path + optimizer_identifier + slot_name + + return _name_slot_variable + + +def _serialize_non_slot_variables(checkpointable_objects, path_to_root, + object_graph_proto): + """Name non-slot variables and add them to `object_graph_proto`.""" + named_variables = {} + non_slot_variables = [] + checkpoint_node_ids = {} + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + checkpoint_node_ids[checkpointable] = checkpoint_id + + for checkpoint_id, checkpointable in enumerate(checkpointable_objects): + naming_scheme = _variable_naming_for_object(path_to_root[checkpointable]) + object_proto = object_graph_proto.nodes.add() + for (local_name, owned_variable) in sorted( + checkpointable.ref._owned_variables.items(), # pylint: disable=protected-access + key=lambda x: x[0]): + variable_name = naming_scheme(local_name) + named_variables[variable_name] = owned_variable + non_slot_variables.append(( + variable_name, # The variable's full checkpoint name + owned_variable, # The variable object + local_name, # The variable's local name + checkpoint_id)) # The checkpoint ID of the node which owns this + # variable. + variable_proto = object_proto.variables.add() + variable_proto.local_name = local_name + variable_proto.checkpoint_key = variable_name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [owned_variable], convert_variable_to_tensor=False) + variable_full_name, = saver_dict.keys() + variable_proto.full_name = variable_full_name + + for child in checkpointable.ref.checkpoint_dependencies: + child_proto = object_proto.children.add() + child_proto.node_id = checkpoint_node_ids[child] + child_proto.local_name = child.name + return named_variables, non_slot_variables + + +def _serialize_slot_variables(checkpointable_objects, path_to_root, + non_slot_variables, object_graph_proto): + """Name slot variables and add them to `object_graph_proto`.""" + named_slot_variables = {} + for optimizer_checkpoint_id, checkpointable_ref in enumerate( + checkpointable_objects): + if isinstance(checkpointable_ref.ref, optimizer_lib.Optimizer): + optimizer_object_proto = object_graph_proto.nodes[optimizer_checkpoint_id] + naming_scheme = _slot_variable_naming_for_optimizer( + optimizer=checkpointable_ref.ref, + path_to_root=path_to_root[checkpointable_ref]) + slot_names = checkpointable_ref.ref.get_slot_names() + for (variable_path, original_variable, original_variable_local_name, + original_node_checkpoint_id) in non_slot_variables: + for slot_name in slot_names: + slot_variable = checkpointable_ref.ref.get_slot( + original_variable, slot_name) + if slot_variable is not None: + checkpoint_name = naming_scheme( + variable_path=variable_path, slot_name=slot_name) + named_slot_variables[checkpoint_name] = slot_variable + slot_variable_proto = optimizer_object_proto.slot_variables.add() + slot_variable_proto.slot_name = slot_name + slot_variable_proto.checkpoint_key = checkpoint_name + # Figure out the name-based Saver's name for this variable. + saver_dict = saver_lib.BaseSaverBuilder.OpListToDict( + [slot_variable], convert_variable_to_tensor=False) + slot_variable_full_name, = saver_dict.keys() + slot_variable_proto.full_name = slot_variable_full_name + slot_variable_proto.original_variable_local_name = ( + original_variable_local_name) + slot_variable_proto.original_variable_node_id = ( + original_node_checkpoint_id) + return named_slot_variables + + +# TODO(allenl): Convenience utility for saving multiple objects (i.e. construct +# a root Checkpointable if passed a list of Checkpointables). +def _serialize_object_graph(root_checkpointable): + """Determine checkpoint keys for variables and build a serialized graph. + + Non-slot variables are keyed based on a shortest path from the root saveable + to the object which owns the variable (i.e. the one which called + `Checkpointable.add_variable` to create it). + + Slot variables are keyed based on a shortest path to the variable being + slotted for, a shortest path to their optimizer, and the slot name. + + Args: + root_checkpointable: A `Checkpointable` object whose variables (including + the variables of dependencies, recursively) should be saved. + + Returns: + A tuple of (named_variables, object_graph_proto): + named_variables: A dictionary mapping names to variable objects. + object_graph_proto: A CheckpointableObjectGraph protocol buffer containing + the serialized object graph and variable references. + + Raises: + ValueError: If there are invalid characters in an optimizer's slot names. + """ + checkpointable_objects, path_to_root = ( + _breadth_first_checkpointable_traversal(root_checkpointable)) + object_graph_proto = ( + checkpointable_object_graph_pb2.CheckpointableObjectGraph()) + + # Gather non-slot variables. + named_variables, non_slot_variables = _serialize_non_slot_variables( + checkpointable_objects, path_to_root, object_graph_proto) + + # Gather slot variables which are associated with variables gathered above. + named_slot_variables = _serialize_slot_variables( + checkpointable_objects, path_to_root, non_slot_variables, + object_graph_proto) + + named_variables.update(named_slot_variables) + return named_variables, object_graph_proto + + +def _set_reference(reference_proto_table, key, checkpointable, parent, + object_id_map): + """Record a checkpoint<->object correspondence, with error checking. + + Args: + reference_proto_table: Map from names or numbers to `ObjectReference` protos + within the parent object. + key: Either a numeric or string identifier for the reference. + checkpointable: The object to record a correspondence for. + parent: The parent Python object, for creating a useful error message. + object_id_map: The map from `node_id` to Python object in which to record + the reference. + Returns: + The `node_id` of the Object proto corresponding to the specified Python + object. + Raises: + AssertionError: If another object is already bound to the `Object` proto. + """ + reference_proto = reference_proto_table[key] + set_reference = object_id_map.setdefault(reference_proto.node_id, + checkpointable) + if set_reference is not checkpointable: + raise AssertionError( + ("Unable to load the checkpoint into this object graph. Either " + "the Checkpointable object references in the Python program " + "have changed in an incompatible way, or the checkpoint was " + "generated in an incompatible program.\n\nTwo checkpoint " + "references (one being '%s' in %s) resolved to different " + "objects (%s and %s).") % (key, parent, set_reference, + checkpointable)) + return reference_proto.node_id + + +def _checkpoint_object_id_map(root_checkpointable, object_graph_proto): + """Match a checkpointed object graph to a Python object graph. + + Args: + root_checkpointable: A Checkpointable object. + object_graph_proto: A CheckpointableObjectGraph protocol buffer representing + a serialized object graph. + Returns: + A dictionary mapping from checkpoint node ids (indices into + `object_graph_proto.nodes`) to `Checkpointable` objects which are + dependencies of `root_checkpointable`. + """ + node_list = object_graph_proto.nodes + # Queue of (checkpointable object, node id) + to_visit = collections.deque([(root_checkpointable, 0)]) + object_id_map = {0: root_checkpointable} + seen = set() + while to_visit: + checkpointable, node_id = to_visit.popleft() + object_proto = node_list[node_id] + named_children = {} + for child_reference in object_proto.children: + if child_reference.local_name: + named_children[child_reference.local_name] = child_reference + else: + raise AssertionError( + ("The checkpointed object graph contains a reference without " + "a name (corrupted?). The reference was from the node %s.") + % (object_proto,)) + + for checkpointable_reference in checkpointable._checkpoint_dependencies: # pylint: disable=protected-access + child_node_id = _set_reference( + reference_proto_table=named_children, + key=checkpointable_reference.name, + checkpointable=checkpointable_reference.ref, + parent=checkpointable, + object_id_map=object_id_map) + if child_node_id not in seen: + seen.add(child_node_id) + to_visit.append((checkpointable_reference.ref, child_node_id)) + + return object_id_map + + +_ValuePointer = collections.namedtuple( + "_ValuePointer", + [ + # Information needed to look up the value to restore. + "save_path", + "checkpoint_key", + "dtype", + # The session to use when restoring (None when executing eagerly) + "session", + ]) + +_SlotVariableRestoration = collections.namedtuple( + "_SlotVariableRestoration", + [ + # A weak reference to the Optimizer object + "optimizer_ref", + # The slot name + "slot_name", + # The _ValuePointer to use when restoring + "value_pointer", + ]) + +_VariableRestoration = collections.namedtuple( + "_VariableRestoration", + [ + # The variable's (local) name. + "name", + # _SlotVariableRestoration objects indicating slot variables which + # should be created once this variable has been restored. + "slot_restorations", + # The _ValuePointer to use when restoring + "value_pointer", + ]) + + +def _gather_restorations(object_graph_proto, save_path, object_id_map, + dtype_map, session): + """Iterate over variables to restore, matching with Checkpointable objects.""" + variable_to_slot_restorations = {} + for node_id, node in enumerate(object_graph_proto.nodes): + for slot_variable in node.slot_variables: + original_variable_key = (slot_variable.original_variable_node_id, + slot_variable.original_variable_local_name) + variable_to_slot_restorations.setdefault( + original_variable_key, []).append( + _SlotVariableRestoration( + optimizer_ref=weakref.ref(object_id_map[node_id]), + slot_name=slot_variable.slot_name, + value_pointer=_ValuePointer( + save_path=save_path, + checkpoint_key=slot_variable.checkpoint_key, + dtype=dtype_map[slot_variable.checkpoint_key], + session=session))) + + for node_id, node in enumerate(object_graph_proto.nodes): + for variable in node.variables: + slots_key = (node_id, variable.local_name) + variable_restore = _VariableRestoration( + name=variable.local_name, + slot_restorations=variable_to_slot_restorations.get(slots_key, []), + value_pointer=_ValuePointer( + save_path=save_path, + checkpoint_key=variable.checkpoint_key, + dtype=dtype_map[variable.checkpoint_key], + session=session)) + yield variable_restore, object_id_map[node_id] + + +def save(file_prefix, root_checkpointable, global_step=None, session=None): + """Save a training checkpoint. + + Args: + file_prefix: A prefix to use for the checkpoint filenames + (/path/to/directory/and_a_prefix). Names are generated based on this + prefix and the global step, if provided. + root_checkpointable: A Checkpointable object to save. The checkpoint + includes variables created by this object and any Checkpointable objects + it depends on. + global_step: An integer variable or Tensor, used to number + checkpoints. Typically this value is saved along with other variables in + training checkpoints, which will happen automatically if it was created by + `root_checkpointable` or one of its dependencies (via + `Checkpointable.add_variable`). + session: The session to evaluate variables in. Ignored when executing + eagerly. If not provided when graph building, the default session is used. + + Returns: + The full path to the checkpoint. + + Currently also returns the serialized object graph proto, but that will go + away once it's saved with the checkpoint. + """ + named_variables, serialized_graph = _serialize_object_graph( + root_checkpointable) + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + else: + session = None + with ops.device("/device:CPU:0"): + save_path = saver_lib.Saver(var_list=named_variables).save( + sess=session, + save_path=file_prefix, + write_meta_graph=False, + global_step=global_step) + # TODO(allenl): Save the graph with the checkpoint, then returning it and + # taking it as an argument to restore won't be necessary. + return serialized_graph, save_path + + +# NOTE: Will be restore(file_prefix, root_checkpointable) once the object graph +# is saved with the checkpoint. +def restore(save_path, root_checkpointable, object_graph_proto, session=None): + """Restore a training checkpoint. + + Restores the values of variables created with `Checkpointable.add_variable` in + the dependency graph of `root_checkpointable`. Either assigns values + immediately (if variables to restore have been created already), or defers + restoration until the variables are created. + + When building a graph, restorations are executed in the default session if + `session` is `None`. Variable initializers read checkpointed values. + + Args: + save_path: The path to the checkpoint, as returned by `save` or + `tf.train.latest_checkpoint`. If None (as when there is no latest + checkpoint for `tf.train.latest_checkpoint` to return), does nothing. + root_checkpointable: The root of the object graph to restore. Variables to + restore need not have been created yet, but all dependencies on other + Checkpointable objects should already be declared. Objects in the + dependency graph are matched to objects in the checkpointed graph, and + matching objects have their variables restored (or the checkpointed values + saved for eventual restoration when the variable is created). + object_graph_proto: (Temporary) the checkpointed object graph. This will + eventually be saved with the checkpoint, and will not be part of the final + API. + session: The session to evaluate assignment ops in. Ignored when executing + eagerly. If not provided when graph building, the default session is used. + """ + if save_path is None: + return + object_id_map = _checkpoint_object_id_map(root_checkpointable, + object_graph_proto) + reader = training.NewCheckpointReader(save_path) + dtype_map = reader.get_variable_to_dtype_map() + if context.in_graph_mode(): + if session is None: + session = ops.get_default_session() + else: + session = None + for restoration, checkpointable in _gather_restorations( + object_graph_proto, save_path, object_id_map, dtype_map, session=session): + checkpointable._process_restoration(restoration) # pylint: disable=protected-access + diff --git a/tensorflow/contrib/eager/python/checkpointable_test.py b/tensorflow/contrib/eager/python/checkpointable_test.py new file mode 100644 index 0000000000000000000000000000000000000000..f7bc155decbb574ddd4b53190da3c3b3ee9b6a4e --- /dev/null +++ b/tensorflow/contrib/eager/python/checkpointable_test.py @@ -0,0 +1,497 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools +import os + +import six + +from tensorflow.contrib.eager.python import checkpointable +from tensorflow.contrib.eager.python import network as network_lib +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.layers import base +from tensorflow.python.layers import core +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import adam +from tensorflow.python.training import saver as core_saver +from tensorflow.python.training import training_util + + +class CheckpointableDenseLayer(core.Dense, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + core.Dense.__init__(self, *args, **kwargs) + + def add_variable(self, name, shape, **kwargs): + # Calls both Checkpointable.add_variable and Layer.add_variable. Eventually + # Layer.add_variable should inherit from Checkpointable and simply call + # super and then do post-processing. + return checkpointable.Checkpointable.add_variable( + self, + name=name, + shape=shape, + getter=functools.partial(core.Dense.add_variable, self), + **kwargs) + + +# pylint: disable=not-callable +class CheckpointableNetwork(network_lib.Network, checkpointable.Checkpointable): + + def __init__(self): + network_lib.Network.__init__(self) + checkpointable.Checkpointable.__init__(self) + + def __setattr__(self, name, value): + if isinstance(value, base.Layer) and value not in self._already_tracked: + self.track_layer(value, name=name) + # Checkpointable is next in the method resolution order, so this will catch + # Checkpointable objects which aren't Layers. + super(CheckpointableNetwork, self).__setattr__(name, value) + + def track_layer(self, layer, name): + self.track_checkpointable(layer, name=name) + return super(CheckpointableNetwork, self).track_layer(layer) + + +class CheckpointableAdam(adam.AdamOptimizer, checkpointable.Checkpointable): + + def __init__(self, *args, **kwargs): + checkpointable.Checkpointable.__init__(self) + adam.AdamOptimizer.__init__(self, *args, **kwargs) + + # NOTE: Copied from Optimizer with modifications to use add_variable + # for non-slot variables. These contortions are necessary to maintain + # checkpoint compatibility with variable.name based saving. + # TODO(allenl): Make this cleaner. + def _create_non_slot_variable(self, initial_value, name, colocate_with): + """Add an extra variable, not associated with a slot.""" + if context.in_graph_mode(): + graph = colocate_with.graph + else: + graph = None + + key = (name, graph) + v = self._non_slot_dict.get(key, None) + if v is None: + with ops.colocate_with(colocate_with): + def _variable_getter(name, shape, dtype, initializer): + del shape, dtype # not used, but there for compatibility + return variable_scope.variable( + name=name, initial_value=initializer, trainable=False) + + initial_value = ops.convert_to_tensor(initial_value) + v = self.add_variable( + name=name, + shape=initial_value.get_shape(), + initializer=initial_value, + getter=_variable_getter) + + self._non_slot_dict[key] = v + + return v + + +class NonLayerCheckpointable(checkpointable.Checkpointable): + + def __init__(self): + super(NonLayerCheckpointable, self).__init__() + self.a_variable = self.add_variable(name="a_variable", shape=[]) + + +class MyNetwork(CheckpointableNetwork): + """A concrete Network for testing.""" + + def __init__(self): + super(MyNetwork, self).__init__() + self._named_dense = CheckpointableDenseLayer(1, use_bias=True) + self._via_track_layer = self.track_layer( + CheckpointableDenseLayer(1, use_bias=False), name="via_track_layer") + # We can still track Checkpointables which aren't Layers. + self._non_layer = NonLayerCheckpointable() + + def call(self, values): + return self._via_track_layer(self._named_dense(values)) + + +class Root(checkpointable.Checkpointable): + """A stand-in for a Trainer class.""" + + def __init__(self, optimizer, network): + super(Root, self).__init__() + self._optimizer = optimizer + self._network = self.track_checkpointable(network, "network") + self._global_step = None + + @property + def global_step(self): + if self._global_step is None: + # Get the default create_global_step utility to actually call + # self.add_variable, by setting a custom creator. + def _owned_variable_as_creator( + next_creator, initial_value, **kwargs): + def _creator_as_getter(initializer, **kwargs): + return next_creator(initial_value=initializer, **kwargs) + return self.add_variable( + getter=_creator_as_getter, initializer=initial_value, shape=[], + **kwargs) + + with variable_scope.variable_creator_scope( + _owned_variable_as_creator): + self._global_step = training_util.create_global_step() + return self._global_step + + +class InterfaceTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testAddVariable(self): + obj = NonLayerCheckpointable() + with self.assertRaisesRegexp(ValueError, "do not specify shape"): + obj.add_variable( + name="shape_specified_twice", shape=[], initializer=1) + constant_initializer = obj.add_variable( + name="constant_initializer", initializer=1) + with variable_scope.variable_scope("some_variable_scope"): + ones_initializer = obj.add_variable( + name="ones_initializer", + shape=[2], + initializer=init_ops.ones_initializer(dtype=dtypes.float32)) + bare_initializer = obj.add_variable( + name="bare_initializer", + shape=[2, 2], + dtype=dtypes.float64, + initializer=init_ops.zeros_initializer) + + # Even in graph mode, there are no naming conflicts between objects, only + # naming conflicts within an object. + other_duplicate = resource_variable_ops.ResourceVariable( + name="duplicate", initial_value=1.) + duplicate = obj.add_variable(name="duplicate", shape=[]) + with self.assertRaisesRegexp(ValueError, "'duplicate' already exists"): + obj.add_variable(name="duplicate", shape=[]) + + if context.in_graph_mode(): + self.evaluate(variables.global_variables_initializer()) + self.assertEqual("constant_initializer:0", constant_initializer.name) + self.assertEqual(1, self.evaluate(constant_initializer)) + self.assertEqual("some_variable_scope/ones_initializer:0", + ones_initializer.name) + self.assertAllEqual([1, 1], self.evaluate(ones_initializer)) + self.assertAllEqual([[0., 0.], + [0., 0.]], self.evaluate(bare_initializer)) + self.assertEqual("a_variable:0", obj.a_variable.name) + self.assertEqual("duplicate:0", other_duplicate.name) + if context.in_graph_mode(): + # The .name attribute may be globally influenced, but the checkpoint name + # won't be (tested below). + self.assertEqual("duplicate_1:0", duplicate.name) + else: + # When executing eagerly, there's no uniquification of variable names. The + # checkpoint name will be the same. + self.assertEqual("duplicate:0", duplicate.name) + named_variables, _ = checkpointable._serialize_object_graph(obj) + expected_checkpoint_names = ( + "a_variable", + "bare_initializer", + "constant_initializer", + "duplicate", + "ones_initializer", + ) + six.assertCountEqual( + self, expected_checkpoint_names, named_variables.keys()) + + def testInitNotCalled(self): + + class NoInit(checkpointable.Checkpointable): + + def __init__(self): + pass + + with self.assertRaisesRegexp(RuntimeError, "__init__"): + NoInit().add_variable("var", shape=[]) + + +class CheckpointingTests(test.TestCase): + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNamingWithOptimizer(self): + input_value = constant_op.constant([[3.]]) + network = MyNetwork() + # A nuisance Network using the same optimizer. Its slot variables should not + # go in the checkpoint, since it is never depended on. + other_network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Root(optimizer=optimizer, network=network) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=root_checkpointable.global_step) + optimizer.minimize( + lambda: other_network(input_value), + global_step=root_checkpointable.global_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=root_checkpointable.global_step) + optimizer.minimize( + other_network(input_value), + global_step=root_checkpointable.global_step) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + named_variables, serialized_graph = checkpointable._serialize_object_graph( + root_checkpointable) + expected_checkpoint_names = ( + # Created in the root node, so no prefix. + "global_step", + # No name provided to track_checkpointable(), so the position is used + # instead (one-based). + "network/via_track_layer/kernel", + # track_checkpointable() with a name provided, so that's used + "network/_named_dense/kernel", + "network/_named_dense/bias", + # non-Layer dependency of the network + "network/_non_layer/a_variable", + # The optimizer creates two non-slot variables + "_optimizer/beta1_power", + "_optimizer/beta2_power", + # Slot variables + "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/m", + "network/via_track_layer/kernel/-OPTIMIZER_SLOT/_optimizer/v", + "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/m", + "network/_named_dense/kernel/-OPTIMIZER_SLOT/_optimizer/v", + "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", + "network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/v", + ) + six.assertCountEqual(self, expected_checkpoint_names, + named_variables.keys()) + # Check that we've mapped to the right variable objects (not exhaustive) + self.assertEqual("global_step:0", named_variables["global_step"].name) + self.assertEqual("my_network/checkpointable_dense_layer_1/kernel:0", + named_variables["network/via_track_layer/kernel"].name) + self.assertEqual("my_network/checkpointable_dense_layer/kernel:0", + named_variables["network/_named_dense/kernel"].name) + self.assertEqual("beta1_power:0", + named_variables["_optimizer/beta1_power"].name) + self.assertEqual("beta2_power:0", + named_variables["_optimizer/beta2_power"].name) + # Spot check the generated protocol buffers. + self.assertEqual("_optimizer", + serialized_graph.nodes[0].children[0].local_name) + optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[ + 0].node_id] + self.assertEqual("beta1_power", optimizer_node.variables[0].local_name) + self.assertEqual("beta1_power", optimizer_node.variables[0].full_name) + # Variable ordering is arbitrary but deterministic (alphabetized) + self.assertEqual( + "bias", optimizer_node.slot_variables[0].original_variable_local_name) + original_variable_owner = serialized_graph.nodes[ + optimizer_node.slot_variables[0].original_variable_node_id] + self.assertEqual("network/_named_dense/bias", + original_variable_owner.variables[0].checkpoint_key) + self.assertEqual("bias", original_variable_owner.variables[0].local_name) + self.assertEqual("m", optimizer_node.slot_variables[0].slot_name) + self.assertEqual("network/_named_dense/bias/-OPTIMIZER_SLOT/_optimizer/m", + optimizer_node.slot_variables[0].checkpoint_key) + # We strip off the :0 suffix, as variable.name-based saving does. + self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam", + optimizer_node.slot_variables[0].full_name) + self.assertEqual("my_network/checkpointable_dense_layer/bias/Adam:0", + optimizer.get_slot( + var=named_variables["network/_named_dense/bias"], + name="m").name) + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestore(self): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root_checkpointable = Root(optimizer=optimizer, network=network) + input_value = constant_op.constant([[3.]]) + if context.in_eager_mode(): + optimizer.minimize( + lambda: network(input_value), + global_step=root_checkpointable.global_step) + else: + train_op = optimizer.minimize( + network(input_value), global_step=root_checkpointable.global_step) + self.evaluate(variables.global_variables_initializer()) + self.evaluate(train_op) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + self.evaluate(state_ops.assign(network._named_dense.variables[1], [42.])) + m_bias_slot = optimizer.get_slot(network._named_dense.variables[1], "m") + self.evaluate(state_ops.assign(m_bias_slot, [1.5])) + serialized_graph, save_path = checkpointable.save( + file_prefix=prefix, + root_checkpointable=root_checkpointable, + global_step=root_checkpointable.global_step) + self.evaluate(state_ops.assign(network._named_dense.variables[1], [43.])) + self.evaluate(state_ops.assign(root_checkpointable.global_step, 3)) + optimizer_variables = self.evaluate(optimizer.variables()) + self.evaluate(state_ops.assign(m_bias_slot, [-2.])) + # Immediate restoration + checkpointable.restore( + save_path=save_path, + root_checkpointable=root_checkpointable, + object_graph_proto=serialized_graph) + self.assertAllEqual([42.], self.evaluate(network._named_dense.variables[1])) + self.assertAllEqual(1, self.evaluate(root_checkpointable.global_step)) + self.assertAllEqual([1.5], self.evaluate(m_bias_slot)) + with ops.Graph().as_default(): + on_create_network = MyNetwork() + on_create_optimizer = CheckpointableAdam(0.001) + on_create_root = Root( + optimizer=on_create_optimizer, network=on_create_network) + with self.test_session(graph=ops.get_default_graph()): + # Deferred restoration + checkpointable.restore( + save_path=save_path, + root_checkpointable=on_create_root, + object_graph_proto=serialized_graph) + on_create_network(constant_op.constant([[3.]])) # create variables + self.assertAllEqual(1, self.evaluate(on_create_root.global_step)) + self.assertAllEqual([42.], + self.evaluate( + on_create_network._named_dense.variables[1])) + on_create_m_bias_slot = on_create_optimizer.get_slot( + on_create_network._named_dense.variables[1], "m") + # Optimizer slot variables are created when the original variable is + # restored. + self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot)) + # beta1_power and beta2_power haven't been created yet, but everything + # else matches. + self.assertAllEqual(optimizer_variables[2:], + self.evaluate(on_create_optimizer.variables())) + on_create_optimizer._create_slots( + [resource_variable_ops.ResourceVariable([1.])]) + beta1_power, beta2_power = on_create_optimizer._get_beta_accumulators() + self.assertAllEqual(optimizer_variables[0], self.evaluate(beta1_power)) + self.assertAllEqual(optimizer_variables[1], self.evaluate(beta2_power)) + + def testDeferredRestorationUsageEager(self): + """An idiomatic eager execution example.""" + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + latest_object_graph = None # Will be saved with the checkpoint eventually. + for training_continuation in range(3): + with ops.Graph().as_default(): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Root(optimizer=optimizer, network=network) + checkpointable.restore( + save_path=core_saver.latest_checkpoint(checkpoint_directory), + root_checkpointable=root, + object_graph_proto=latest_object_graph) + for _ in range(num_training_steps): + # TODO(allenl): Use a Dataset and serialize/checkpoint it. + input_value = constant_op.constant([[3.]]) + optimizer.minimize( + lambda: network(input_value), # pylint: disable=cell-var-from-loop + global_step=root.global_step) + latest_object_graph, _ = checkpointable.save( + file_prefix=checkpoint_prefix, + root_checkpointable=root) + self.assertEqual((training_continuation + 1) * num_training_steps, + root.global_step.numpy()) + + def testUsageGraph(self): + """Expected usage when graph building.""" + with context.graph_mode(): + num_training_steps = 10 + checkpoint_directory = self.get_temp_dir() + checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt") + latest_object_graph = None + for training_continuation in range(3): + with ops.Graph().as_default(): + network = MyNetwork() + optimizer = CheckpointableAdam(0.001) + root = Root(optimizer=optimizer, network=network) + input_value = constant_op.constant([[3.]]) + train_op = optimizer.minimize( + network(input_value), + global_step=root.global_step) + init_op = variables.global_variables_initializer() + checkpoint_path = core_saver.latest_checkpoint(checkpoint_directory) + with self.test_session(graph=ops.get_default_graph()) as session: + if checkpoint_path is None: + self.assertEqual(0, training_continuation) + session.run(init_op) + # Another alternative would be to run initializers automatically + # if no checkpoint is being loaded. This would make deferred + # loading a bit more useful with graph execution. + else: + checkpointable.restore( + save_path=checkpoint_path, + root_checkpointable=root, + object_graph_proto=latest_object_graph, + session=session) + for _ in range(num_training_steps): + session.run(train_op) + latest_object_graph, _ = checkpointable.save( + file_prefix=checkpoint_prefix, + root_checkpointable=root, + session=session) + self.assertEqual((training_continuation + 1) * num_training_steps, + session.run(root.global_step)) + + def _get_checkpoint_name(self, name): + root = checkpointable.Checkpointable() + root.add_variable(name=name, shape=[1, 2], dtype=dtypes.float64) + named_variables, _ = checkpointable._serialize_object_graph(root) + checkpoint_name, = named_variables.keys() + with ops.name_scope("root/" + checkpoint_name): + pass # Make sure we can use this as an op name if we prefix it. + return checkpoint_name + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableNameEscaping(self): + self.assertEqual(r"a_S__b_S__c", self._get_checkpoint_name(r"a/b/c")) + self.assertEqual(r"b", self._get_checkpoint_name(r"b")) + self.assertEqual(r"c_S__", self._get_checkpoint_name(r"c/")) + self.assertEqual(r"d_S___S_._", self._get_checkpoint_name(r"d/_S__")) + + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testNumberedPath(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + root.track_checkpointable(leaf, name="leaf") + leaf.add_variable(name="v", shape=[]) + named_variables, _ = checkpointable._serialize_object_graph(root) + variable_name, = named_variables.keys() + self.assertEqual(r"leaf/v", variable_name) + + @test_util.run_in_graph_and_eager_modes() + def testLocalNameValidation(self): + root = checkpointable.Checkpointable() + leaf = checkpointable.Checkpointable() + with self.assertRaisesRegexp(ValueError, "invalid name"): + # Leading dashes are reserved, which avoids conflicts with un-named edges + # in paths and the optimizer slots identifier. + root.track_checkpointable(leaf, name="-unnamed-12") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index b559cce6b12a809d671ce7855680063f02a4ac22..d177bfeab2d1fdc05d7ced54df8723fae2c77fdb 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -23,6 +23,7 @@ import threading from tensorflow.contrib.data.python.ops import prefetching_ops from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.util import nest +from tensorflow.python.data.util import sparse from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -41,7 +42,7 @@ def _generate_shared_name(prefix): global _uid_counter uid = _uid_counter _uid_counter += 1 - return "{}_{}".format(prefix, uid) + return "{}{}".format(prefix, uid) class Iterator(object): @@ -75,13 +76,16 @@ class Iterator(object): format(type(self))) with ops.device("/device:CPU:0"): ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access + self._output_classes = dataset.output_classes self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes - self._flat_output_types = nest.flatten(dataset.output_types) - self._flat_output_shapes = nest.flatten(dataset.output_shapes) + self._flat_output_types = nest.flatten( + sparse.as_dense_types(self._output_types, self._output_classes)) + self._flat_output_shapes = nest.flatten( + sparse.as_dense_shapes(self._output_shapes, self._output_classes)) self._resource = gen_dataset_ops.iterator( - container="", - shared_name=_generate_shared_name("eager_iterator"), + shared_name="", + container=_generate_shared_name("eageriterator"), output_types=self._flat_output_types, output_shapes=self._flat_output_shapes) gen_dataset_ops.make_iterator(ds_variant, self._resource) @@ -108,7 +112,7 @@ class Iterator(object): remote_fn.add_to_graph(None) target = constant_op.constant("/device:CPU:0") with ops.device(self._device): - self._buffer_resource_handle = prefetching_ops.function_buffering_resource( + self._buffer_resource_handle = prefetching_ops.function_buffering_resource( # pylint: disable=line-too-long string_arg=iter_string_handle, f=remote_fn, target_device=target, @@ -116,8 +120,9 @@ class Iterator(object): thread_pool_size=1, container="", shared_name=_generate_shared_name("function_buffer_resource")) - self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._buffer_resource_handle, handle_device=self._device) + self._buffer_resource_deleter = resource_variable_ops.EagerResourceDeleter( # pylint: disable=line-too-long + handle=self._buffer_resource_handle, + handle_device=self._device) def __iter__(self): return self @@ -125,22 +130,83 @@ class Iterator(object): def __next__(self): # For Python 3 compatibility return self.next() - def next(self): - """Return the next tf.Tensor from the dataset.""" + def _next_internal(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ with ops.device(self._device): - try: - if self._buffer_resource_handle is not None: - ret = prefetching_ops.function_buffering_resource_get_next( - function_buffer_resource=self._buffer_resource_handle, - output_types=self._flat_output_types) - else: - # TODO(ashankar): Consider removing this ops.device() contextmanager - # and instead mimic ops placement in graphs: Operations on resource - # handles execute on the same device as where the resource is placed. - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - except errors.OutOfRangeError: - raise StopIteration - return nest.pack_sequence_as(self._output_types, ret) + if self._buffer_resource_handle is not None: + ret = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=self._buffer_resource_handle, + output_types=self._flat_output_types) + else: + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + # NOTE(mrry): Here we use the "_sync" variant of `iterator_get_next` + # because in eager mode this code will run synchronously on the calling + # thread. Therefore we do not need to make a defensive context switch + # to a background thread, and can achieve a small constant performance + # boost by invoking the iterator synchronously. + ret = gen_dataset_ops.iterator_get_next_sync( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + + return sparse.deserialize_sparse_tensors( + nest.pack_sequence_as(self._output_types, ret), self._output_types, + self._output_shapes, self._output_classes) + + def next(self): + """Returns a nested structure of `tf.Tensor`s containing the next element. + """ + try: + return self._next_internal() + except errors.OutOfRangeError: + raise StopIteration + + @property + def output_classes(self): + """Returns the class of each component of an element of this iterator. + + The expected values are `tf.Tensor` and `tf.SparseTensor`. + + Returns: + A nested structure of Python `type` objects corresponding to each + component of an element of this dataset. + """ + return self._output_classes + + @property + def output_shapes(self): + """Returns the shape of each component of an element of this iterator. + + Returns: + A nested structure of `tf.TensorShape` objects corresponding to each + component of an element of this dataset. + """ + return self._output_shapes + + @property + def output_types(self): + """Returns the type of each component of an element of this iterator. + + Returns: + A nested structure of `tf.DType` objects corresponding to each component + of an element of this dataset. + """ + return self._output_types + + def get_next(self, name=None): + """Returns a nested structure of `tf.Tensor`s containing the next element. + + Args: + name: (Optional.) A name for the created operation. Currently unused. + + Returns: + A nested structure of `tf.Tensor` objects. + + Raises: + `tf.errors.OutOfRangeError`: If the end of the dataset has been reached. + """ + del name + return self._next_internal() diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index c924d81c9d85e638e4f35f260664c0ee7d03257e..a1611e92b113839c2dd2a3b2560b0ba90c0a7ef0 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,11 +16,19 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import time + +import numpy as np + +from tensorflow.contrib import lookup from tensorflow.contrib.eager.python import datasets from tensorflow.python.data import Dataset from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops @@ -33,6 +41,15 @@ class IteratorTest(test.TestCase): got.append(t.numpy()) self.assertAllEqual([0, 1, 2, 3], got) + def testGetNext(self): + iterator = datasets.Iterator(Dataset.range(4)) + self.assertEqual(0, iterator.get_next().numpy()) + self.assertEqual(1, iterator.get_next().numpy()) + self.assertEqual(2, iterator.get_next().numpy()) + self.assertEqual(3, iterator.get_next().numpy()) + with self.assertRaises(errors.OutOfRangeError): + iterator.get_next() + def testMultipleIteratorsOnTheSameDataset(self): ds = Dataset.range(4) it1 = datasets.Iterator(ds) @@ -64,6 +81,18 @@ class IteratorTest(test.TestCase): got = [x.numpy() for x in it] self.assertAllEqual([0, 4, 16, 36], got) + def testMapCaptureLookupTable(self): + default_val = -1 + keys = constant_op.constant(['brain', 'salad', 'surgery']) + values = constant_op.constant([0, 1, 2], dtypes.int64) + table = lookup.HashTable( + lookup.KeyValueTensorInitializer(keys, values), default_val) + dataset = Dataset.from_tensor_slices(['brain', 'salad', 'surgery']) + dataset = dataset.map(table.lookup) + it = datasets.Iterator(dataset) + got = [x.numpy() for x in it] + self.assertAllEqual([0, 1, 2], got) + def testMultipleIteratorsOnADatasetThatUsesFunctions(self): ds = Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]).map(math_ops.square) @@ -72,6 +101,53 @@ class IteratorTest(test.TestCase): got2 = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual(got1, got2) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + + def testSparseTensorElements(self): + components = (sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 0], [2, 0]]), + values=np.array([0, 0, 0]), + dense_shape=np.array([3, 1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0, 0], [1, 1], [2, 2]]), + values=np.array([1, 2, 3]), + dense_shape=np.array([3, 3]))) + + expected = [ + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([1]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[1]]), + values=np.array([2]), + dense_shape=np.array([3]))), + (sparse_tensor.SparseTensorValue( + indices=np.array([[0]]), + values=np.array([0]), + dense_shape=np.array([1])), + sparse_tensor.SparseTensorValue( + indices=np.array([[2]]), + values=np.array([3]), + dense_shape=np.array([3]))), + ] + + for i, result in enumerate( + datasets.Iterator(Dataset.from_tensor_slices(components))): + self.assertSparseValuesEqual(expected[i][0], result[0]) + self.assertSparseValuesEqual(expected[i][1], result[1]) + def testPyFunc(self): def my_map(inp): @@ -90,5 +166,64 @@ class IteratorTest(test.TestCase): self.assertAllEqual([0., 2.], x.numpy()) +class DatasetConstructorBenchmark(test.Benchmark): + + def benchmarkSliceRepeatBatchEager(self): + input_size = 10000 + batch_size = 100 + num_epochs = 100 + + input_data = np.random.randn(input_size) + + dataset = ( + Dataset.from_tensor_slices(input_data).repeat(num_epochs) + .batch(batch_size)) + iterator = datasets.Iterator(dataset) + + ends = [time.time()] + for _ in iterator: + ends.append(time.time()) + + deltas = np.ediff1d(ends) + median_wall_time = np.median(deltas) + print( + 'Slice/repeat/batch eager input size: %d batch size: %d Median wall ' + 'time per element: %f' + % (input_size, batch_size, median_wall_time)) + self.report_benchmark( + iters=len(deltas), + wall_time=median_wall_time, + name='benchmark_slice_repeat_batch_eager_input_%d_batch_%d' % + (input_size, batch_size)) + + def benchmarkSliceBatchCacheRepeatCallable(self): + input_size = 10000 + batch_size = 100 + num_epochs = 100 + + input_data = np.random.randn(input_size) + + dataset = ( + Dataset.from_tensor_slices(input_data).batch(batch_size).cache() + .repeat(num_epochs)) + iterator = datasets.Iterator(dataset) + + ends = [time.time()] + for _ in iterator: + ends.append(time.time()) + + deltas = np.ediff1d(ends) + median_wall_time = np.median(deltas) + print( + 'Slice/batch/cache/repeat eager input size: %d batch size: %d Median ' + 'wall time per element: %f' + % (input_size, batch_size, median_wall_time)) + self.report_benchmark( + iters=len(deltas), + wall_time=median_wall_time, + name='benchmark_slice_batch_cache_repeat_eager_input_%d_batch_%d' % + (input_size, batch_size)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py index bd0ab02ecf7ae6025e08dde1c3ddc634db9255c1..68e7b5421fec7f73f10e381ca45f9d900de299d7 100644 --- a/tensorflow/contrib/eager/python/evaluator.py +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -110,7 +110,7 @@ class Evaluator(object): return self._all_metric_results() else: def f(): - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( summary_logdir).as_default(), summary_ops.always_record_summaries(): return self._all_metric_results() if context.in_eager_mode(): @@ -178,7 +178,7 @@ class Evaluator(object): call_op: An op that updates evaluation state on a mini-batch of examples. Must generate an tf.errors.OutOfRangeError when done. results_op: A dictionary of tensors that compute the final evaluation - results from the evaulation state. + results from the evaluation state. sess: The Session to run the evaluation in. Defaults to the default Session. diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD index aa21a6ab994acf929890ecebc07a86cf7ebf97db..15a21885f66eface291a39fa0ee1ff28bc297548 100644 --- a/tensorflow/contrib/eager/python/examples/BUILD +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -6,10 +6,12 @@ package(default_visibility = ["//tensorflow:internal"]) py_library( name = "examples_pip", deps = [ + "//tensorflow/contrib/eager/python/examples/gan:mnist", "//tensorflow/contrib/eager/python/examples/linear_regression", "//tensorflow/contrib/eager/python/examples/mnist", "//tensorflow/contrib/eager/python/examples/resnet50", "//tensorflow/contrib/eager/python/examples/rnn_colorbot", "//tensorflow/contrib/eager/python/examples/rnn_ptb", + "//tensorflow/contrib/eager/python/examples/spinn:data", ], ) diff --git a/tensorflow/contrib/eager/python/examples/gan/BUILD b/tensorflow/contrib/eager/python/examples/gan/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c61ec2dbae60a782c0e6589701554b045dcb92ae --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "mnist", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/examples/tutorials/mnist:input_data", + ], +) + +cuda_py_test( + name = "mnist_test", + srcs = ["mnist_test.py"], + additional_deps = [ + ":mnist", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "mnist_graph_test", + srcs = ["mnist_graph_test.py"], + additional_deps = [ + ":mnist", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/gan/README.md b/tensorflow/contrib/eager/python/examples/gan/README.md new file mode 100644 index 0000000000000000000000000000000000000000..208a64b05d47eea10b49a1bf967a5453677bfd21 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/README.md @@ -0,0 +1,38 @@ +# GAN with TensorFlow eager execution + +A simple Generative Adversarial Network (GAN) example using eager execution. +The discriminator and generator networks each contain a few convolution and +fully connected layers. + +Other eager execution examples can be found under the parent directory. + +## Content + +- `mnist.py`: Model definitions and training routines. +- `mnist_test.py`: Benchmarks for training and using the models using eager +execution. +- `mnist_graph_test.py`: Benchmarks for training and using the models using +graph execution. The same model definitions and loss functions are used in +all benchmarks. + + +## To run + +- Make sure you have installed TensorFlow 1.5+ or the latest `tf-nightly` +or `tf-nightly-gpu` pip package in order to access the eager execution feature. + +- Train model. E.g., + + ```bash + python mnist.py + ``` + + Use `--output_dir=` to direct the script to save TensorBoard summaries + during training. Disabled by default. + + Use `--checkpoint_dir=` to direct the script to save checkpoints to + `` during training. DIR defaults to /tmp/tensorflow/mnist/checkpoints/. + The script will load the latest saved checkpoint from this directory if + one exists. + + Use `-h` for other options. diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist.py b/tensorflow/contrib/eager/python/examples/gan/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..b9ac79f46c83bb709918e3b72830b90ddcfd71b4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist.py @@ -0,0 +1,368 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A deep MNIST classifier using convolutional layers. + +Sample usage: + python mnist.py --help +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import sys +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.examples.tutorials.mnist import input_data + +FLAGS = None + + +class Discriminator(tfe.Network): + """GAN Discriminator. + + A network to differentiate between generated and real handwritten digits. + """ + + def __init__(self, data_format): + """Creates a model for discriminating between real and generated digits. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + 'channels_first' is typically faster on GPUs while 'channels_last' is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Discriminator, self).__init__(name='') + if data_format == 'channels_first': + self._input_shape = [-1, 1, 28, 28] + else: + assert data_format == 'channels_last' + self._input_shape = [-1, 28, 28, 1] + self.conv1 = self.track_layer(tf.layers.Conv2D(64, 5, padding='SAME', + data_format=data_format, + activation=tf.tanh)) + self.pool1 = self.track_layer( + tf.layers.AveragePooling2D(2, 2, data_format=data_format)) + self.conv2 = self.track_layer(tf.layers.Conv2D(128, 5, + data_format=data_format, + activation=tf.tanh)) + self.pool2 = self.track_layer( + tf.layers.AveragePooling2D(2, 2, data_format=data_format)) + self.flatten = self.track_layer(tf.layers.Flatten()) + self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.tanh)) + self.fc2 = self.track_layer(tf.layers.Dense(1, activation=None)) + + def call(self, inputs): + """Return two logits per image estimating input authenticity. + + Users should invoke __call__ to run the network, which delegates to this + method (and not call this method directly). + + Args: + inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1] + or [batch_size, 1, 28, 28] + + Returns: + A Tensor with shape [batch_size] containing logits estimating + the probability that corresponding digit is real. + """ + x = tf.reshape(inputs, self._input_shape) + x = self.conv1(x) + x = self.pool1(x) + x = self.conv2(x) + x = self.pool2(x) + x = self.flatten(x) + x = self.fc1(x) + x = self.fc2(x) + return x + + +class Generator(tfe.Network): + """Generator of handwritten digits similar to the ones in the MNIST dataset. + """ + + def __init__(self, data_format): + """Creates a model for discriminating between real and generated digits. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + 'channels_first' is typically faster on GPUs while 'channels_last' is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(Generator, self).__init__(name='') + self.data_format = data_format + # We are using 128 6x6 channels as input to the first deconvolution layer + if data_format == 'channels_first': + self._pre_conv_shape = [-1, 128, 6, 6] + else: + assert data_format == 'channels_last' + self._pre_conv_shape = [-1, 6, 6, 128] + self.fc1 = self.track_layer(tf.layers.Dense(6 * 6 * 128, + activation=tf.tanh)) + + # In call(), we reshape the output of fc1 to _pre_conv_shape + + # Deconvolution layer. Resulting image shape: (batch, 14, 14, 64) + self.conv1 = self.track_layer(tf.layers.Conv2DTranspose( + 64, 4, strides=2, activation=None, data_format=data_format)) + + # Deconvolution layer. Resulting image shape: (batch, 28, 28, 1) + self.conv2 = self.track_layer(tf.layers.Conv2DTranspose( + 1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)) + + def call(self, inputs): + """Return a batch of generated images. + + Users should invoke __call__ to run the network, which delegates to this + method (and not call this method directly). + + Args: + inputs: A batch of noise vectors as a Tensor with shape + [batch_size, length of noise vectors]. + + Returns: + A Tensor containing generated images. If data_format is 'channels_last', + the shape of returned images is [batch_size, 28, 28, 1], else + [batch_size, 1, 28, 28] + """ + + x = self.fc1(inputs) + x = tf.reshape(x, shape=self._pre_conv_shape) + x = self.conv1(x) + x = self.conv2(x) + return x + + +def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs): + """Original discriminator loss for GANs, with label smoothing. + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more + details. + + Args: + discriminator_real_outputs: Discriminator output on real data. + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + + Returns: + A scalar loss Tensor. + """ + + loss_on_real = tf.losses.sigmoid_cross_entropy( + tf.ones_like(discriminator_real_outputs), discriminator_real_outputs, + label_smoothing=0.25) + loss_on_generated = tf.losses.sigmoid_cross_entropy( + tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs) + loss = loss_on_real + loss_on_generated + tf.contrib.summary.scalar('discriminator_loss', loss) + return loss + + +def generator_loss(discriminator_gen_outputs): + """Original generator loss for GANs. + + L = -log(sigmoid(D(G(z)))) + + See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) + for more details. + + Args: + discriminator_gen_outputs: Discriminator output on generated data. Expected + to be in the range of (-inf, inf). + + Returns: + A scalar loss Tensor. + """ + loss = tf.losses.sigmoid_cross_entropy( + tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs) + tf.contrib.summary.scalar('generator_loss', loss) + return loss + + +def train_one_epoch(generator, discriminator, + generator_optimizer, discriminator_optimizer, + dataset, log_interval, noise_dim): + """Trains `generator` and `discriminator` models on `dataset`. + + Args: + generator: Generator model. + discriminator: Discriminator model. + generator_optimizer: Optimizer to use for generator. + discriminator_optimizer: Optimizer to use for discriminator. + dataset: Dataset of images to train on. + log_interval: How many global steps to wait between logging and collecting + summaries. + noise_dim: Dimension of noise vector to use. + """ + + total_generator_loss = 0.0 + total_discriminator_loss = 0.0 + for (batch_index, images) in enumerate(tfe.Iterator(dataset)): + with tf.device('/cpu:0'): + tf.assign_add(tf.train.get_global_step(), 1) + + with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): + current_batch_size = images.shape[0] + noise = tf.random_uniform(shape=[current_batch_size, noise_dim], + minval=-1., maxval=1., seed=batch_index) + + with tfe.GradientTape(persistent=True) as g: + generated_images = generator(noise) + tf.contrib.summary.image('generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) + + discriminator_gen_outputs = discriminator(generated_images) + discriminator_real_outputs = discriminator(images) + discriminator_loss_val = discriminator_loss(discriminator_real_outputs, + discriminator_gen_outputs) + total_discriminator_loss += discriminator_loss_val + + generator_loss_val = generator_loss(discriminator_gen_outputs) + total_generator_loss += generator_loss_val + + generator_grad = g.gradient(generator_loss_val, generator.variables) + discriminator_grad = g.gradient(discriminator_loss_val, + discriminator.variables) + + with tf.variable_scope('generator'): + generator_optimizer.apply_gradients(zip(generator_grad, + generator.variables)) + with tf.variable_scope('discriminator'): + discriminator_optimizer.apply_gradients(zip(discriminator_grad, + discriminator.variables)) + + if log_interval and batch_index > 0 and batch_index % log_interval == 0: + print('Batch #%d\tAverage Generator Loss: %.6f\t' + 'Average Discriminator Loss: %.6f' % ( + batch_index, total_generator_loss/batch_index, + total_discriminator_loss/batch_index)) + + +def main(_): + (device, data_format) = ('/gpu:0', 'channels_first') + if FLAGS.no_gpu or tfe.num_gpus() <= 0: + (device, data_format) = ('/cpu:0', 'channels_last') + print('Using device %s, and data format %s.' % (device, data_format)) + + # Load the datasets + data = input_data.read_data_sets(FLAGS.data_dir) + dataset = (tf.data.Dataset + .from_tensor_slices(data.train.images) + .shuffle(60000) + .batch(FLAGS.batch_size)) + + # Create the models and optimizers + generator = Generator(data_format) + discriminator = Discriminator(data_format) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(FLAGS.lr) + + # Prepare summary writer and checkpoint info + summary_writer = tf.contrib.summary.create_summary_file_writer( + FLAGS.output_dir, flush_millis=1000) + checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') + latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir) + if latest_cpkt: + print('Using latest checkpoint at ' + latest_cpkt) + + with tf.device(device): + for epoch in range(1, 101): + with tfe.restore_variables_on_create(latest_cpkt): + global_step = tf.train.get_or_create_global_step() + start = time.time() + with summary_writer.as_default(): + train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + dataset, FLAGS.log_interval, FLAGS.noise) + end = time.time() + print('\nTrain time for epoch #%d (global step %d): %f' % ( + epoch, global_step.numpy(), end - start)) + + all_variables = ( + generator.variables + + discriminator.variables + + generator_optimizer.variables() + + discriminator_optimizer.variables() + + [global_step]) + tfe.Saver(all_variables).save( + checkpoint_prefix, global_step=global_step) + + +if __name__ == '__main__': + tfe.enable_eager_execution() + + parser = argparse.ArgumentParser() + parser.add_argument( + '--data-dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help=('Directory for storing input data (default ' + '/tmp/tensorflow/mnist/input_data)')) + parser.add_argument( + '--batch-size', + type=int, + default=128, + metavar='N', + help='input batch size for training (default: 128)') + parser.add_argument( + '--log-interval', + type=int, + default=100, + metavar='N', + help=('number of batches between logging and writing summaries ' + '(default: 100)')) + parser.add_argument( + '--output_dir', + type=str, + default=None, + metavar='DIR', + help='Directory to write TensorBoard summaries (defaults to none)') + parser.add_argument( + '--checkpoint_dir', + type=str, + default='/tmp/tensorflow/mnist/checkpoints/', + metavar='DIR', + help=('Directory to save checkpoints in (once per epoch) (default ' + '/tmp/tensorflow/mnist/checkpoints/)')) + parser.add_argument( + '--lr', + type=float, + default=0.001, + metavar='LR', + help='learning rate (default: 0.001)') + parser.add_argument( + '--noise', + type=int, + default=100, + metavar='N', + help='Length of noise vector for generator input (default: 100)') + parser.add_argument( + '--no-gpu', + action='store_true', + default=False, + help='disables GPU usage even if a GPU is available') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..12b39b0cde49d4c017acfa74572c725036c54eff --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_graph_test.py @@ -0,0 +1,151 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.gan import mnist + +NOISE_DIM = 100 +# Big enough so that summaries are never recorded. +# Lower this value if would like to benchmark with some summaries. +SUMMARY_INTERVAL = 10000 +SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +class MnistGraphGanBenchmark(tf.test.Benchmark): + + def _create_graph(self, batch_size): + # Generate some random data. + images_data = np.random.randn(batch_size, 784).astype(np.float32) + dataset = tf.data.Dataset.from_tensors(images_data) + images = dataset.repeat().make_one_shot_iterator().get_next() + + # Create the models and optimizers + generator = mnist.Generator(data_format()) + discriminator = mnist.Discriminator(data_format()) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(0.001) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(0.001) + + # Run models and compute loss + noise_placeholder = tf.placeholder(tf.float32, + shape=[batch_size, NOISE_DIM]) + generated_images = generator(noise_placeholder) + tf.contrib.summary.image('generated_images', + tf.reshape(generated_images, [-1, 28, 28, 1]), + max_images=10) + discriminator_gen_outputs = discriminator(generated_images) + discriminator_real_outputs = discriminator(images) + generator_loss = mnist.generator_loss(discriminator_gen_outputs) + discriminator_loss = mnist.discriminator_loss(discriminator_real_outputs, + discriminator_gen_outputs) + # Get train ops + with tf.variable_scope('generator'): + generator_train = generator_optimizer.minimize( + generator_loss, var_list=generator.variables) + with tf.variable_scope('discriminator'): + discriminator_train = discriminator_optimizer.minimize( + discriminator_loss, var_list=discriminator.variables) + + return (generator_train, discriminator_train, noise_placeholder) + + def _report(self, test_name, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tf.test.is_gpu_available() else 'cpu' + name = 'graph_%s_%s_batch_%d_%s' % (test_name, dev, batch_size, + data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_train(self): + for batch_size in [64, 128, 256]: + with tf.Graph().as_default(): + global_step = tf.train.get_or_create_global_step() + increment_global_step = tf.assign_add(global_step, 1) + with tf.contrib.summary.create_file_writer( + tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(), ( + tf.contrib.summary.record_summaries_every_n_global_steps( + SUMMARY_INTERVAL)): + (generator_train, discriminator_train, noise_placeholder + ) = self._create_graph(batch_size) + + with tf.Session() as sess: + tf.contrib.summary.initialize(graph=tf.get_default_graph(), + session=sess) + + sess.run(tf.global_variables_initializer()) + + num_burn, num_iters = (3, 100) + for _ in range(num_burn): + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + # Increment global step before evaluating summary ops to avoid + # race condition. + sess.run(increment_global_step) + sess.run([generator_train, discriminator_train, + tf.contrib.summary.all_summary_ops()], + feed_dict={noise_placeholder: noise}) + + # Run and benchmark 2 epochs + start = time.time() + for _ in range(num_iters): + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + sess.run(increment_global_step) + sess.run([generator_train, discriminator_train, + tf.contrib.summary.all_summary_ops()], + feed_dict={noise_placeholder: noise}) + self._report('train', start, num_iters, batch_size) + + def benchmark_generate(self): + for batch_size in [64, 128, 256]: + with tf.Graph().as_default(): + # Using random weights. This will generate garbage. + generator = mnist.Generator(data_format()) + noise_placeholder = tf.placeholder(tf.float32, + shape=[batch_size, NOISE_DIM]) + generated_images = generator(noise_placeholder) + + init = tf.global_variables_initializer() + with tf.Session() as sess: + sess.run(init) + noise = np.random.uniform(-1.0, 1.0, size=[batch_size, NOISE_DIM]) + num_burn, num_iters = (30, 1000) + for _ in range(num_burn): + sess.run(generated_images, feed_dict={noise_placeholder: noise}) + + start = time.time() + for _ in range(num_iters): + # Comparison with the eager execution benchmark in mnist_test.py + # isn't entirely fair as the time here includes the cost of copying + # the feeds from CPU memory to GPU. + sess.run(generated_images, feed_dict={noise_placeholder: noise}) + self._report('generate', start, num_iters, batch_size) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/gan/mnist_test.py b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4a3ca8d82bc2619b05a734f6d2e58431c1a45995 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/gan/mnist_test.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.gan import mnist + +NOISE_DIM = 100 +# Big enough so that summaries are never recorded. +# Lower this value if would like to benchmark with some summaries. +SUMMARY_INTERVAL = 10000 +SUMMARY_FLUSH_MS = 100 # Flush summaries every 100ms + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +def device(): + return '/gpu:0' if tfe.num_gpus() else '/cpu:0' + + +class MnistEagerGanBenchmark(tf.test.Benchmark): + + def _report(self, test_name, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tfe.num_gpus() else 'cpu' + name = 'eager_%s_%s_batch_%d_%s' % (test_name, dev, batch_size, + data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_train(self): + for batch_size in [64, 128, 256]: + # Generate some random data. + burn_batches, measure_batches = (3, 100) + burn_images = [tf.random_normal([batch_size, 784]) + for _ in range(burn_batches)] + burn_dataset = tf.data.Dataset.from_tensor_slices(burn_images) + measure_images = [tf.random_normal([batch_size, 784]) + for _ in range(measure_batches)] + measure_dataset = tf.data.Dataset.from_tensor_slices(measure_images) + + tf.train.get_or_create_global_step() + with tf.device(device()): + # Create the models and optimizers + generator = mnist.Generator(data_format()) + discriminator = mnist.Discriminator(data_format()) + with tf.variable_scope('generator'): + generator_optimizer = tf.train.AdamOptimizer(0.001) + with tf.variable_scope('discriminator'): + discriminator_optimizer = tf.train.AdamOptimizer(0.001) + + with tf.contrib.summary.create_file_writer( + tempfile.mkdtemp(), flush_millis=SUMMARY_FLUSH_MS).as_default(): + + # warm up + mnist.train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + burn_dataset, log_interval=SUMMARY_INTERVAL, + noise_dim=NOISE_DIM) + # measure + start = time.time() + mnist.train_one_epoch(generator, discriminator, generator_optimizer, + discriminator_optimizer, + measure_dataset, log_interval=SUMMARY_INTERVAL, + noise_dim=NOISE_DIM) + self._report('train', start, measure_batches, batch_size) + + def benchmark_generate(self): + for batch_size in [64, 128, 256]: + with tf.device(device()): + # Using random weights. This will generate garbage. + generator = mnist.Generator(data_format()) + + num_burn, num_iters = (30, 1000) + for _ in range(num_burn): + noise = tf.random_uniform(shape=[batch_size, NOISE_DIM], + minval=-1., maxval=1.) + generator(noise) + + start = time.time() + for _ in range(num_iters): + noise = tf.random_uniform(shape=[batch_size, NOISE_DIM], + minval=-1., maxval=1.) + generator(noise) + self._report('generate', start, num_iters, batch_size) + + +if __name__ == '__main__': + tfe.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD index bab7ad0c701b2110fda9a8d27792fd361a5fc1c0..f86331af6f7928f0f86c888e22706c6e0a5978b2 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -23,3 +23,13 @@ cuda_py_test( "//tensorflow:tensorflow_py", ], ) + +cuda_py_test( + name = "linear_regression_graph_test", + size = "small", + srcs = ["linear_regression_graph_test.py"], + additional_deps = [ + ":linear_regression", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py index d0130ebd118dbaff4f0161c8b2528764c6103e02..6ce4de6ee0bf50400eff339ac04e132252a2b53e 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -41,7 +41,7 @@ class LinearModel(tfe.Network): For those familiar with TensorFlow graphs, notice the absence of `tf.Session`. The `forward()` method here immediately executes and returns output values. The `loss()` method immediately compares the - output of `forward()` with the target adn returns the MSE loss value. + output of `forward()` with the target and returns the MSE loss value. The `fit()` performs gradient-descent training on the model's weights and bias. """ @@ -63,6 +63,10 @@ class LinearModel(tfe.Network): return self._hidden_layer(xs) +def mean_square_loss(model, xs, ys): + return tf.reduce_mean(tf.square(model(xs) - ys)) + + def fit(model, dataset, optimizer, verbose=False, logdir=None): """Fit the linear-regression model. @@ -76,16 +80,14 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): """ # The loss function to optimize. - def mean_square_loss(xs, ys): - return tf.reduce_mean(tf.square(model(xs) - ys)) - - loss_and_grads = tfe.implicit_value_and_gradients(mean_square_loss) + mse = lambda xs, ys: mean_square_loss(model, xs, ys) + loss_and_grads = tfe.implicit_value_and_gradients(mse) tf.train.get_or_create_global_step() if logdir: # Support for TensorBoard summaries. Once training has started, use: # tensorboard --logdir= - summary_writer = tf.contrib.summary.create_summary_file_writer(logdir) + summary_writer = tf.contrib.summary.create_file_writer(logdir) # Training loop. for i, (xs, ys) in enumerate(tfe.Iterator(dataset)): @@ -103,14 +105,20 @@ def fit(model, dataset, optimizer, verbose=False, logdir=None): def synthetic_dataset(w, b, noise_level, batch_size, num_batches): """tf.data.Dataset that yields synthetic data for linear regression.""" + return synthetic_dataset_helper(w, b, + tf.shape(w)[0], noise_level, batch_size, + num_batches) + +def synthetic_dataset_helper(w, b, num_features, noise_level, batch_size, + num_batches): # w is a matrix with shape [N, M] # b is a vector with shape [M] # So: # - Generate x's as vectors with shape [batch_size N] # - y = tf.matmul(x, W) + b + noise def batch(_): - x = tf.random_normal([batch_size, tf.shape(w)[0]]) + x = tf.random_normal([batch_size, num_features]) y = tf.matmul(x, w) + b + noise_level * tf.random_normal([]) return x, y diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..557ad42752144243ae3da61b955b31398cba846e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_graph_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Graph benchmark for linear regression, to contrast with eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.linear_regression import linear_regression + + +class GraphLinearRegressionBenchmark(tf.test.Benchmark): + + def benchmarkGraphLinearRegression(self): + num_epochs = 10 + num_batches = 200 + batch_size = 64 + dataset = linear_regression.synthetic_dataset_helper( + w=tf.random_uniform([3, 1]), + b=tf.random_uniform([1]), + num_features=3, + noise_level=0.01, + batch_size=batch_size, + num_batches=num_batches) + iterator = dataset.make_initializable_iterator() + x, y = iterator.get_next() + + model = linear_regression.LinearModel() + + if tf.test.is_gpu_available(): + use_gpu = True + device = "/device:GPU:0" + else: + use_gpu = False + device = "/device:CPU:0" + + with tf.device(device): + loss = linear_regression.mean_square_loss(model, x, y) + optimization_step = tf.train.GradientDescentOptimizer( + learning_rate=0.1).minimize(loss) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + + def train(num_epochs): + for _ in range(num_epochs): + sess.run(iterator.initializer) + try: + while True: + _, _ = sess.run([optimization_step, loss]) + except tf.errors.OutOfRangeError: + pass + + # Warmup: a single epoch. + train(1) + + start_time = time.time() + train(num_epochs) + wall_time = time.time() - start_time + + examples_per_sec = num_epochs * num_batches * batch_size / wall_time + self.report_benchmark( + name="graph_train_%s" % + ("gpu" if use_gpu else "cpu"), + iters=num_epochs * num_batches, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py index 39e7aabd7be04ba36a786a4c08d0df6c2ce916d0..e53234b51a7dccc11e548ac81a7ef070c628aa52 100644 --- a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -83,6 +83,7 @@ class LinearRegressionTest(tf.test.TestCase): class EagerLinearRegressionBenchmark(tf.test.Benchmark): def benchmarkEagerLinearRegression(self): + num_epochs = 10 num_batches = 200 batch_size = 64 dataset = linear_regression.synthetic_dataset( @@ -102,14 +103,15 @@ class EagerLinearRegressionBenchmark(tf.test.Benchmark): linear_regression.fit(model, burn_in_dataset, optimizer) start_time = time.time() - linear_regression.fit(model, dataset, optimizer) + for _ in range(num_epochs): + linear_regression.fit(model, dataset, optimizer) wall_time = time.time() - start_time - examples_per_sec = num_batches * batch_size / wall_time + examples_per_sec = num_epochs * num_batches * batch_size / wall_time self.report_benchmark( name="eager_train_%s" % ("gpu" if tfe.num_gpus() > 0 else "cpu"), - iters=num_batches, + iters=num_epochs * num_batches, extras={"examples_per_sec": examples_per_sec}, wall_time=wall_time) diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py index bfb7d5a9002787f6544d383de58150661ac2bde3..772f59562ba27cce510c82681f491d005298f44c 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -23,7 +23,6 @@ from __future__ import division from __future__ import print_function import argparse -import functools import os import sys import time @@ -40,7 +39,7 @@ class MNISTModel(tfe.Network): """MNIST Network. Network structure is equivalent to: - https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/tutorials/mnist/mnist_deep.py + https://github.com/tensorflow/tensorflow/blob/r1.6/tensorflow/examples/tutorials/mnist/mnist_deep.py and https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py @@ -96,8 +95,7 @@ class MNISTModel(tfe.Network): x = self.max_pool2d(x) x = tf.layers.flatten(x) x = self.fc1(x) - if training: - x = self.dropout(x) + x = self.dropout(x, training=training) x = self.fc2(x) return x @@ -124,21 +122,18 @@ def train_one_epoch(model, optimizer, dataset, log_interval=None): tf.train.get_or_create_global_step() - def model_loss(labels, images): - prediction = model(images, training=True) - loss_value = loss(prediction, labels) - tf.contrib.summary.scalar('loss', loss_value) - tf.contrib.summary.scalar('accuracy', - compute_accuracy(prediction, labels)) - return loss_value - for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): with tf.contrib.summary.record_summaries_every_n_global_steps(10): - batch_model_loss = functools.partial(model_loss, labels, images) - optimizer.minimize( - batch_model_loss, global_step=tf.train.get_global_step()) + with tfe.GradientTape() as tape: + prediction = model(images, training=True) + loss_value = loss(prediction, labels) + tf.contrib.summary.scalar('loss', loss_value) + tf.contrib.summary.scalar('accuracy', + compute_accuracy(prediction, labels)) + grads = tape.gradient(loss_value, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) if log_interval and batch % log_interval == 0: - print('Batch #%d\tLoss: %.6f' % (batch, batch_model_loss())) + print('Batch #%d\tLoss: %.6f' % (batch, loss_value)) def test(model, dataset): @@ -190,9 +185,9 @@ def main(_): else: train_dir = None test_dir = None - summary_writer = tf.contrib.summary.create_summary_file_writer( + summary_writer = tf.contrib.summary.create_file_writer( train_dir, flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( test_dir, flush_millis=10000, name='test') checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py index 205709fe2edd3c260c30a84b624e322e120edf8e..136085eba21284a42282395e54f32c33bf63b5c3 100644 --- a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py @@ -39,22 +39,40 @@ def random_dataset(): return tf.data.Dataset.from_tensors((images, labels)) +def train_one_epoch(defun=False): + model = mnist.MNISTModel(data_format()) + if defun: + model.call = tfe.defun(model.call) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + dataset = random_dataset() + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.train_one_epoch(model, optimizer, dataset) + + +def evaluate(defun=False): + model = mnist.MNISTModel(data_format()) + dataset = random_dataset() + if defun: + model.call = tfe.defun(model.call) + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.test(model, dataset) + + class MNISTTest(tf.test.TestCase): def testTrainOneEpoch(self): - model = mnist.MNISTModel(data_format()) - optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.train_one_epoch(model, optimizer, dataset) + train_one_epoch(defun=False) def testTest(self): - model = mnist.MNISTModel(data_format()) - dataset = random_dataset() - with tf.device(device()): - tf.train.get_or_create_global_step() - mnist.test(model, dataset) + evaluate(defun=False) + + def testTrainOneEpochWithDefunCall(self): + train_one_epoch(defun=True) + + def testTestWithDefunCall(self): + evaluate(defun=True) if __name__ == "__main__": diff --git a/tensorflow/contrib/eager/python/examples/resnet50/README.md b/tensorflow/contrib/eager/python/examples/resnet50/README.md index db023e6c976c8eda09ef0dee7eecb144678773c4..79e460052945718eac194653015d60d900998e2d 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/README.md +++ b/tensorflow/contrib/eager/python/examples/resnet50/README.md @@ -34,7 +34,7 @@ bazel run -c opt --config=cuda :resnet50_graph_test -- --benchmarks=. (Or remove the `--config=cuda` flag for running on CPU instead of GPU). -On October 31, 2017, the benchmarks demostrated comparable performance +On October 31, 2017, the benchmarks demonstrated comparable performance for eager and graph execution of this particular model when using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32. diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py index b302a87e0e8a61d2456db1eba847f31bd70f552e..9982fdb07eefa665379e7be095f4f8017d92cf97 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -97,7 +97,7 @@ class _ConvBlock(tfe.Network): Args: kernel_size: the kernel size of middle conv layer at main path - filters: list of integers, the filterss of 3 conv layer at main path + filters: list of integers, the filters of 3 conv layer at main path stage: integer, current stage label, used for generating layer names block: 'a','b'..., current block label, used for generating layer names data_format: data_format for the input ('channels_first' or diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 14c82c87a72457d414c4a1d3c53d4d1a68a400e6..23317886e712323f4b520000e0fd372734fc53a1 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -73,7 +73,7 @@ class ResNet50GraphTest(tf.test.TestCase): tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() with tf.contrib.summary.always_record_summaries(): - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(): model = resnet50.ResNet50(data_format()) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py index 582f4837c6f3197081cb558063e963866d173f29..0ff8746884c288f824f5f22ab4c550370d0e0302 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -22,6 +22,7 @@ import gc import tempfile import time +from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf import tensorflow.contrib.eager as tfe @@ -52,26 +53,33 @@ def random_batch(batch_size): def train_one_step(model, images, labels, optimizer): - def model_loss(): + with tfe.GradientTape() as tape: logits = model(images, training=True) loss = tf.losses.softmax_cross_entropy( logits=logits, onehot_labels=labels) tf.contrib.summary.scalar(name='loss', tensor=loss) - return loss - - optimizer.minimize(model_loss) + grads = tape.gradient(loss, model.variables) + optimizer.apply_gradients(zip(grads, model.variables)) class ResNet50Test(tf.test.TestCase): - def test_apply(self): + def _apply(self, defun=False): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) with tf.device(device): images, _ = random_batch(2) output = model(images) self.assertEqual((2, 1000), output.shape) + def test_apply(self): + self._apply(defun=False) + + def test_apply_with_defun(self): + self._apply(defun=True) + def test_apply_no_top(self): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format, include_top=False) @@ -95,7 +103,7 @@ class ResNet50Test(tf.test.TestCase): model = resnet50.ResNet50(data_format) tf.train.get_or_create_global_step() logdir = tempfile.mkdtemp() - with tf.contrib.summary.create_summary_file_writer( + with tf.contrib.summary.create_file_writer( logdir, max_queue=0, name='t0').as_default(), tf.contrib.summary.always_record_summaries(): with tf.device(device): @@ -175,9 +183,11 @@ class ResNet50Benchmarks(tf.test.Benchmark): # a sync. This is a roundabout way, yes. tf.constant(1.).cpu() - def benchmark_eager_apply(self): + def _benchmark_eager_apply(self, label, defun=False): device, data_format = device_and_data_format() model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) batch_size = 64 num_burn = 5 num_iters = 30 @@ -189,16 +199,23 @@ class ResNet50Benchmarks(tf.test.Benchmark): start = time.time() for _ in xrange(num_iters): model(images).cpu() - self._report('eager_apply', start, num_iters, device, batch_size, - data_format) + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_apply(self): + self._benchmark_eager_apply('eager_apply', defun=False) + + def benchmark_eager_apply_with_defun(self): + self._benchmark_eager_apply('eager_apply_with_defun', defun=True) - def _benchmark_eager_train(self, label, make_iterator): + def _benchmark_eager_train(self, label, make_iterator, defun=False): device, data_format = device_and_data_format() for batch_size in self._train_batch_sizes(): (images, labels) = random_batch(batch_size) num_burn = 3 num_iters = 10 model = resnet50.ResNet50(data_format) + if defun: + model.call = tfe.defun(model.call) optimizer = tf.train.GradientDescentOptimizer(0.1) with tf.device(device): @@ -217,7 +234,11 @@ class ResNet50Benchmarks(tf.test.Benchmark): self._report(label, start, num_iters, device, batch_size, data_format) def benchmark_eager_train(self): - self._benchmark_eager_train('eager_train', MockIterator) + self._benchmark_eager_train('eager_train', MockIterator, defun=False) + + def benchmark_eager_train_with_defun(self): + self._benchmark_eager_train( + 'eager_train_with_defun', MockIterator, defun=True) def benchmark_eager_train_datasets(self): @@ -226,7 +247,18 @@ class ResNet50Benchmarks(tf.test.Benchmark): ds = tf.data.Dataset.from_tensors(tensors).repeat() return tfe.Iterator(ds) - self._benchmark_eager_train('eager_train_dataset', make_iterator) + self._benchmark_eager_train( + 'eager_train_dataset', make_iterator, defun=False) + + def benchmark_eager_train_datasets_with_defun(self): + + def make_iterator(tensors): + with tf.device('/device:CPU:0'): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train( + 'eager_train_dataset_with_defun', make_iterator, defun=True) if __name__ == '__main__': diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py index 609cbd28772c3ae8da70648ca5b1b264a8a255e2..aa87b94e7b0876e65405f6bcb2d6aabde36582bf 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -65,7 +65,6 @@ import six import tensorflow as tf from tensorflow.contrib.eager.python import tfe -from tensorflow.python.eager import context try: import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top @@ -247,9 +246,9 @@ def main(_): log_dir = os.path.join(FLAGS.dir, "summaries") tf.gfile.MakeDirs(log_dir) - train_summary_writer = tf.contrib.summary.create_summary_file_writer( + train_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "train"), flush_millis=10000) - test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_summary_writer = tf.contrib.summary.create_file_writer( os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") with tf.device(device): diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md index 743ebb68ee5bba5635899267cc4839828f7e4e2f..966177e91c212c1aa132fe3af6f7dc9a50fb984e 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md @@ -40,7 +40,7 @@ bazel run -c opt --config=cuda :rnn_ptb_graph_test -- --benchmarks=. (Or remove the `--config=cuda` flag for running on CPU instead of GPU). -On October 31, 2017, the benchmarks demostrated slightly better performance +On October 31, 2017, the benchmarks demonstrated slightly better performance (3-6%) for graph execution over eager execution for this particular model when using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32. diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py index 30bb3c8ad33d38453bd96a76c7770071e24bb034..5c5c59c87744f4ffa6db90e5d8d3aa3bc8132756 100644 --- a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -22,6 +22,11 @@ Usage: python ./rnn_ptb.py --data-path= Penn Treebank (PTB) dataset from: http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz """ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + import argparse import os import sys @@ -83,7 +88,7 @@ class Embedding(tf.layers.Layer): class PTBModel(tfe.Network): - """LSTM for word language modelling. + """LSTM for word language modeling. Model described in: (Zaremba, et. al.) Recurrent Neural Network Regularization @@ -209,7 +214,7 @@ class Datasets(object): """Load the Penn Treebank dataset. Args: - path: Path to the data/ directory of the dataset from from Tomas Mikolov's + path: Path to the data/ directory of the dataset from Tomas Mikolov's webpage - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz """ @@ -334,8 +339,7 @@ if __name__ == "__main__": "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz") parser.add_argument( "--logdir", type=str, default="", help="Directory for checkpoint.") - parser.add_argument( - "--epoch", type=int, default=20, help="Number of epoches.") + parser.add_argument("--epoch", type=int, default=20, help="Number of epochs.") parser.add_argument("--batch-size", type=int, default=20, help="Batch size.") parser.add_argument( "--seq-len", type=int, default=35, help="Sequence length.") diff --git a/tensorflow/contrib/eager/python/examples/spinn/BUILD b/tensorflow/contrib/eager/python/examples/spinn/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a1f8a759e2a556bc219f0aa13942f293c4f34cfa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/BUILD @@ -0,0 +1,42 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "data", + srcs = ["data.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = ["//third_party/py/numpy"], +) + +py_test( + name = "data_test", + size = "small", + srcs = ["data_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":data", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "spinn_test", + size = "medium", + srcs = ["spinn_test.py"], + additional_deps = [ + ":data", + "//third_party/examples/eager/spinn", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/python/eager:test", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_test_lib", + ], + tags = ["no_pip"], # because spinn.py is under third_party/. +) diff --git a/tensorflow/contrib/eager/python/examples/spinn/README.md b/tensorflow/contrib/eager/python/examples/spinn/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eb0637df473e22e5d39ca1b0816464cb2b7c6435 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/README.md @@ -0,0 +1,13 @@ +# SPINN: Dynamic neural network with TensorFlow eager execution + +This directory contains files supporting the +[spinn.py model in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/spinn.py), +including + +- `data.py`: Utility library for loading and preprocessing the SNLI and GloVe + data. +- `data_test.py` and `spinn_test.py`: Unit tests for the data and model modules. + +See the [README.md in third_party/examples/eager/spinn/](../../../../../../third_party/examples/eager/spinn/README.md) +for detailed background, license and usage information regarding the SPINN code. + diff --git a/tensorflow/contrib/eager/python/examples/spinn/data.py b/tensorflow/contrib/eager/python/examples/spinn/data.py new file mode 100644 index 0000000000000000000000000000000000000000..3bc3bb49bcbbc26f7a3134a8bfc385ec080dde1e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data.py @@ -0,0 +1,373 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities of SNLI data and GloVe word vectors for SPINN model. + +See more details about the SNLI data set at: + https://nlp.stanford.edu/projects/snli/ + +See more details about the GloVe pretrained word embeddings at: + https://nlp.stanford.edu/projects/glove/ +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import math +import os +import random + +import numpy as np + +POSSIBLE_LABELS = ("entailment", "contradiction", "neutral") + +UNK_CODE = 0 # Code for unknown word tokens. +PAD_CODE = 1 # Code for padding tokens. + +SHIFT_CODE = 3 +REDUCE_CODE = 2 + +WORD_VECTOR_LEN = 300 # Embedding dimensions. + +LEFT_PAREN = "(" +RIGHT_PAREN = ")" +PARENTHESES = (LEFT_PAREN, RIGHT_PAREN) + + +def get_non_parenthesis_words(items): + """Get the non-parenthesis items from a SNLI parsed sentence. + + Args: + items: Data items from a parsed SNLI sentence, with parentheses. E.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of non-parentheses word items, all converted to lower case. E.g., + ["man", "wearing", "pass", ... + """ + return [x.lower() for x in items if x not in PARENTHESES and x] + + +def get_shift_reduce(items): + """Obtain shift-reduce vector from a list of items from the SNLI data. + + Args: + items: Data items as a list of str, e.g., + ["(", "Man", "(", "(", "(", "(", "(", "wearing", "pass", ")", ... + + Returns: + A list of shift-reduce transitions, encoded as `SHIFT_CODE` for shift and + `REDUCE_CODE` for reduce. See code above for the values of `SHIFT_CODE` + and `REDUCE_CODE`. + """ + trans = [] + for item in items: + if item == LEFT_PAREN: + continue + elif item == RIGHT_PAREN: + trans.append(REDUCE_CODE) + else: + trans.append(SHIFT_CODE) + return trans + + +def pad_and_reverse_word_ids(sentences): + """Pad a list of sentences to the common maximum length + 1. + + Args: + sentences: A list of sentences as a list of list of integers. Each integer + is a word ID. Each list of integer corresponds to one sentence. + + Returns: + A numpy.ndarray of shape (num_sentences, max_length + 1), wherein max_length + is the maximum sentence length (in # of words). Each sentence is reversed + and then padded with an extra one at head, as required by the model. + """ + max_len = max(len(sent) for sent in sentences) + for sent in sentences: + if len(sent) < max_len: + sent.extend([PAD_CODE] * (max_len - len(sent))) + # Reverse in time order and pad an extra one. + sentences = np.fliplr(np.array(sentences, dtype=np.int64)) + sentences = np.concatenate( + [np.ones([sentences.shape[0], 1], dtype=np.int64), sentences], axis=1) + return sentences + + +def pad_transitions(sentences_transitions): + """Pad a list of shift-reduce transitions to the maximum length.""" + max_len = max(len(transitions) for transitions in sentences_transitions) + for transitions in sentences_transitions: + if len(transitions) < max_len: + transitions.extend([PAD_CODE] * (max_len - len(transitions))) + return np.array(sentences_transitions, dtype=np.int64) + + +def load_vocabulary(data_root): + """Load vocabulary from SNLI data files. + + Args: + data_root: Root directory of the data. It is assumed that the SNLI data + files have been downloaded and extracted to the "snli/snli_1.0" + subdirectory of it. + + Returns: + Vocabulary as a set of strings. + + Raises: + ValueError: If SNLI data files cannot be found. + """ + snli_path = os.path.join(data_root, "snli") + snli_glob_pattern = os.path.join(snli_path, "snli_1.0/snli_1.0_*.txt") + file_names = glob.glob(snli_glob_pattern) + if not file_names: + raise ValueError( + "Cannot find SNLI data files at %s. " + "Please download and extract SNLI data first." % snli_glob_pattern) + + print("Loading vocabulary...") + vocab = set() + for file_name in file_names: + with open(os.path.join(snli_path, file_name), "rt") as f: + for i, line in enumerate(f): + if i == 0: + continue + items = line.split("\t") + premise_words = get_non_parenthesis_words(items[1].split(" ")) + hypothesis_words = get_non_parenthesis_words(items[2].split(" ")) + vocab.update(premise_words) + vocab.update(hypothesis_words) + return vocab + + +def load_word_vectors(data_root, vocab): + """Load GloVe word vectors for words present in the vocabulary. + + Args: + data_root: Data root directory. It is assumed that the GloVe file + has been downloaded and extracted at the "glove/" subdirectory of it. + vocab: A `set` of words, representing the vocabulary. + + Returns: + 1. word2index: A dict from lower-case word to row index in the embedding + matrix, i.e, `embed` below. + 2. embed: The embedding matrix as a float32 numpy array. Its shape is + [vocabulary_size, WORD_VECTOR_LEN]. vocabulary_size is len(vocab). + WORD_VECTOR_LEN is the embedding dimension (300). + + Raises: + ValueError: If GloVe embedding file cannot be found. + """ + glove_path = os.path.join(data_root, "glove/glove.42B.300d.txt") + if not os.path.isfile(glove_path): + raise ValueError( + "Cannot find GloVe embedding file at %s. " + "Please download and extract GloVe embeddings first." % glove_path) + + print("Loading word vectors...") + + word2index = dict() + embed = [] + + embed.append([0] * WORD_VECTOR_LEN) # + embed.append([0] * WORD_VECTOR_LEN) # + word2index[""] = UNK_CODE + word2index[""] = PAD_CODE + + with open(glove_path, "rt") as f: + for line in f: + items = line.split(" ") + word = items[0] + if word in vocab and word not in word2index: + word2index[word] = len(embed) + vector = np.array([float(item) for item in items[1:]]) + assert (WORD_VECTOR_LEN,) == vector.shape + embed.append(vector) + embed = np.array(embed, dtype=np.float32) + return word2index, embed + + +def calculate_bins(length2count, min_bin_size): + """Calculate bin boundaries given a histogram of lengths and minimum bin size. + + Args: + length2count: A `dict` mapping length to sentence count. + min_bin_size: Minimum bin size in terms of total number of sentence pairs + in the bin. + + Returns: + A `list` representing the right bin boundaries, starting from the inclusive + right boundary of the first bin. For example, if the output is + [10, 20, 35], + it means there are three bins: [1, 10], [11, 20] and [21, 35]. + """ + bounds = [] + lengths = sorted(length2count.keys()) + cum_count = 0 + for length in lengths: + cum_count += length2count[length] + if cum_count >= min_bin_size: + bounds.append(length) + cum_count = 0 + if bounds[-1] != lengths[-1]: + bounds.append(lengths[-1]) + return bounds + + +def encode_sentence(sentence, word2index): + """Encode a single sentence as word indices and shift-reduce code. + + Args: + sentence: The sentence with added binary parse information, represented as + a string, with all the word items and parentheses separated by spaces. + E.g., '( ( The dog ) ( ( is ( playing toys ) ) . ) )'. + word2index: A `dict` mapping words to their word indices. + + Returns: + 1. Word indices as a numpy array, with shape `(sequence_len, 1)`. + 2. Shift-reduce sequence as a numpy array, with shape + `(sequence_len * 2 - 3, 1)`. + """ + items = [w for w in sentence.split(" ") if w] + words = get_non_parenthesis_words(items) + shift_reduce = get_shift_reduce(items) + word_indices = pad_and_reverse_word_ids( + [[word2index.get(word, UNK_CODE) for word in words]]).T + return (word_indices, + np.expand_dims(np.array(shift_reduce, dtype=np.int64), -1)) + + +class SnliData(object): + """A split of SNLI data.""" + + def __init__(self, data_file, word2index, sentence_len_limit=-1): + """SnliData constructor. + + Args: + data_file: Full path to the data file, e.g., + "/tmp/spinn-data/snli/snli_1.0/snli_1.0.train.txt" + word2index: A dict from lower-case word to row index in the embedding + matrix (see `load_word_vectors()` for details). + sentence_len_limit: Maximum allowed sentence length (# of words). + A value of <= 0 means unlimited. Sentences longer than this limit + are currently discarded, not truncated. + """ + + self._labels = [] + self._premises = [] + self._premise_transitions = [] + self._hypotheses = [] + self._hypothesis_transitions = [] + + with open(data_file, "rt") as f: + for i, line in enumerate(f): + if i == 0: + # Skip header line. + continue + items = line.split("\t") + if items[0] not in POSSIBLE_LABELS: + continue + + premise_items = items[1].split(" ") + hypothesis_items = items[2].split(" ") + premise_words = get_non_parenthesis_words(premise_items) + hypothesis_words = get_non_parenthesis_words(hypothesis_items) + + if (sentence_len_limit > 0 and + (len(premise_words) > sentence_len_limit or + len(hypothesis_words) > sentence_len_limit)): + # TODO(cais): Maybe truncate; do not discard. + continue + + premise_ids = [ + word2index.get(word, UNK_CODE) for word in premise_words] + hypothesis_ids = [ + word2index.get(word, UNK_CODE) for word in hypothesis_words] + + self._premises.append(premise_ids) + self._hypotheses.append(hypothesis_ids) + self._premise_transitions.append(get_shift_reduce(premise_items)) + self._hypothesis_transitions.append(get_shift_reduce(hypothesis_items)) + assert (len(self._premise_transitions[-1]) == + 2 * len(premise_words) - 1) + assert (len(self._hypothesis_transitions[-1]) == + 2 * len(hypothesis_words) - 1) + + self._labels.append(POSSIBLE_LABELS.index(items[0]) + 1) + + assert len(self._labels) == len(self._premises) + assert len(self._labels) == len(self._hypotheses) + assert len(self._labels) == len(self._premise_transitions) + assert len(self._labels) == len(self._hypothesis_transitions) + + def num_batches(self, batch_size): + """Calculate number of batches given batch size.""" + return int(math.ceil(len(self._labels) / batch_size)) + + def get_generator(self, batch_size): + """Obtain a generator for batched data. + + All examples of this SnliData object are randomly shuffled, sorted + according to the maximum sentence length of the premise and hypothesis + sentences in the pair, and batched. + + Args: + batch_size: Desired batch size. + + Returns: + A generator for data batches. The generator yields a 5-tuple: + label: An array of the shape (batch_size,). + premise: An array of the shape (max_premise_len, batch_size), wherein + max_premise_len is the maximum length of the (padded) premise + sentence in the batch. + premise_transitions: An array of the shape (2 * max_premise_len -3, + batch_size). + hypothesis: Same as `premise`, but for hypothesis sentences. + hypothesis_transitions: Same as `premise_transitions`, but for + hypothesis sentences. + All the elements of the 5-tuple have dtype `int64`. + """ + # Randomly shuffle examples. + zipped = list(zip( + self._labels, self._premises, self._premise_transitions, + self._hypotheses, self._hypothesis_transitions)) + random.shuffle(zipped) + # Then sort the examples by maximum of the premise and hypothesis sentence + # lengths in the pair. During training, the batches are expected to be + # shuffled. So it is okay to leave them sorted by max length here. + (labels, premises, premise_transitions, hypotheses, + hypothesis_transitions) = zip( + *sorted(zipped, key=lambda x: max(len(x[1]), len(x[3])))) + + def _generator(): + begin = 0 + while begin < len(labels): + # The sorting above and the batching here makes sure that sentences of + # similar max lengths are batched together, minimizing the inefficiency + # due to uneven max lengths. The sentences are batched differently in + # each call to get_generator() due to the shuffling before sorting + # above. The pad_and_reverse_word_ids() and pad_transitions() functions + # take care of any remaining unevenness of the max sentence lengths. + end = min(begin + batch_size, len(labels)) + # Transpose, because the SPINN model requires time-major, instead of + # batch-major. + yield (labels[begin:end], + pad_and_reverse_word_ids(premises[begin:end]).T, + pad_transitions(premise_transitions[begin:end]).T, + pad_and_reverse_word_ids(hypotheses[begin:end]).T, + pad_transitions(hypothesis_transitions[begin:end]).T) + begin = end + return _generator diff --git a/tensorflow/contrib/eager/python/examples/spinn/data_test.py b/tensorflow/contrib/eager/python/examples/spinn/data_test.py new file mode 100644 index 0000000000000000000000000000000000000000..54fef2c3fe4111cd2d93ac109a5b8fffad0c2fad --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/data_test.py @@ -0,0 +1,270 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for SPINN data module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.spinn import data + + +class DataTest(tf.test.TestCase): + + def setUp(self): + super(DataTest, self).setUp() + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(DataTest, self).tearDown() + + def testGenNonParenthesisWords(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + ["man", "wearing", "pass", "on", "a", "lanyard", "and", "standing", + "in", "a", "crowd", "of", "people", "."], + data.get_non_parenthesis_words(seq_with_parse.split(" "))) + + def testGetShiftReduce(self): + seq_with_parse = ( + "( Man ( ( ( ( ( wearing pass ) ( on ( a lanyard ) ) ) and " + ") ( standing ( in ( ( a crowd ) ( of people ) ) ) ) ) . ) )") + self.assertEqual( + [3, 3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 2, 3, 3, 3, 3, 2, 3, 3, 2, 2, 2, 2, 2, + 3, 2, 2], data.get_shift_reduce(seq_with_parse.split(" "))) + + def testPadAndReverseWordIds(self): + id_sequences = [[0, 2, 3, 4, 5], + [6, 7, 8], + [9, 10, 11, 12, 13, 14, 15, 16]] + self.assertAllClose( + [[1, 1, 1, 1, 5, 4, 3, 2, 0], + [1, 1, 1, 1, 1, 1, 8, 7, 6], + [1, 16, 15, 14, 13, 12, 11, 10, 9]], + data.pad_and_reverse_word_ids(id_sequences)) + + def testPadTransitions(self): + unpadded = [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2]] + self.assertAllClose( + [[3, 3, 3, 2, 2, 2, 2], + [3, 3, 2, 2, 2, 1, 1]], + data.pad_transitions(unpadded)) + + def testCalculateBins(self): + length2count = { + 1: 10, + 2: 15, + 3: 25, + 4: 40, + 5: 35, + 6: 10} + self.assertEqual([2, 3, 4, 5, 6], + data.calculate_bins(length2count, 20)) + self.assertEqual([3, 4, 6], data.calculate_bins(length2count, 40)) + self.assertEqual([4, 6], data.calculate_bins(length2count, 60)) + + def testLoadVoacbulary(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + fake_dev_file = os.path.join(snli_1_0_dir, "snli_1.0_dev.txt") + os.makedirs(snli_1_0_dir) + + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo baz ) . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + with open(fake_dev_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Quux quuz ) ? )\t( ( Corge grault ) ! )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Quux quuz?\t.Corge grault!\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + vocab = data.load_vocabulary(self._temp_data_dir) + self.assertSetEqual( + {".", "?", "!", "foo", "bar", "baz", "quux", "quuz", "corge", "grault"}, + vocab) + + def testLoadVoacbularyWithoutFileRaisesError(self): + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + os.makedirs(os.path.join(self._temp_data_dir, "snli/snli_1.0")) + with self.assertRaisesRegexp(ValueError, "Cannot find SNLI data files at"): + data.load_vocabulary(self._temp_data_dir) + + def testLoadWordVectors(self): + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", ",", "foo", "bar", "baz"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + vocab = {"foo", "bar", "baz", "qux", "."} + # Notice that "qux" is not present in `words`. + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + self.assertEqual(6, len(word2index)) + self.assertEqual(0, word2index[""]) + self.assertEqual(1, word2index[""]) + self.assertEqual(2, word2index["."]) + self.assertEqual(3, word2index["foo"]) + self.assertEqual(4, word2index["bar"]) + self.assertEqual(5, word2index["baz"]) + self.assertEqual((6, data.WORD_VECTOR_LEN), embed.shape) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[0, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[1, :]) + self.assertAllClose([0.0] * data.WORD_VECTOR_LEN, embed[2, :]) + self.assertAllClose([0.2] * data.WORD_VECTOR_LEN, embed[3, :]) + self.assertAllClose([0.3] * data.WORD_VECTOR_LEN, embed[4, :]) + self.assertAllClose([0.4] * data.WORD_VECTOR_LEN, embed[5, :]) + + def testLoadWordVectorsWithoutFileRaisesError(self): + vocab = {"foo", "bar", "baz", "qux", "."} + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + os.makedirs(os.path.join(self._temp_data_dir, "glove")) + with self.assertRaisesRegexp( + ValueError, "Cannot find GloVe embedding file at"): + data.load_word_vectors(self._temp_data_dir, vocab) + + def _createFakeSnliData(self, fake_snli_file): + # Four sentences in total. + with open(fake_snli_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + def _createFakeGloveData(self, glove_file): + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + def testEncodeSingleSentence(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + vocab = data.load_vocabulary(self._temp_data_dir) + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + sentence_variants = [ + "( Foo ( ( bar baz ) . ) )", + " ( Foo ( ( bar baz ) . ) ) ", + "( Foo ( ( bar baz ) . ) )"] + for sentence in sentence_variants: + word_indices, shift_reduce = data.encode_sentence(sentence, word2index) + self.assertEqual(np.int64, word_indices.dtype) + self.assertEqual((5, 1), word_indices.shape) + self.assertAllClose( + np.array([[3, 3, 3, 2, 3, 2, 2]], dtype=np.int64).T, shift_reduce) + + def testSnliData(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + self._createFakeSnliData(fake_train_file) + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + self._createFakeGloveData(glove_file) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, _ = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + self.assertEqual(4, train_data.num_batches(1)) + self.assertEqual(2, train_data.num_batches(2)) + self.assertEqual(2, train_data.num_batches(3)) + self.assertEqual(1, train_data.num_batches(4)) + + generator = train_data.get_generator(2)() + for _ in range(2): + label, prem, prem_trans, hypo, hypo_trans = next(generator) + self.assertEqual(2, len(label)) + self.assertEqual((4, 2), prem.shape) + self.assertEqual((5, 2), prem_trans.shape) + self.assertEqual((3, 2), hypo.shape) + self.assertEqual((3, 2), hypo_trans.shape) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eefc06d90d83b61d07a613643c913d3833a5f2c1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/spinn/spinn_test.py @@ -0,0 +1,475 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import gc +import glob +import os +import shutil +import tempfile +import time + +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +# pylint: disable=g-bad-import-order +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.spinn import data +from third_party.examples.eager.spinn import spinn +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.eager import test +from tensorflow.python.framework import test_util +from tensorflow.python.training import checkpoint_utils +# pylint: enable=g-bad-import-order + + +def _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size): + """Generate a fake batch of SNLI data for testing.""" + with tf.device("cpu:0"): + labels = tf.random_uniform([batch_size], minval=1, maxval=4, dtype=tf.int64) + prem = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + prem_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + hypo = tf.random_uniform( + (sequence_length, batch_size), maxval=vocab_size, dtype=tf.int64) + hypo_trans = tf.constant(np.array( + [[3, 3, 2, 3, 3, 3, 2, 2, 2, 3, 3, 3, + 2, 3, 3, 2, 2, 3, 3, 3, 2, 2, 2, 2, + 3, 2, 2]] * batch_size, dtype=np.int64).T) + if tfe.num_gpus(): + labels = labels.gpu() + prem = prem.gpu() + prem_trans = prem_trans.gpu() + hypo = hypo.gpu() + hypo_trans = hypo_trans.gpu() + return labels, prem, prem_trans, hypo, hypo_trans + + +def _test_spinn_config(d_embed, d_out, logdir=None, inference_sentences=None): + """Generate a config tuple for testing. + + Args: + d_embed: Embedding dimensions. + d_out: Model output dimensions. + logdir: Optional logdir. + inference_sentences: A 2-tuple of strings representing the sentences (with + binary parsing result), e.g., + ("( ( The dog ) ( ( is running ) . ) )", "( ( The dog ) ( moves . ) )"). + + Returns: + A config tuple. + """ + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict", + "embed_dropout", "mlp_dropout", "n_mlp_layers", "d_mlp", + "d_out", "projection", "lr", "batch_size", "epochs", + "force_cpu", "logdir", "log_every", "dev_every", "save_every", + "lr_decay_every", "lr_decay_by", "inference_premise", + "inference_hypothesis"]) + + inference_premise = inference_sentences[0] if inference_sentences else None + inference_hypothesis = inference_sentences[1] if inference_sentences else None + return config_tuple( + d_hidden=d_embed, + d_proj=d_embed * 2, + d_tracker=8, + predict=False, + embed_dropout=0.1, + mlp_dropout=0.1, + n_mlp_layers=2, + d_mlp=32, + d_out=d_out, + projection=True, + lr=2e-2, + batch_size=2, + epochs=20, + force_cpu=False, + logdir=logdir, + log_every=1, + dev_every=2, + save_every=2, + lr_decay_every=1, + lr_decay_by=0.75, + inference_premise=inference_premise, + inference_hypothesis=inference_hypothesis) + + +class SpinnTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(SpinnTest, self).setUp() + self._test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + self._temp_data_dir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._temp_data_dir) + super(SpinnTest, self).tearDown() + + def testBundle(self): + with tf.device(self._test_device): + lstm_iter = [np.array([[0, 1], [2, 3]], dtype=np.float32), + np.array([[0, -1], [-2, -3]], dtype=np.float32), + np.array([[0, 2], [4, 6]], dtype=np.float32), + np.array([[0, -2], [-4, -6]], dtype=np.float32)] + out = spinn._bundle(lstm_iter) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 2, 0, -2, 0, 4, 0, -4]]).T, + out[0].numpy()) + self.assertAllEqual(np.array([[1, 3, -1, -3, 2, 6, -2, -6]]).T, + out[1].numpy()) + + def testUnbunbdle(self): + with tf.device(self._test_device): + state = [np.array([[0, 1, 2], [3, 4, 5]], dtype=np.float32), + np.array([[0, -1, -2], [-3, -4, -5]], dtype=np.float32)] + out = spinn._unbundle(state) + + self.assertEqual(2, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual(tf.float32, out[1].dtype) + self.assertAllEqual(np.array([[0, 1, 2, 0, -1, -2]]), + out[0].numpy()) + self.assertAllEqual(np.array([[3, 4, 5, -3, -4, -5]]), + out[1].numpy()) + + def testReducer(self): + with tf.device(self._test_device): + batch_size = 3 + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + left_in = [] + right_in = [] + tracking = [] + for _ in range(batch_size): + left_in.append(tf.random_normal((1, size * 2))) + right_in.append(tf.random_normal((1, size * 2))) + tracking.append(tf.random_normal((1, tracker_size * 2))) + + out = reducer(left_in, right_in, tracking=tracking) + self.assertEqual(batch_size, len(out)) + self.assertEqual(tf.float32, out[0].dtype) + self.assertEqual((1, size * 2), out[0].shape) + + def testReduceTreeLSTM(self): + with tf.device(self._test_device): + size = 10 + tracker_size = 8 + reducer = spinn.Reducer(size, tracker_size=tracker_size) + + lstm_in = np.array([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [0, -1, -2, -3, -4, -5, -6, -7, -8, -9]], + dtype=np.float32) + c1 = np.array([[0, 1], [2, 3]], dtype=np.float32) + c2 = np.array([[0, -1], [-2, -3]], dtype=np.float32) + + h, c = reducer._tree_lstm(c1, c2, lstm_in) + self.assertEqual(tf.float32, h.dtype) + self.assertEqual(tf.float32, c.dtype) + self.assertEqual((2, 2), h.shape) + self.assertEqual((2, 2), c.shape) + + def testTracker(self): + with tf.device(self._test_device): + batch_size = 2 + size = 10 + tracker_size = 8 + buffer_length = 18 + stack_size = 3 + + tracker = spinn.Tracker(tracker_size, False) + tracker.reset_state() + + # Create dummy inputs for testing. + bufs = [] + buf = [] + for _ in range(buffer_length): + buf.append(tf.random_normal((batch_size, size * 2))) + bufs.append(buf) + self.assertEqual(1, len(bufs)) + self.assertEqual(buffer_length, len(bufs[0])) + self.assertEqual((batch_size, size * 2), bufs[0][0].shape) + + stacks = [] + stack = [] + for _ in range(stack_size): + stack.append(tf.random_normal((batch_size, size * 2))) + stacks.append(stack) + self.assertEqual(1, len(stacks)) + self.assertEqual(3, len(stacks[0])) + self.assertEqual((batch_size, size * 2), stacks[0][0].shape) + + for _ in range(2): + out1, out2 = tracker(bufs, stacks) + self.assertIsNone(out2) + self.assertEqual(batch_size, len(out1)) + self.assertEqual(tf.float32, out1[0].dtype) + self.assertEqual((1, tracker_size * 2), out1[0].shape) + + self.assertEqual(tf.float32, tracker.state.c.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.c.shape) + self.assertEqual(tf.float32, tracker.state.h.dtype) + self.assertEqual((batch_size, tracker_size), tracker.state.h.shape) + + def testSPINN(self): + with tf.device(self._test_device): + embedding_dims = 10 + d_tracker = 8 + sequence_length = 15 + num_transitions = 27 + + config_tuple = collections.namedtuple( + "Config", ["d_hidden", "d_proj", "d_tracker", "predict"]) + config = config_tuple( + embedding_dims, embedding_dims * 2, d_tracker, False) + s = spinn.SPINN(config) + + # Create some fake data. + buffers = tf.random_normal((sequence_length, 1, config.d_proj)) + transitions = tf.constant( + [[3], [3], [2], [3], [3], [3], [2], [2], [2], [3], [3], [3], + [2], [3], [3], [2], [2], [3], [3], [3], [2], [2], [2], [2], + [3], [2], [2]], dtype=tf.int64) + self.assertEqual(tf.int64, transitions.dtype) + self.assertEqual((num_transitions, 1), transitions.shape) + + out = s(buffers, transitions, training=True) + self.assertEqual(tf.float32, out.dtype) + self.assertEqual((1, embedding_dims), out.shape) + + def testSNLIClassifierAndTrainer(self): + with tf.device(self._test_device): + vocab_size = 40 + batch_size = 2 + d_embed = 10 + sequence_length = 15 + d_out = 4 + + config = _test_spinn_config(d_embed, d_out) + + # Create fake embedding matrix. + embed = tf.random_normal((vocab_size, d_embed)) + + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + # Invoke model under non-training mode. + logits = model(prem, prem_trans, hypo, hypo_trans, training=False) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Invoke model under training model. + logits = model(prem, prem_trans, hypo, hypo_trans, training=True) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + + # Calculate loss. + loss1 = trainer.loss(labels, logits) + self.assertEqual(tf.float32, loss1.dtype) + self.assertEqual((), loss1.shape) + + loss2, logits = trainer.train_batch( + labels, prem, prem_trans, hypo, hypo_trans) + self.assertEqual(tf.float32, loss2.dtype) + self.assertEqual((), loss2.shape) + self.assertEqual(tf.float32, logits.dtype) + self.assertEqual((batch_size, d_out), logits.shape) + # Training on the batch should have led to a change in the loss value. + self.assertNotEqual(loss1.numpy(), loss2.numpy()) + + def _create_test_data(self, snli_1_0_dir): + fake_train_file = os.path.join(snli_1_0_dir, "snli_1.0_train.txt") + os.makedirs(snli_1_0_dir) + + # Four sentences in total. + with open(fake_train_file, "wt") as f: + f.write("gold_label\tsentence1_binary_parse\tsentence2_binary_parse\t" + "sentence1_parse\tsentence2_parse\tsentence1\tsentence2\t" + "captionID\tpairID\tlabel1\tlabel2\tlabel3\tlabel4\tlabel5\n") + f.write("neutral\t( ( Foo bar ) . )\t( ( foo . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("contradiction\t( ( Bar foo ) . )\t( ( baz . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quux quuz ) . )\t( ( grault . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + f.write("entailment\t( ( Quuz quux ) . )\t( ( garply . )\t" + "DummySentence1Parse\tDummySentence2Parse\t" + "Foo bar.\tfoo baz.\t" + "4705552913.jpg#2\t4705552913.jpg#2r1n\t" + "neutral\tentailment\tneutral\tneutral\tneutral\n") + + glove_dir = os.path.join(self._temp_data_dir, "glove") + os.makedirs(glove_dir) + glove_file = os.path.join(glove_dir, "glove.42B.300d.txt") + + words = [".", "foo", "bar", "baz", "quux", "quuz", "grault", "garply"] + with open(glove_file, "wt") as f: + for i, word in enumerate(words): + f.write("%s " % word) + for j in range(data.WORD_VECTOR_LEN): + f.write("%.5f" % (i * 0.1)) + if j < data.WORD_VECTOR_LEN - 1: + f.write(" ") + else: + f.write("\n") + + return fake_train_file + + def testInferSpinnWorks(self): + """Test inference with the spinn model.""" + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", "( bar ( foo . ) )")) + logits = spinn.train_or_infer_spinn( + embed, word2index, None, None, None, config) + self.assertEqual(np.float32, logits.dtype) + self.assertEqual((3,), logits.shape) + + def testInferSpinnThrowsErrorIfOnlyOneSentenceIsSpecified(self): + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir"), + inference_sentences=("( foo ( bar . ) )", None)) + with self.assertRaises(ValueError): + spinn.train_or_infer_spinn(embed, word2index, None, None, None, config) + + def testTrainSpinn(self): + """Test with fake toy SNLI data and GloVe vectors.""" + + # 1. Create and load a fake SNLI data file and a fake GloVe embedding file. + snli_1_0_dir = os.path.join(self._temp_data_dir, "snli/snli_1.0") + fake_train_file = self._create_test_data(snli_1_0_dir) + + vocab = data.load_vocabulary(self._temp_data_dir) + word2index, embed = data.load_word_vectors(self._temp_data_dir, vocab) + + train_data = data.SnliData(fake_train_file, word2index) + dev_data = data.SnliData(fake_train_file, word2index) + test_data = data.SnliData(fake_train_file, word2index) + + # 2. Create a fake config. + config = _test_spinn_config( + data.WORD_VECTOR_LEN, 4, + logdir=os.path.join(self._temp_data_dir, "logdir")) + + # 3. Test training of a SPINN model. + trainer = spinn.train_or_infer_spinn( + embed, word2index, train_data, dev_data, test_data, config) + + # 4. Load train loss values from the summary files and verify that they + # decrease with training. + summary_file = glob.glob(os.path.join(config.logdir, "events.out.*"))[0] + events = summary_test_util.events_from_file(summary_file) + train_losses = [event.summary.value[0].simple_value for event in events + if event.summary.value + and event.summary.value[0].tag == "train/loss"] + self.assertEqual(config.epochs, len(train_losses)) + self.assertLess(train_losses[-1], train_losses[0]) + + # 5. Verify that checkpoints exist and contains all the expected variables. + self.assertTrue(glob.glob(os.path.join(config.logdir, "ckpt*"))) + ckpt_variable_names = [ + item[0] for item in checkpoint_utils.list_variables(config.logdir)] + self.assertIn("global_step", ckpt_variable_names) + for v in trainer.variables: + variable_name = v.name[:v.name.index(":")] if ":" in v.name else v.name + self.assertIn(variable_name, ckpt_variable_names) + + +class EagerSpinnSNLIClassifierBenchmark(test.Benchmark): + + def benchmarkEagerSpinnSNLIClassifier(self): + test_device = "gpu:0" if tfe.num_gpus() else "cpu:0" + with tf.device(test_device): + burn_in_iterations = 2 + benchmark_iterations = 10 + + vocab_size = 1000 + batch_size = 128 + sequence_length = 15 + d_embed = 200 + d_out = 4 + + embed = tf.random_normal((vocab_size, d_embed)) + + config = _test_spinn_config(d_embed, d_out) + model = spinn.SNLIClassifier(config, embed) + trainer = spinn.SNLIClassifierTrainer(model, config.lr) + + (labels, prem, prem_trans, hypo, + hypo_trans) = _generate_synthetic_snli_data_batch(sequence_length, + batch_size, + vocab_size) + + for _ in range(burn_in_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + + gc.collect() + start_time = time.time() + for _ in xrange(benchmark_iterations): + trainer.train_batch(labels, prem, prem_trans, hypo, hypo_trans) + wall_time = time.time() - start_time + # Named "examples"_per_sec to conform with other benchmarks. + extras = {"examples_per_sec": benchmark_iterations / wall_time} + self.report_benchmark( + name="Eager_SPINN_SNLIClassifier_Benchmark", + iters=benchmark_iterations, + wall_time=wall_time, + extras=extras) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md index 147b7047f42b7ccba5829b61370e82e217ce5838..ffc1d0332eae605ce0444a225e53baa68954cae0 100644 --- a/tensorflow/contrib/eager/python/g3doc/guide.md +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -19,29 +19,34 @@ to models defined without using eager execution. ## Installation -Eager execution is **not** included in the latest release (version 1.4) of -TensorFlow. To use it, you will need to [build TensorFlow from -source](https://www.tensorflow.org/install/install_sources) or install the -nightly builds. +Eager execution is included in TensorFlow versions 1.5 and above. +Installation instructions at https://www.tensorflow.org/install/ -For example, the nightly builds can be installed using `pip`: +The contents of this guide are compatible with TensorFlow 1.5. +However, if you run into bugs that are fixed in source but not the +release, you may want to either either [building from +source](https://www.tensorflow.org/install/install_sources) +or the try latest nightly builds. The nightly builds are available as: -- `pip install tf-nightly` (for CPU-only TensorFlow) -- `pip install tf-nightly-gpu` (for GPU-enabled TensorFlow) +- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and -Or using `docker`, with [Jupyter Notebook](http://jupyter.org/) support: +- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. + +For example, to run the latest nightly docker image: ```sh -# For CPU-only TensorFlow +# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker +docker pull tensorflow/tensorflow:nightly-gpu +docker run --runtime=nvidia -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu + +# If you do not have a GPU, use the CPU-only image docker pull tensorflow/tensorflow:nightly docker run -it -p 8888:8888 tensorflow/tensorflow:nightly - -# For GPU-enabled TensorFlow: -# (Requires https://github.com/NVIDIA/nvidia-docker) -nvidia-docker pull tensorflow/tensorflow:nightly-gpu -nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu ``` +And then visit http://localhost:8888 in your browser for a Jupyter notebook +environment. + ## Getting Started With TensorFlow installed, eager execution is enabled via a single call: @@ -292,7 +297,7 @@ def loss(weight, bias): error = prediction(training_inputs, weight, bias) - training_outputs return tf.reduce_mean(tf.square(error)) -# Function that returns the the derivative of loss with respect to +# Function that returns the derivative of loss with respect to # weight and bias grad = tfe.gradients_function(loss) @@ -757,7 +762,7 @@ For example, to record summaries once every 100 global steps, use: ```python tf.train.get_or_create_global_step() # Ensuring the global step variable exists -writer = tf.contrib.summary.create_summary_file_writer(logdir) +writer = tf.contrib.summary.create_file_writer(logdir) for _ in range(iterations): with writer.as_default(): diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py index 2f8016ede3caee6dbb6fd8f5226f1464b5c3976b..ea8dbf2b46ea4bd0e33645ae3c590c4dd13f7a52 100644 --- a/tensorflow/contrib/eager/python/metrics_impl.py +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -49,6 +49,20 @@ class Metric(object): Example use with graph execution: + ```python + m = SomeMetric(...) + inputs = ... # Some tensors to compute the metric on. + m_update = m(inputs) + # Variables defined in first call, so get the initialization op afterwards. + m_init = m.init_variables() # or tf.global_variables_initializer() + m_result = m.result() + with tf.Session() as sess: + sess.run(m_init) + for input in ...: + sess.run(m_update) + print(sess.run(m_result)) + ``` + Example use with graph execution with placeholders and feed_dict: ```python m = SomeMetric(...) m_placeholder = tf.placeholder(...) @@ -107,6 +121,7 @@ class Metric(object): """Returns op to execute to update this metric for these inputs. Returns None if eager execution is enabled. + Returns a graph-mode function if graph execution is enabled. Args: *args: @@ -183,6 +198,13 @@ class Metric(object): """Computes and returns a final value for the metric.""" raise NotImplementedError("Metrics must define a result() member function") + def value(self): + """In graph mode returns the result Tensor while in eager the callable.""" + if context.in_graph_mode(): + return self.result() + else: + return self.result + # We can support two different strategies of for doing data-parallel # distributed metric computations: # * Put metric variables on the first device and rely on small @@ -269,6 +291,9 @@ class Mean(Metric): Args: values: Tensor with the per-example value. weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. """ if weights is None: self.denom.assign_add( @@ -280,6 +305,9 @@ class Mean(Metric): self.denom.assign_add(math_ops.reduce_sum(weights)) values = math_ops.cast(values, self.dtype) * weights self.numer.assign_add(math_ops.reduce_sum(values)) + if weights is None: + return values + return values, weights def result(self): t = self.numer / self.denom @@ -307,7 +335,13 @@ class Accuracy(Mean): per element of the Tensor. predictions: Tensor with the predicted label for each example. weights: Optional weighting of each example. Defaults to 1. + + Returns: + The arguments, for easy chaining. """ matches = math_ops.equal(labels, predictions) matches = math_ops.cast(matches, dtypes.float64) super(Accuracy, self).call(matches, weights=weights) + if weights is None: + return labels, predictions + return labels, predictions, weights diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py index 96eb1b4f2a0e4c4af1f3310a2801b1b6aee285d6..a9ecaa3f8bced3043ea0eb0ac3aa8bfa65e9e1ff 100644 --- a/tensorflow/contrib/eager/python/metrics_test.py +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -27,6 +27,7 @@ from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.training import training_util @@ -67,7 +68,7 @@ class MetricsTest(test.TestCase): m([1, 10, 100]) training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - with summary_ops.create_summary_file_writer( + with summary_ops.create_file_writer( logdir, max_queue=0, name="t0").as_default(), summary_ops.always_record_summaries(): m.result() # As a side-effect will write summaries. @@ -137,7 +138,7 @@ class MetricsTest(test.TestCase): self.assertEqual(m1.name, "has space") self.assertEqual(m1.numer.name, "has_space/numer:0") - def testGraph(self): + def testGraphWithPlaceholder(self): with context.graph_mode(), self.test_session() as sess: m = metrics.Mean() p = array_ops.placeholder(dtypes.float32) @@ -153,6 +154,22 @@ class MetricsTest(test.TestCase): sess.run(accumulate, feed_dict={p: 7}) self.assertAllEqual(m.result().eval(), 7) + @test_util.run_in_graph_and_eager_modes() + def testGraphAndEagerTensor(self): + m = metrics.Mean() + inputs = ops.convert_to_tensor([1.0, 2.0]) + accumulate = m(inputs) + result = m.result() + self.evaluate(m.init_variables()) + self.evaluate(accumulate) + self.assertEqual(self.evaluate(result), 1.5) + # Second init resets all the variables. + self.evaluate(m.init_variables()) + inputs = ops.convert_to_tensor([2.0, 3.0]) + self.evaluate(m(inputs)) + value = m.value() + self.assertEqual(self.evaluate(value), 2.5) + def testTwoMeansGraph(self): # Verify two metrics with the same class and name don't # accidentally share state. @@ -163,6 +180,19 @@ class MetricsTest(test.TestCase): m2 = metrics.Mean() m2(2) + def testMetricsChain(self): + with context.graph_mode(), self.test_session(): + m1 = metrics.Mean() + m2 = metrics.Mean(name="m2") + update_m2 = m2(3.0) + update_m2_2 = m2(m1(1.0)) + m1.init_variables().run() + m2.init_variables().run() + update_m2.eval() + update_m2_2.eval() + self.assertAllEqual(m2.result().eval(), 2.0) + self.assertAllEqual(m1.result().eval(), 1.0) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py index 0388aaa8495f380595b2635529bc2e33e808b06f..e3c13cbd2e8ccd2ab79da74e0e97905c6ed5c02d 100644 --- a/tensorflow/contrib/eager/python/network.py +++ b/tensorflow/contrib/eager/python/network.py @@ -451,8 +451,30 @@ class Network(base.Layer): "at https://github.com/tensorflow/tensorflow/issues/new if this is " "important to you") + def add_loss(self, losses, inputs=None): + raise RuntimeError( + "add_loss is not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + @property + def losses(self): + """Gather losses from `Layer`s in the `Network`. + + Note that when executing eagerly, `Layer.losses` evaluates + regularizers. When using graph execution, variable regularization ops have + already been created and are simply returned here. + + Returns: + A list of tensors. + """ + layer_losses = [] + for layer in self.layers: + layer_losses.extend(layer.losses) + return layer_losses + # TODO(allenl): Support other Layer methods needed for graph mode, such as for - # losses and updates + # updates class Sequential(Network): diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py index e7835a63e6db926aa2d4b6c76c681c8a301757bd..3329fc6c513265deff41a368f5688dd605209c14 100644 --- a/tensorflow/contrib/eager/python/network_test.py +++ b/tensorflow/contrib/eager/python/network_test.py @@ -19,6 +19,7 @@ from __future__ import print_function import gc from tensorflow.contrib.eager.python import network +from tensorflow.contrib.layers.python.layers import regularizers from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test @@ -45,6 +46,22 @@ class MyNetwork(network.Network): return self.l1(x) +class RegularizedNetwork(network.Network): + + def __init__(self): + super(RegularizedNetwork, self).__init__() + self.l1 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0), + kernel_regularizer=regularizers.l1_regularizer(2.0))) + self.l2 = self.track_layer(core.Dense( + 1, + bias_regularizer=regularizers.l1_regularizer(2.0))) + + def call(self, values): + return self.l2(self.l1(values)) + + class NetworkTest(test.TestCase): def _save_modify_load_network_built(self, net, global_step=None): @@ -88,15 +105,13 @@ class NetworkTest(test.TestCase): result = net(constant_op.constant([[2.0]])) self.assertEqual(34.0, self.evaluate(result)) - # TODO(akshayka): This test should be changed once an API for compiling - # `call` into a defun is implemented. def testReplacingNetworkCallWithDefun(self): net = MyNetwork(name="abcd") + net.call = function.defun(net.call) x = constant_op.constant([[2.0]]) net(x) # Force variables to be created. self.evaluate(net.trainable_variables[0].assign([[17.0]])) - net.call = function.defun(net.call) result = net(x) # Build and execute the TensorFlow function self.assertEqual(34.0, self.evaluate(result)) @@ -484,6 +499,18 @@ class NetworkTest(test.TestCase): _check_op_prefixes(expected_prefix="my_network_1/dense/", checked_ops=checked_ops) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) + def testVariableRegularizers(self): + net = RegularizedNetwork() + net(constant_op.constant([[1.]])) + self.evaluate(net.variables[0].assign([[2.]])) + self.evaluate(net.variables[1].assign([3.])) + self.evaluate(net.variables[2].assign([[-2.]])) + self.evaluate(net.variables[3].assign([4.])) + self.assertAllEqual([4., 6., 8.], self.evaluate(net.losses)) + self.evaluate(net.variables[3].assign([5.])) + self.assertAllEqual([4., 6., 10.], self.evaluate(net.losses)) + @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True) def testDuplicateNameError(self): one = constant_op.constant([[1.]]) @@ -512,7 +539,7 @@ class NetworkTest(test.TestCase): # No issue here since the name is unique within its scope. name_conflict3 = MyNetwork(name="name_conflict") net2 = MyNetwork() # name=outside_scope/my_network_2 to avoid the - # variable_scope my_network_1 below. + # variable_scope my_network_1 below. vs_name_conflict = MyNetwork(name="vs_name_conflict") # conflict below with variable_scope.variable_scope("intervening_scope"): with variable_scope.variable_scope(captured_scope): @@ -661,7 +688,7 @@ class NetworkTest(test.TestCase): net2(one) # Layer names typically are globally unique rather than being unique within # the scope of their first use. However, within a Network they must be named - # locally so that previous Layer consutrciton does not interfere with + # locally so that previous Layer construction does not interfere with # variable naming (e.g. add a Layer construction before the Network, # suddenly your previously saved checkpoint is incompatible). self.assertEqual("dense", net1.l1.name) diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 57b070ec6eeac00c77f199a846639d64c4957cd8..62421849c766a1124c726812428985c913c653a3 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -82,7 +82,7 @@ def restore_variables_on_create(save_path, map_func=None): map_func_wrapper = lambda self, x: x else: if not callable(map_func): - raise ValueError("map_func must be callaled.") + raise ValueError("map_func must be callable.") map_func_wrapper = lambda self, x: map_func(x) ckpt_var_cache = dict() diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index abc7e3690c76c4446bce6b945325f1ca15ef1c8b..1a7f7b85e688e80e3cf482f2754462888187d311 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -73,16 +73,6 @@ class SaverTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'v1'): saver.save(ckpt_prefix) - def testDifferentGraphError(self): - with ops.device(self._dev()): - with ops.Graph().as_default(): - v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') - with ops.Graph().as_default(): - saver = _saver.Saver([v1]) - ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') - with self.assertRaisesRegexp(ValueError, 'Graph'): - saver.save(ckpt_prefix) - def testSameObjectOK(self): with ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') diff --git a/tensorflow/contrib/eager/python/summary_writer.py b/tensorflow/contrib/eager/python/summary_writer.py deleted file mode 100644 index 5d8c41b545b3c9fd03af85f302ba05a394f085a4..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer.py +++ /dev/null @@ -1,242 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""TensorBoard Summary Writer for TensorFlow Eager Execution.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import uuid - -from tensorflow.contrib.summary import gen_summary_ops -from tensorflow.python.eager import context -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import init_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import state_ops -from tensorflow.python.ops import summary_op_util -from tensorflow.python.ops import variable_scope - - -def _maybe_cpu(v): - if isinstance(v, (ops.EagerTensor, ops.Tensor)): - return v.cpu() - else: - return v - - -def _summary_writer_function(name, tensor, function, family=None): - def record(): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True - return record - - -class SummaryWriter(object): - """Writes summaries for TensorBoard, compatible with eager execution. - - This class is the supported way of writing TensorBoard summaries under - eager execution. - """ - - _CPU_DEVICE = "cpu:0" - - def __init__(self, - logdir, - max_queue=10, - flush_secs=120, - filename_suffix=""): - """Summary writer for TensorBoard, compatible with eager execution. - - If necessary, multiple instances of `SummaryWriter` can be created, with - distinct `logdir`s and `name`s. Each `SummaryWriter` instance will retain - its independent `global_step` counter and data writing destination. - - Example: - ```python - writer = tfe.SummaryWriter("my_model") - - # ... Code that sets up the model and data batches ... - - for _ in xrange(train_iters): - loss = model.train_batch(batch) - writer.scalar("loss", loss) - writer.step() - ``` - - Args: - logdir: Directory in which summary files will be written. - max_queue: Number of summary items to buffer before flushing to - filesystem. If 0, summaries will be flushed immediately. - flush_secs: Number of secondsbetween forced commits to disk. - filename_suffix: Suffix of the event protobuf files in which the summary - data are stored. - - Raises: - ValueError: If this constructor is called not under eager execution. - """ - # TODO(apassos, ashankar): Make this class and the underlying - # contrib.summary_ops compatible with graph model and remove this check. - if not context.in_eager_mode(): - raise ValueError( - "Use of SummaryWriter is currently supported only with eager " - "execution enabled. File an issue at " - "https://github.com/tensorflow/tensorflow/issues/new to express " - "interest in fixing this.") - - # TODO(cais): Consider adding name keyword argument, which if None or empty, - # will register the global global_step that training_util.get_global_step() - # can find. - with context.device(self._CPU_DEVICE): - self._name = uuid.uuid4().hex - self._global_step = 0 - self._global_step_tensor = variable_scope.get_variable( - "global_step/summary_writer/" + self._name, - shape=[], dtype=dtypes.int64, - initializer=init_ops.zeros_initializer()) - self._global_step_dirty = False - self._resource = gen_summary_ops.summary_writer(shared_name=self._name) - gen_summary_ops.create_summary_file_writer( - self._resource, logdir, max_queue, flush_secs, filename_suffix) - # Delete the resource when this object is deleted - self._resource_deleter = resource_variable_ops.EagerResourceDeleter( - handle=self._resource, handle_device=self._CPU_DEVICE) - - def step(self): - """Increment the global step counter of this SummaryWriter instance.""" - self._global_step += 1 - self._global_step_dirty = True - - @property - def global_step(self): - """Obtain the current global_step value of this SummaryWriter instance. - - Returns: - An `int` representing the current value of the global_step of this - `SummaryWriter` instance. - """ - return self._global_step - - def _update_global_step_tensor(self): - with context.device(self._CPU_DEVICE): - if self._global_step_dirty: - self._global_step_dirty = False - return state_ops.assign(self._global_step_tensor, self._global_step) - else: - return self._global_step_tensor - - def generic(self, name, tensor, metadata, family=None): - """Write a generic-type summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A `Tensor` or compatible value type containing the value of the - summary. - metadata: Metadata about the summary. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_summary( - self._resource, - self._update_global_step_tensor(), - _maybe_cpu(tensor), - tag, - _maybe_cpu(metadata), - name=scope) - - def scalar(self, name, tensor, family=None): - """Write a scalar summary. - - Args: - name: A name for the generated node. Will also serve as the series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type containing a - single value. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - - Returns: - A summary writer function for scalars. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_scalar_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def histogram(self, name, tensor, family=None): - """Write a histogram summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A real numeric `Tensor` or compatible value type. Any shape. - Values to use to build the histogram. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_histogram_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), name=scope) - - def image(self, name, tensor, bad_color=None, max_images=3, family=None): - """Write an image summary.""" - with context.device(self._CPU_DEVICE): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_image_summary( - self._resource, self._update_global_step_tensor(), - tag, _maybe_cpu(tensor), bad_color_, max_images, - name=scope) - - def audio(self, name, tensor, sample_rate, max_outputs, family=None): - """Write an audio summary. - - Args: - name: A name for the generated node. Will also serve as a series name in - TensorBoard. - tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` - or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`, or - compatible value type. - sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the - signal in hertz. - max_outputs: Max number of batch elements to generate audio for. - family: Optional; if provided, used as the prefix of the summary tag name, - which controls the tab name used for display on Tensorboard. - """ - with context.device(self._CPU_DEVICE): - with summary_op_util.summary_scope( - name, family, values=[tensor]) as (tag, scope): - gen_summary_ops.write_audio_summary( - self._resource, self._update_global_step_tensor(), - tag, - _maybe_cpu(tensor), - sample_rate=_maybe_cpu(sample_rate), - max_outputs=max_outputs, - name=scope) diff --git a/tensorflow/contrib/eager/python/summary_writer_test.py b/tensorflow/contrib/eager/python/summary_writer_test.py deleted file mode 100644 index 5ebb36d04fcba8f4558fa1c09716314af42f559f..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/eager/python/summary_writer_test.py +++ /dev/null @@ -1,150 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Unit tests for eager execution SummaryWriter.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import shutil -import tempfile - -import numpy as np - -from tensorflow.contrib.eager.python import summary_writer -from tensorflow.core.util import event_pb2 -from tensorflow.python.eager import context -from tensorflow.python.eager import test -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.lib.io import tf_record -from tensorflow.python.platform import gfile - - -class SummaryWriterTest(test.TestCase): - - def setUp(self): - super(SummaryWriterTest, self).setUp() - self._test_device = "gpu:0" if context.num_gpus() else "cpu:0" - self._tmp_logdir = tempfile.mkdtemp() - with context.device(self._test_device): - # Use max_queue=0 so that summaries are immediately flushed to filesystem, - # making testing easier. - self._writer = summary_writer.SummaryWriter(self._tmp_logdir, max_queue=0) - - def tearDown(self): - if os.path.isdir(self._tmp_logdir): - shutil.rmtree(self._tmp_logdir) - super(SummaryWriterTest, self).tearDown() - - def _readLastEvent(self, logdir=None): - if not logdir: - logdir = self._tmp_logdir - files = [f for f in gfile.ListDirectory(logdir) - if not gfile.IsDirectory(os.path.join(logdir, f))] - file_path = os.path.join(logdir, files[0]) - records = list(tf_record.tf_record_iterator(file_path)) - event = event_pb2.Event() - event.ParseFromString(records[-1]) - return event - - def testGlobalStep(self): - with context.device(self._test_device): - orig_step = self._writer.global_step - self._writer.step() - self.assertEqual(orig_step + 1, self._writer.global_step) - self.assertEqual(orig_step + 1, self._writer.global_step) - self._writer.step() - self._writer.step() - self.assertEqual(orig_step + 3, self._writer.global_step) - - def testGenericSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - with context.device("cpu:0"): - metadata = constant_op.constant("foo") - self._writer.generic("x", x, metadata) - event = self._readLastEvent() - self.assertEqual("x", event.summary.value[0].tag) - - def testScalarSummary(self): - with context.device(self._test_device): - x = constant_op.constant(1337.0) - self._writer.scalar("x", x) - event = self._readLastEvent() - self.assertTrue("x", event.summary.value[0].tag) - self.assertEqual(1337.0, event.summary.value[0].simple_value) - - def testHistogramSummary(self): - with context.device(self._test_device): - y = constant_op.constant([1.0, 3.0, 3.0, 7.0]) - self._writer.histogram("y", y) - event = self._readLastEvent() - self.assertEqual("y", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].histo) - - def testImageSummary(self): - with context.device(self._test_device): - a = constant_op.constant([[10.0, 20.0], [-20.0, -10.0]]) - self._writer.histogram("image1", a) - event = self._readLastEvent() - self.assertEqual("image1", event.summary.value[0].tag) - self.assertTrue(event.summary.value[0].image) - - def testAudioSummary(self): - with context.device(self._test_device): - w = constant_op.constant(np.random.rand(3, 10, 2), dtype=dtypes.float32) - fs = constant_op.constant(44100.0, dtype=dtypes.float32) - max_outputs = 1 - self._writer.audio("audio1", w, fs, max_outputs) - event = self._readLastEvent() - self.assertTrue(event.summary.value[0].audio) - - def testTwoSummaryWritersGlobalStepsWorkWithoutCrosstalk(self): - tmp_logdir2 = os.path.join(self._tmp_logdir, "_writer2_") - writer2 = summary_writer.SummaryWriter(tmp_logdir2, max_queue=0) - - self.assertEqual(0, writer2.global_step) - self._writer.step() - self.assertEqual(0, writer2.global_step) - writer2.step() - writer2.step() - writer2.step() - self.assertEqual(3, writer2.global_step) - - x = constant_op.constant(1337.0) - writer_orig_step = self._writer.global_step - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 1, event.step) - - writer2.scalar("x", x) - event = self._readLastEvent(tmp_logdir2) - self.assertEqual(3, event.step) - - self._writer.step() - self._writer.scalar("x", x) - - event = self._readLastEvent() - self.assertEqual(writer_orig_step + 2, event.step) - - -# TODO(cais): Add performance benchmark for SummaryWriter. - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 1697c879def8af5c05f3c9b11d318d570785d6de..d32bebf90c1e768d1efec26b3b78bf1a522a8f00 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -23,7 +23,9 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@list_devices @@num_gpus +@@py_func @@defun +@@make_template @@implicit_gradients @@implicit_value_and_gradients @@gradients_function @@ -50,13 +52,13 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@EagerVariableStore @@Network +@@Sequential @@save_network_checkpoint @@restore_network_checkpoint @@in_eager_mode @@in_graph_mode -@@IsolateTest @@run_test_in_graph_and_eager_modes @@DEVICE_PLACEMENT_EXPLICIT @@ -74,6 +76,7 @@ from __future__ import print_function from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.network import Sequential from tensorflow.contrib.eager.python.network import save_network_checkpoint from tensorflow.contrib.eager.python.network import restore_network_checkpoint from tensorflow.contrib.eager.python.saver import get_optimizer_variables @@ -97,13 +100,16 @@ from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr from tensorflow.python.framework.ops import enable_eager_execution from tensorflow.python.framework.ops import eager_run as run -from tensorflow.python.framework.test_util import IsolateTest from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable from tensorflow.python.ops.variable_scope import EagerVariableStore +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import template from tensorflow.python.util.all_util import remove_undocumented +py_func = script_ops.eager_py_func defun = function.defun +make_template = template.make_template_internal implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad gradients_function = backprop.gradients_function diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 0dedb2fd7c0905801cd87c239ff2ee09eecb6080..b6659c2a1797feab261d756e78b45231dbea5a02 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -102,10 +102,6 @@ class TFETest(test_util.TensorFlowTestCase): # Expect at least one device. self.assertTrue(tfe.list_devices()) - def testNumGPUs(self): - devices = tfe.list_devices() - self.assertEqual(len(devices) - 1, tfe.num_gpus()) - def testAddCheckNumericsOpsRaisesError(self): with self.assertRaisesRegexp( RuntimeError, diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index 8395e2db5ec0ce6f4adae5fa2467159549e70143..6cdbed5b896577f5622b1bd0123c289c798bc0a5 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -88,8 +88,9 @@ py_library( py_test( name = "dnn_linear_combined_test", - size = "small", + size = "medium", srcs = ["python/estimator/dnn_linear_combined_test.py"], + shard_count = 3, srcs_version = "PY2AND3", tags = [ "no_pip", @@ -162,7 +163,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", + "//tensorflow/python:check_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:lookup_ops", @@ -176,7 +177,6 @@ py_library( "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", - "//tensorflow/python/estimator:util", "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", ], @@ -204,6 +204,7 @@ py_test( "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", @@ -330,23 +331,24 @@ py_library( "//tensorflow/python:device", "//tensorflow/python:device_lib", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", + "//tensorflow/python/ops/losses", "@six_archive//:six", ], ) cuda_py_test( name = "replicate_model_fn_test", - size = "small", + size = "medium", srcs = ["python/estimator/replicate_model_fn_test.py"], additional_deps = [ "//tensorflow/python/estimator", @@ -374,5 +376,9 @@ cuda_py_test( "//tensorflow/python:variables", ":replicate_model_fn", ], - tags = ["multi_gpu"], + tags = [ + "manual", + "multi_gpu", + "notap", + ], ) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index 8191e06faed004df6927708ea04a67b90bd464de..0f75b77050b0ba4c752a6a74fdc7024170b6f318 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -26,6 +26,7 @@ from tensorflow.contrib.estimator.python.estimator.head import * from tensorflow.contrib.estimator.python.estimator.linear import * from tensorflow.contrib.estimator.python.estimator.logit_fns import * from tensorflow.contrib.estimator.python.estimator.multi_head import * +from tensorflow.contrib.estimator.python.estimator.replicate_model_fn import * from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import @@ -45,6 +46,8 @@ _allowed_symbols = [ 'call_logit_fn', 'dnn_logit_fn_builder', 'linear_logit_fn_builder', + 'replicate_model_fn', + 'TowerOptimizer', ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index 29c3c7358534f6e8ebbd31cbfcd7e34086d9b506..c99bf8badb35e6fffb7cae8761db9d402b8b3a8f 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -100,7 +100,7 @@ def add_metrics(estimator, metric_fn): def clip_gradients_by_norm(optimizer, clip_norm): - """Returns an optimizer which clips gradients before appliying them. + """Returns an optimizer which clips gradients before applying them. Example: diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index 5f4a3cc902c9cc07c0688ad41dab7391a641c133..ad1a8ef152b07ecbab33d9eb3184a2ae89def27d 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -20,8 +20,8 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.estimator.python.estimator import extenders +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import linear from tensorflow.python.feature_column import feature_column as fc diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index a9311a20f127d92f02a95b8b48082fc90850635a..238cf287b768eee28b20202084eb244c085c8b75 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.python.estimator import model_fn -from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys @@ -29,7 +28,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib @@ -44,6 +42,8 @@ _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY def multi_class_head(n_classes, weight_column=None, label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): """Creates a `_Head` for multi class classification. @@ -64,6 +64,12 @@ def multi_class_head(n_classes, labels have shape `[batch_size, 1]`, the loss is the weighted sum over `batch_size`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with + shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to + the input labels before passing them to `loss_fn`. + Args: n_classes: Number of classes, must be greater than 2 (for 2 classes, use `binary_classification_head`). @@ -76,6 +82,9 @@ def multi_class_head(n_classes, integer within [0, n_classes). If given, labels must be of string type and have any value in `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -83,17 +92,25 @@ def multi_class_head(n_classes, An instance of `_Head` for multi class classification. Raises: - ValueError: if `n_classes`, `metric_class_ids` or `label_keys` is invalid. + ValueError: if `n_classes`, `label_vocabulary` or `loss_reduction` is + invalid. """ return head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access n_classes=n_classes, weight_column=weight_column, label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) def binary_classification_head( - weight_column=None, thresholds=None, label_vocabulary=None, name=None): + weight_column=None, + thresholds=None, + label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, + loss_fn=None, + name=None): """Creates a `_Head` for single label binary classification. This head uses `sigmoid_cross_entropy_with_logits` loss. @@ -113,6 +130,12 @@ def binary_classification_head( labels have shape `[batch_size, 1]`, the loss is the weighted sum over `batch_size`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with + shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to + the input labels before passing them to `loss_fn`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -128,6 +151,9 @@ def binary_classification_head( [0, 1]. If given, labels must be string type and have any value in `label_vocabulary`. Note that errors will be raised if `label_vocabulary` is not provided but labels are strings. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -135,17 +161,22 @@ def binary_classification_head( An instance of `_Head` for binary classification. Raises: - ValueError: if `thresholds` contains a value outside of `(0, 1)`. + ValueError: If `thresholds` contains a value outside of `(0, 1)`. + ValueError: If `loss_reduction` is invalid. """ return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint:disable=protected-access weight_column=weight_column, thresholds=thresholds, label_vocabulary=label_vocabulary, + loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) def regression_head(weight_column=None, label_dimension=1, + loss_reduction=losses.Reduction.SUM, + loss_fn=None, name=None): """Creates a `_Head` for regression using the `mean_squared_error` loss. @@ -164,6 +195,10 @@ def regression_head(weight_column=None, `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or `[D0, D1, ... DN, label_dimension]`. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[D0, D1, ... DN, label_dimension]`. + Args: weight_column: A string or a `_NumericColumn` created by `tf.feature_column.numeric_column` defining feature column representing @@ -172,15 +207,23 @@ def regression_head(weight_column=None, label_dimension: Number of regression labels per example. This is the size of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for linear regression. + + Raises: + ValueError: If `label_dimension` or `loss_reduction` is invalid. """ return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access weight_column=weight_column, label_dimension=label_dimension, + loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) @@ -188,6 +231,7 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -202,7 +246,7 @@ def multi_label_head(n_classes, `batch_size`. The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many - applications, the shape is `[batch_size, label_n_classes]`. + applications, the shape is `[batch_size, n_classes]`. Labels can be: * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` @@ -237,6 +281,8 @@ def multi_label_head(n_classes, [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. + loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to + reduce training loss over batch. Defaults to `SUM`. loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. Also used as `name_scope` when creating ops. @@ -245,7 +291,8 @@ def multi_label_head(n_classes, An instance of `_Head` for multi-label classification. Raises: - ValueError: if `n_classes`, `thresholds`, or `loss_fn` is invalid. + ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is + invalid. """ thresholds = tuple(thresholds) if thresholds else tuple() if n_classes is None or n_classes < 2: @@ -266,10 +313,14 @@ def multi_label_head(n_classes, 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) if loss_fn: - _validate_loss_fn_args(loss_fn) + head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access + if (loss_reduction not in losses.Reduction.all() or + loss_reduction == losses.Reduction.NONE): + raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, - label_vocabulary=label_vocabulary, loss_fn=loss_fn, name=name) + label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, + loss_fn=loss_fn, name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -280,12 +331,14 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, + loss_reduction=losses.Reduction.SUM, loss_fn=None, name=None): self._n_classes = n_classes self._weight_column = weight_column self._thresholds = thresholds self._label_vocabulary = label_vocabulary + self._loss_reduction = loss_reduction self._loss_fn = loss_fn self._name = name @@ -344,9 +397,9 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access labels=processed_labels, logits=logits, expected_labels_dimension=self.logits_dimension) if self._loss_fn: - unweighted_loss = _call_loss_fn( + unweighted_loss = head_lib._call_loss_fn( # pylint:disable=protected-access loss_fn=self._loss_fn, labels=processed_labels, logits=logits, - features=features) + features=features, expected_loss_dim=1) else: unweighted_loss = losses.sigmoid_cross_entropy( multi_class_labels=processed_labels, logits=logits, @@ -356,19 +409,41 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access unweighted_loss, axis=-1, keep_dims=True) weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, features=features, weight_column=self._weight_column, logits=logits) - weighted_sum_loss = losses.compute_weighted_loss( - unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) - # _weights() can return 1. - example_weight_sum = math_ops.reduce_sum( - weights * array_ops.ones_like(unweighted_loss)) + training_loss = losses.compute_weighted_loss( + unweighted_loss, weights=weights, reduction=self._loss_reduction) return head_lib.LossSpec( - weighted_sum_loss=weighted_sum_loss, - example_weight_sum=example_weight_sum, + training_loss=training_loss, + unreduced_loss=unweighted_loss, + weights=weights, processed_labels=processed_labels) def create_estimator_spec( - self, features, mode, logits, labels=None, train_op_fn=None): - """See `Head`.""" + self, features, mode, logits, labels=None, train_op_fn=None, + regularization_losses=None): + """Returns an `EstimatorSpec`. + + Args: + features: Input `dict` of `Tensor` or `SparseTensor` objects. + mode: Estimator's `ModeKeys`. + logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`. + For many applications, the shape is `[batch_size, n_classes]`. + labels: Labels with shape matching `logits`. Can be multi-hot `Tensor` + with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with + `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when + `mode` equals `TRAIN` or `EVAL`. + train_op_fn: Function that takes a scalar loss `Tensor` and returns + `train_op`. Required in TRAIN mode. + regularization_losses: A list of additional scalar losses to be added to + the training loss, such as regularization losses. These losses are + usually expressed as a batch average, so for best results users need to + set `loss_reduction=SUM_OVER_BATCH_SIZE` or + `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to + avoid scaling errors. + Returns: + `EstimatorSpec`. + Raises: + ValueError: If `train_op_fn` is `None` in TRAIN mode. + """ with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access @@ -394,60 +469,74 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access export_output.PredictOutput(predictions)) }) - (weighted_sum_loss, example_weight_sum, + (training_loss, unreduced_loss, weights, processed_labels) = self.create_loss( features=features, mode=mode, logits=logits, labels=labels) + if regularization_losses: + regularization_loss = math_ops.add_n(regularization_losses) + regularized_training_loss = math_ops.add_n( + [training_loss, regularization_loss]) + else: + regularization_loss = None + regularized_training_loss = training_loss # Eval. if mode == model_fn.ModeKeys.EVAL: - weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, - features=features, weight_column=self._weight_column, logits=logits) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, - loss=weighted_sum_loss, + loss=regularized_training_loss, eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, weights=weights, - weighted_sum_loss=weighted_sum_loss, - example_weight_sum=example_weight_sum)) + unreduced_loss=unreduced_loss, + regularization_loss=regularization_loss)) # Train. if train_op_fn is None: raise ValueError('train_op_fn can not be None.') + # Only summarize mean_loss for SUM reduction to preserve backwards + # compatibility. Otherwise skip it to avoid unnecessary computation. + if self._loss_reduction == losses.Reduction.SUM: + example_weight_sum = math_ops.reduce_sum( + weights * array_ops.ones_like(unreduced_loss)) + mean_loss = training_loss / example_weight_sum + else: + mean_loss = None with ops.name_scope(''): + keys = metric_keys.MetricKeys summary.scalar( - head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access - weighted_sum_loss) - summary.scalar( - head_lib._summary_key( # pylint:disable=protected-access - self._name, metric_keys.MetricKeys.LOSS_MEAN), - weighted_sum_loss / example_weight_sum) + head_lib._summary_key(self._name, keys.LOSS), # pylint:disable=protected-access + regularized_training_loss) + if mean_loss is not None: + summary.scalar( + head_lib._summary_key(self._name, keys.LOSS_MEAN), # pylint:disable=protected-access + mean_loss) + if regularization_loss is not None: + summary.scalar( + head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access + regularization_loss) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, - loss=weighted_sum_loss, - train_op=train_op_fn(weighted_sum_loss)) + loss=regularized_training_loss, + train_op=train_op_fn(regularized_training_loss)) - def _eval_metric_ops(self, labels, probabilities, weights, weighted_sum_loss, - example_weight_sum): + def _eval_metric_ops( + self, labels, probabilities, weights, unreduced_loss, + regularization_loss): """Returns a dict of metrics for eval_metric_ops.""" with ops.name_scope( None, 'metrics', - [labels, probabilities, weights, weighted_sum_loss, example_weight_sum - ]): + [labels, probabilities, weights, unreduced_loss, regularization_loss]): keys = metric_keys.MetricKeys metric_ops = { # Estimator already adds a metric for loss. head_lib._summary_key(self._name, keys.LOSS_MEAN): # pylint:disable=protected-access metrics_lib.mean( - # Both values and weights here are reduced, scalar Tensors. - # values is the actual mean we want, but we pass the scalar - # example_weight_sum in order to return the correct update_op - # alongside the value_op for streaming metrics. - values=(weighted_sum_loss / example_weight_sum), - weights=example_weight_sum, + values=unreduced_loss, + weights=weights, name=keys.LOSS_MEAN), head_lib._summary_key(self._name, keys.AUC): # pylint:disable=protected-access metrics_lib.auc(labels=labels, predictions=probabilities, @@ -457,6 +546,13 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weights=weights, curve='PR', name=keys.AUC_PR), } + if regularization_loss is not None: + loss_regularization_key = head_lib._summary_key( # pylint:disable=protected-access + self._name, keys.LOSS_REGULARIZATION) + metric_ops[loss_regularization_key] = ( + metrics_lib.mean( + values=regularization_loss, + name=keys.LOSS_REGULARIZATION)) for threshold in self._thresholds: accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold metric_ops[head_lib._summary_key(self._name, accuracy_key)] = ( # pylint:disable=protected-access @@ -485,52 +581,3 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access threshold=threshold, name=recall_key)) return metric_ops - - -def _validate_loss_fn_args(loss_fn): - """Validates loss_fn arguments. - - Required arguments: labels, logits. - Optional arguments: features. - - Args: - loss_fn: The loss function. - Raises: - ValueError: If the signature is unexpected. - """ - loss_fn_args = util.fn_args(loss_fn) - for required_arg in ['labels', 'logits']: - if required_arg not in loss_fn_args: - raise ValueError( - 'loss_fn must contain argument: {}. ' - 'Given arguments: {}'.format(required_arg, loss_fn_args)) - invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features'])) - if invalid_args: - raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args)) - - -def _call_loss_fn(loss_fn, labels, logits, features): - """Calls loss_fn and checks the returned shape. - - Args: - loss_fn: The loss function. - labels: Processed labels Tensor. - logits: Logits Tensor of shape [batch_size, logits_dimension]. - features: Features dict. - Returns: - Loss Tensor with shape [batch_size, 1]. - """ - loss_fn_args = util.fn_args(loss_fn) - kwargs = {} - if 'features' in loss_fn_args: - kwargs['features'] = features - unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs) - batch_size = array_ops.shape(logits)[0] - loss_shape = array_ops.shape(unweighted_loss) - check_shape_op = control_flow_ops.Assert( - math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])), - data=[ - 'loss_fn must return Tensor of shape [batch_size, 1]. Given: ', - loss_shape]) - with ops.control_dependencies([check_shape_op]): - return array_ops.identity(unweighted_loss) diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index d1cf9090048470181818c573647923c9f5824dfa..43cdfec9689879201305385499b3b784e1593d60 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -35,6 +35,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants from tensorflow.python.training import monitored_session @@ -132,6 +133,16 @@ class MultiLabelHead(test.TestCase): r'Length of label_vocabulary must be n_classes \(3\). Given: 2'): head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar']) + def test_invalid_loss_reduction(self): + with self.assertRaisesRegexp( + ValueError, r'Invalid loss_reduction: invalid_loss_reduction'): + head_lib.multi_label_head( + n_classes=3, loss_reduction='invalid_loss_reduction') + with self.assertRaisesRegexp( + ValueError, r'Invalid loss_reduction: none'): + head_lib.multi_label_head( + n_classes=3, loss_reduction=losses.Reduction.NONE) + def test_loss_fn_arg_labels_missing(self): def _loss_fn(logits): del logits # Unused @@ -262,17 +273,17 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - expected_weighted_sum_loss = np.sum( + expected_training_loss = np.sum( _sigmoid_cross_entropy(labels=labels, logits=logits)) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, labels=labels)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(expected_weighted_sum_loss, - actual_weighted_sum_loss.eval()) + self.assertAllClose(expected_training_loss, + actual_training_loss.eval()) def test_eval_create_loss_large_logits(self): """Tests head.create_loss for eval mode and large logits.""" @@ -286,9 +297,9 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_weighted_sum_loss = np.sum( + expected_training_loss = np.sum( np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -296,9 +307,7 @@ class MultiLabelHead(test.TestCase): with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_weighted_sum_loss, - actual_weighted_sum_loss.eval(), - atol=1e-4) + expected_training_loss, actual_training_loss.eval(), atol=1e-4) def test_eval_create_loss_labels_wrong_shape(self): """Tests head.create_loss for eval mode when labels has the wrong shape.""" @@ -307,7 +316,7 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels_placeholder = array_ops.placeholder(dtype=dtypes.int64) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -317,14 +326,14 @@ class MultiLabelHead(test.TestCase): with self.assertRaisesRegexp( errors.InvalidArgumentError, r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'): - actual_weighted_sum_loss.eval({ + actual_training_loss.eval({ labels_placeholder: np.array([[1], [1]], dtype=np.int64) }) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels shape must be \[D0, D1, ... DN, 2\]\..*' r'\[Received shape: \] \[2\]'): - actual_weighted_sum_loss.eval({ + actual_training_loss.eval({ labels_placeholder: np.array([1, 1], dtype=np.int64) }) @@ -344,14 +353,14 @@ class MultiLabelHead(test.TestCase): return constant_op.constant(loss) head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits_input, labels=labels_input)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose(np.sum(loss), actual_weighted_sum_loss.eval()) + self.assertAllClose(np.sum(loss), actual_training_loss.eval()) def test_eval_create_loss_loss_fn_wrong_shape(self): """Tests custom loss_fn that returns Tensor of unexpected shape.""" @@ -363,7 +372,7 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) - actual_weighted_sum_loss = head.create_loss( + actual_training_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, @@ -372,9 +381,9 @@ class MultiLabelHead(test.TestCase): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, - r'loss_fn must return Tensor of shape \[batch_size, 1\]\. ' - r'Given: \] \[2\]'): - actual_weighted_sum_loss.eval() + r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] ' + r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2\]'): + actual_training_loss.eval() def test_eval_labels_none(self): """Tests that error is raised when labels is None.""" @@ -390,12 +399,13 @@ class MultiLabelHead(test.TestCase): def _test_eval( self, head, logits, labels, expected_loss, expected_metrics, - features=None): + features=None, regularization_losses=None): spec = head.create_estimator_spec( features=features or {}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels) + labels=labels, + regularization_losses=regularization_losses) # Assert spec contains expected tensors. self.assertIsNotNone(spec.loss) @@ -477,6 +487,38 @@ class MultiLabelHead(test.TestCase): expected_loss=expected_loss, expected_metrics=expected_metrics) + def test_eval_with_regularization_losses(self): + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + regularization_losses = [1.5, 0.5] + expected_regularization_loss = 2. + # unregularized_loss = sum( + # labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits))) / batch_size + expected_unregularized_loss = np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) / 2. + expected_regularized_loss = ( + expected_unregularized_loss + expected_regularization_loss) + keys = metric_keys.MetricKeys + expected_metrics = { + keys.LOSS_MEAN: expected_unregularized_loss, + keys.LOSS_REGULARIZATION: expected_regularization_loss, + # auc and auc_pr cannot be reliably calculated for only 4 samples, but + # this assert tests that the algorithm remains consistent. + keys.AUC: 0.3333, + keys.AUC_PR: 0.7639, + } + self._test_eval( + head=head, + logits=logits, + labels=labels, + expected_loss=expected_regularized_loss, + expected_metrics=expected_metrics, + regularization_losses=regularization_losses) + def test_eval_with_label_vocabulary(self): n_classes = 2 head = head_lib.multi_label_head( @@ -618,12 +660,44 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_weighted_sum_loss = np.sum( - np.array( - [[1. * (10. + 10.) / 2.], [2. * (15. + 0.) / 2.]], - dtype=np.float32)) - expected_example_weight_sum = 1. + 2. - actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] + expected_weights = [[1.], [2.]] + expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. + training_loss, unreduced_loss, actual_weights, _ = head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'example_weights': weights + }, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels) + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose( + expected_training_loss, training_loss.eval(), atol=1e-4) + self.assertAllClose( + expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4) + self.assertAllClose(expected_weights, actual_weights.eval()) + + def test_train_create_loss_loss_reduction(self): + """Tests head.create_loss with loss_reduction.""" + n_classes = 2 + head = head_lib.multi_label_head( + n_classes, weight_column='example_weights', + loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS) + + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + weights = np.array([[1.], [2.]], dtype=np.float32) + # loss = labels * -log(sigmoid(logits)) + + # (1 - labels) * -log(1 - sigmoid(logits)) + # For large logits, this is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits + expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] + expected_weights = [[1.], [2.]] + expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. + training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), 'example_weights': weights @@ -634,13 +708,10 @@ class MultiLabelHead(test.TestCase): with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_weighted_sum_loss, - actual_weighted_sum_loss.eval(), - atol=1e-4) + expected_training_loss, training_loss.eval(), atol=1e-4) self.assertAllClose( - expected_example_weight_sum, - actual_example_weight_sum.eval(), - atol=1e-4) + expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4) + self.assertAllClose(expected_weights, actual_weights.eval()) def test_train_labels_none(self): """Tests that error is raised when labels is None.""" @@ -791,6 +862,49 @@ class MultiLabelHead(test.TestCase): self._test_train( head=head, logits=logits, labels=labels, expected_loss=expected_loss) + def test_train_with_regularization_losses(self): + head = head_lib.multi_label_head( + n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + regularization_losses = [1.5, 0.5] + # For large logits, sigmoid cross entropy loss is approximated as: + # loss = labels * (logits < 0) * (-logits) + + # (1 - labels) * (logits > 0) * logits => + # expected_unweighted_loss = [[10., 10.], [15., 0.]] + # Average over classes and over batch and add regularization loss. + expected_loss = 35. / 4. + 2. + expected_summaries = { + metric_keys.MetricKeys.LOSS: expected_loss, + metric_keys.MetricKeys.LOSS_REGULARIZATION: 2., + } + expected_train_result = 'my_train_op' + def _train_op_fn(loss): + return string_ops.string_join( + [constant_op.constant(expected_train_result), + string_ops.as_string(loss, precision=3)]) + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=logits, + labels=labels, + train_op_fn=_train_op_fn, + regularization_losses=regularization_losses) + + # Assert predictions, loss, train_op, and summaries. + tol = 1e-3 + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertIsNotNone(spec.scaffold.summary_op) + loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, + spec.scaffold.summary_op)) + self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) + self.assertEqual( + six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), + train_result) + _assert_simple_summaries(self, expected_summaries, summary_str, tol) + def test_train_with_weights(self): n_classes = 2 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') @@ -851,12 +965,15 @@ class MultiLabelHead(test.TestCase): labels = np.array([[[1, 0, 0], [1, 0, 0]], [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) - # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 - # = [[20/3, 10/3], [4, 8]] + # unreduced_loss = + # [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 + # = [[20/3, 10/3], [4, 8]] + expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] + # weights are reshaped to [2, 2, 1] to match logits. + expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 - expected_weighted_sum_loss = 39.6667 - expected_example_weight_sum = np.sum(weights) - actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + expected_training_loss = 39.6667 + training_loss, unreduced_loss, actual_weights, _ = head.create_loss( features={'weights': weights}, mode=model_fn.ModeKeys.TRAIN, logits=logits, @@ -865,11 +982,10 @@ class MultiLabelHead(test.TestCase): with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_weighted_sum_loss, actual_weighted_sum_loss.eval(), - atol=atol) + expected_training_loss, training_loss.eval(), atol=atol) self.assertAllClose( - expected_example_weight_sum, actual_example_weight_sum.eval(), - atol=atol) + expected_unreduced_loss, unreduced_loss.eval(), atol=atol) + self.assertAllClose(expected_weights, actual_weights.eval()) def test_multi_dim_weighted_train(self): """Logits and labels of shape [2, 2, 3], weights [2, 2].""" diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index f2a6eae03ec021e5c28d48b3887870d8a057e077..0346ddc24bffd61068177f4622bd03be4acd53d9 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -186,40 +186,44 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access logits_dict = logits else: logits_dict = self._split_logits(logits) - weighted_sum_losses = [] - example_weight_sums = [] + training_losses = [] labels_by_head = {} - for head in self._heads: - (weighted_sum_loss, - example_weight_sum, processed_labels) = head.create_loss( + unreduced_losses_by_head = {} + example_weights_by_head = {} + for i, head in enumerate(self._heads): + (training_loss, unreduced_loss, + weights, processed_labels) = head.create_loss( features, mode, logits_dict[head.name], labels[head.name]) - weighted_sum_losses.append(weighted_sum_loss) - example_weight_sums.append(example_weight_sum) + training_losses.append(training_loss) labels_by_head[head.name] = processed_labels + if self._head_weights: + head_weight = self._head_weights[i] + unreduced_losses_by_head[head.name] = math_ops.multiply( + unreduced_loss, head_weight) + example_weights_by_head[head.name] = math_ops.multiply( + weights, head_weight) + else: + unreduced_losses_by_head[head.name] = unreduced_loss + example_weights_by_head[head.name] = weights - weighted_sum_losses = tuple(weighted_sum_losses) - with ops.name_scope('merge_losses', - values=weighted_sum_losses + (self._head_weights or - tuple())): + training_losses = tuple(training_losses) + with ops.name_scope( + 'merge_losses', + values=training_losses + (self._head_weights or tuple())): if self._head_weights: - head_weighted_losses = [] - head_weighted_example_weight_sums = [] - for loss, example_weight_sum, weight in zip(weighted_sum_losses, - example_weight_sums, - self._head_weights): - head_weighted_losses.append(math_ops.multiply(loss, weight)) - head_weighted_example_weight_sums.append(math_ops.multiply( - example_weight_sum, weight)) - merged_weighted_sum_loss = math_ops.add_n(head_weighted_losses) - merged_example_weight_sum = math_ops.add_n( - head_weighted_example_weight_sums) + head_weighted_training_losses = [] + for training_loss, head_weight in zip( + training_losses, self._head_weights): + head_weighted_training_losses.append( + math_ops.multiply(training_loss, head_weight)) + merged_training_loss = math_ops.add_n(head_weighted_training_losses) else: - merged_weighted_sum_loss = math_ops.add_n(weighted_sum_losses) - merged_example_weight_sum = math_ops.add_n(example_weight_sums) + merged_training_loss = math_ops.add_n(training_losses) return head_lib.LossSpec( - weighted_sum_loss=merged_weighted_sum_loss, - example_weight_sum=merged_example_weight_sum, + training_loss=merged_training_loss, + unreduced_loss=unreduced_losses_by_head, + weights=example_weights_by_head, processed_labels=labels_by_head) def create_estimator_spec( diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index 68f2d5d1cd53456f7dd82222e171b3619052321a..65ea89ba1b9236d0bf4d2de430fab168ef50bf97 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -370,7 +370,7 @@ class MultiHeadTest(test.TestCase): 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), } - weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + training_loss, unreduced_losses, weights, _ = multi_head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), 'weights1': weights1, @@ -383,14 +383,23 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = 1 * 10 + 2 * 7.5 = 25 + # head-weighted unreduced_loss = 1 * [10, 7.5] + self.assertAllClose( + [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 - # head-weighted merge = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) - # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 - self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + # training_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted unreduced_loss = 2 * [20, 10] + self.assertAllClose( + [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted example weights + self.assertAllClose( + [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) + self.assertAllClose( + [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol) def test_train_create_loss_logits_tensor(self): """Tests create_loss with logits Tensor.""" @@ -409,7 +418,7 @@ class MultiHeadTest(test.TestCase): 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), } - weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + training_loss, unreduced_losses, weights, _ = multi_head.create_loss( features={ 'x': np.array(((42,),), dtype=np.int32), 'weights1': weights1, @@ -422,14 +431,23 @@ class MultiHeadTest(test.TestCase): with self.test_session(): # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] # = [10, 7.5] - # weighted_sum_loss = 1 * 10 + 2 * 7.5 = 25 + # training_loss = 1 * 10 + 2 * 7.5 = 25 + # head-weighted unreduced_loss = 1 * [10, 7.5] + self.assertAllClose( + [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] # = [20, 10] - # weighted_sum_loss = 2 * 20 + 3 * 10 = 70 - # head-weighted merge = 1 * 25 + 2 * 70 = 165 - self.assertAllClose(165, weighted_sum_loss.eval(), rtol=tol, atol=tol) - # example_weight_sum = 1 * (1 + 2) + 2 * (2 + 3) = 13 - self.assertAllClose(13., example_weight_sum.eval(), rtol=tol, atol=tol) + # training_loss = 2 * 20 + 3 * 10 = 70 + # head-weighted unreduced_loss = 2 * [20, 10] + self.assertAllClose( + [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) + # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 + self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) + # head-weighted example weights + self.assertAllClose( + [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) + self.assertAllClose( + [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol) def test_train_create_loss_logits_tensor_multi_dim(self): """Tests create_loss with multi-dimensional logits of shape [2, 2, 5].""" @@ -455,20 +473,17 @@ class MultiHeadTest(test.TestCase): # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 # = 74 - expected_weighted_sum_loss = 28. + 74. + expected_training_loss = 28. + 74. - weighted_sum_loss, example_weight_sum, _ = multi_head.create_loss( + training_loss = multi_head.create_loss( features={}, mode=model_fn.ModeKeys.TRAIN, logits=logits, - labels=labels) + labels=labels)[0] tol = 1e-3 with self.test_session(): self.assertAllClose( - expected_weighted_sum_loss, weighted_sum_loss.eval(), - rtol=tol, atol=tol) - self.assertAllClose( - 2. * 2. * 5., example_weight_sum.eval(), rtol=tol, atol=tol) + expected_training_loss, training_loss.eval(), rtol=tol, atol=tol) def test_train_one_head(self): head1 = head_lib.multi_label_head(n_classes=2, name='head1') diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py index d9c83aa86577aa129458c56887ff4668c103d0db..7134cd3f5a457a322f51066eb791133c3181d3fb 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -23,6 +23,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from collections import defaultdict +from contextlib import contextmanager import copy import six @@ -41,20 +43,24 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.losses import losses from tensorflow.python.platform import tf_logging -from tensorflow.python.training import training_util +from tensorflow.python.training import device_setter as device_setter_lib +from tensorflow.python.training import optimizer as optimizer_lib -def replicate_model_fn(model_fn, optimizer_fn, devices=None): - """Replicate `Estimator.model_fn` over GPUs within a single host. +def replicate_model_fn(model_fn, + loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + devices=None): + """Replicate `Estimator.model_fn` over GPUs. The given `model_fn` specifies a single forward pass of a model. To replicate such a model over GPUs, each GPU gets its own instance of the forward pass (a.k.a. a tower). The input features and labels get sharded into the chunks - that correspond to the number of GPUs. Each tower computes its own loss based + that correspond to the number of GPUs. Each tower computes a loss based on its input. For each such loss, gradients are computed. After that, the - available losses are summed to form aggregated loss. The available - gradients are summed too. Then, they update weights using the specified + available losses are aggregated to form aggregated loss. Available + gradients are summed. Then, they update weights using the specified optimizer. If `devices` are `None`, then all available GPUs are going to be used for @@ -63,36 +69,38 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): Two modes of local replication over available GPUs are supported: 1) If exactly 1 GPU is detected, then variables and operations are placed - onto GPU. + onto the GPU. 2) If more than 1 GPU is detected, then variables are going to be placed on the CPU. Replicas of operations are placed on each individual GPU. Here is an example of how one might use their `model_fn` to run over GPUs: ```python - def optimizer_fn(): - return tf.train.GradientDescentOptimizer(learning_rate=0.001) ... def model_fn(...): # See `model_fn` in `Estimator`. loss = ... + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.001) + optimizer = tf.contrib.estimator.TowerOptimizer(optimizer) if mode == tf.estimator.ModeKeys.TRAIN: # See the section below on `EstimatorSpec.train_op`. - return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop()) + return EstimatorSpec(mode=mode, loss=loss, + train_op=optimizer.minimize(loss)) # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`. return EstimatorSpec(...) ... classifier = tf.estimator.Estimator( - model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn)) + model_fn=tf.contrib.estimator.replicate_model_fn(model_fn)) ``` + Please see `DNNClassifierIntegrationTest` for an example with a canned + Estimator. + On `EstimatorSpec.train_op`: `model_fn` returns `EstimatorSpec.train_op` for `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer. - `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there - is no need to use an optimizer inside the user's `model_fn`. The - `EstimatorSpec.loss` subgraph is going to be executed, while - `EstimatorSpec.train_op` isn't going to be executed. One could pass - `train_op=tf.noop()` to `EstimatorSpec`. + Towers are expected to populate it in the same way. Gradients from all towers + are reduced and applied in the last tower. To achieve that in the case of + multiple towers, `TowerOptimizer` needs to be used. See `TowerOptimizer`. On sharding input features and labels: Input features and labels are split for consumption by each tower. They are @@ -101,7 +109,7 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): On reduction algorithms: Certain algorithms were chosen for aggregating results of computations on multiple towers: - - Losses from all towers are reduced using sum. + - Losses from all towers are reduced according to `loss_reduction`. - Gradients are reduced using sum for each trainable variable. - `eval_metrics_ops` are reduced per metric using `reduce_mean`. - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are @@ -109,65 +117,332 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None): - For all other fields of `EstimatorSpec` the values of the first tower are taken. - On replication of variables: + On distribution of variables: Variables are not duplicated between towers. Instead, they are placed on a single device as defined above and shared across towers. - Other current limitations: - - `predictions` are not supported for `ModeKeys.EVAL`. That is required for - `tf.contrib.estimator.add_metrics`. + On overhead: + If only one device is specified, then aggregation of loss and gradients + doesn't happen. Replication consists of placing `model_fn` onto the + specified device. + + On current limitations: + - `predictions` are not supported for `ModeKeys.EVAL`. They are required + for `tf.contrib.estimator.add_metrics`. Args: model_fn: `model_fn` as defined in `Estimator`. See the section above about the train_op argument of `EstimatorSpec`. - optimizer_fn: a function that returns an optimizer instance. The function - may accept one `params` argument. This is the `params` argument as - defined by `Estimator`. See the `Estimator` documentation for details. + loss_reduction: controls whether losses are summed or averaged. devices: Optional list of devices to replicate the model across. This argument can be used to replice only on the subset of available GPUs. If `None`, then all available GPUs are going to be used for replication. If no GPUs are available, then the model is going to be placed on the CPU. + Raises: + ValueError: if there is no `loss_reduction` or if TowerOptimizer is + mis-used. + Returns: A replicated version of the supplied `model_fn`. Returned function that conforms to the requirements of `Estimator`'s `model_fn` and can be used instead of the supplied `model_fn`. """ + return _replicate_model_fn_with_mode( + model_fn, + loss_reduction, + devices, + # TODO(isaprykin): Query the system configuration to choose modes other + # than `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often + # appropriate. + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER) + + +class _VariableDistributionMode(object): + """Modes for variable distribution used for forcing a particular one. + + Forcing a mode is meant for performance experimentation purposes rather than + for general use cases. + """ + + SHARED_LOCAL_PARAMETER_SERVER = 1 + """Variables are placed on a single device and shared across all devices. + + Two ways to achieve this distribution over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + """ + + SHARED_ROUND_ROBIN = 2 + """Variables are placed on all devices in a round-robin fashion. + + Every subsequent variable is placed on the next device. There is only one + copy of each variable that is shared across all devices. + """ + + +def _replicate_model_fn_with_mode( + model_fn, + loss_reduction, + devices=None, + mode=_VariableDistributionMode.SHARED_LOCAL_PARAMETER_SERVER): + """A version of `replicate_model_fn` that allows to specify a `mode`.""" + if loss_reduction == losses.Reduction.NONE: + raise ValueError('Tower losses need to be reduced in some way, yet {} ' + 'reduction is specified.'.format(loss_reduction)) if not devices: devices = _get_local_devices('GPU') or _get_local_devices('CPU') - is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] - local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0].upper() + consolidation_device = devices[0] if is_a_single_gpu_case else '/CPU:0' - tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' - 'server device is going to be {}.'.format( - devices, local_ps_device)) + ps_devices = [consolidation_device] + if mode == _VariableDistributionMode.SHARED_ROUND_ROBIN: + ps_devices = devices + + tf_logging.info('Replicating the `model_fn` across {}. Variables are going ' + 'to be placed on {}. Consolidation device is going to be {}.' + .format(devices, ps_devices, consolidation_device)) + + def single_device_model_fn(features, labels, mode, params=None, config=None): + """`model_fn` on a single device without reduction overhead.""" + return _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=[features], + labels=[labels], + params=params, + loss_reduction=loss_reduction, + config=config, + devices=devices, + local_ps_devices=ps_devices)[0] # One device, so one spec is out. def replicated_model_fn(features, labels, mode, params=None, config=None): """Replicated version of `model_fn` to be used instead.""" feature_shards, label_shards = _split_batch( - features, labels, len(devices), device=local_ps_device) + features, labels, len(devices), device=consolidation_device) tower_specs = _get_loss_towers( model_fn=model_fn, mode=mode, features=feature_shards, labels=label_shards, params=params, + loss_reduction=loss_reduction, config=config, devices=devices, - local_ps_device=local_ps_device) + local_ps_devices=ps_devices) if mode == model_fn_lib.ModeKeys.TRAIN: - train_op = _minimize_towers(tower_specs, - _call_optimizer_fn(optimizer_fn, params)) + train_op = _minimize_towers(tower_specs) return _train_spec( - tower_specs, train_op, aggregation_device=local_ps_device) + tower_specs, train_op, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.EVAL: - return _eval_spec(tower_specs, aggregation_device=local_ps_device) + return _eval_spec(tower_specs, aggregation_device=consolidation_device) elif mode == model_fn_lib.ModeKeys.PREDICT: - return _predict_spec(tower_specs, aggregation_device=local_ps_device) + return _predict_spec(tower_specs, aggregation_device=consolidation_device) + + if len(devices) == 1: + return single_device_model_fn + else: + return replicated_model_fn + + +class TowerOptimizer(optimizer_lib.Optimizer): + """Gathers gradients from all towers and reduces them in the last one.""" + + COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states' - return replicated_model_fn + def __init__(self, optimizer_or_optimizer_fn): + """Wrap an existing optimizer for gathering gradients across towers. + + Each invocation of model_fn has to call the same optimizers in the same + order. + + Multiple optimizers that use the same or different losses are supported. + + If TowerOptimizer is used but `replicate_model_fn` isn't, then no + aggregation will happen. All calls will simply be forwarded to the + underlying optimizer. The behavior is similar if there is only one tower. + + If TowerOptimizer is used together with SyncReplicasOptimizer that wraps + the user's optimizer, then it's the SyncReplicasOptimizer that needs to be + wrapped with TowerOptimizer. + + Args: + optimizer_or_optimizer_fn: an instance of optimizer to wrap. That + instance is going to be used for optimizer-specific logic. This can + also be a no-argument function that returns such an optimizer instance. + """ + self._optimizer_or_optimizer_fn = optimizer_or_optimizer_fn + + @staticmethod + def has_been_used(): + return TowerOptimizer._graph_state().has_tower_optimizer_been_used + + def get_slot(self, *args, **kwargs): + return self._get_optimizer().get_slot(*args, **kwargs) + + def get_slot_names(self, *args, **kwargs): + return self._get_optimizer().get_slot_names(*args, **kwargs) + + def get_name(self, *args, **kwargs): + return self._get_optimizer().get_name(*args, **kwargs) + + def variables(self, *args, **kwargs): + return self._get_optimizer().variables(*args, **kwargs) + + def compute_gradients(self, loss, *args, **kwargs): + """Compute gradients, but first, if needed, scale the loss.""" + loss = _scale_loss(loss, + self._graph_state().loss_reduction, + self._graph_state().number_of_towers) + return self._get_optimizer().compute_gradients(loss, *args, **kwargs) + + def apply_gradients(self, grads_and_vars, global_step=None, **kwargs): + """Collect gradients updates to apply them with the last tower.""" + if self._graph_state().number_of_towers == 1: + # Avoid the overhead of reduction if there's only one tower. + # + # There assumed to be only one tower if aggregation-related methods were + # not called by `_get_loss_towers`, for example if the model_fn uses + # TowerEstimator, but `replicate_model_fn` isn't used. + return self._get_optimizer().apply_gradients(grads_and_vars, global_step, + **kwargs) + + self._graph_state().collect_gradients(grads_and_vars) + + if not self._graph_state().is_the_last_tower: + with ops_lib.control_dependencies(_extract_tensors(grads_and_vars)): + return self._construct_no_op_train_op() + else: + # Gradients need to be gathered and applied in the scope of the first + # tower, so that the tensors are accessible via names without prefixes. + var_scope, name_scope = self._graph_state().scopes_of_the_first_tower + with variable_scope.variable_scope(var_scope): + with ops_lib.name_scope(name_scope): + return self._apply_gathered_gradients(global_step, **kwargs) + + def _apply_gathered_gradients(self, global_step, **kwargs): + graph_state = self._graph_state() + optimizer = self._get_optimizer() + + grad_lists = {} + for grad, var in graph_state.get_latest_gradients_from_all_towers(): + if grad is not None: + grad_lists.setdefault(var, []).append(grad) + + aggregated_grads = [] + with ops_lib.name_scope('gradient_aggregating'): + for var, grads in six.iteritems(grad_lists): + grad = _compute_sum_on_device(grads, var.device) + aggregated_grads.append((grad, var)) + return optimizer.apply_gradients( + aggregated_grads, global_step=global_step, **kwargs) + + def _get_optimizer(self): + if callable(self._optimizer_or_optimizer_fn): + # If optimizer is given as a function then we need to wait till we are + # under the right graph context before constructing it. That's why the + # optimizer is constructed in _get_optimizer() rather than __init__(). + self._optimizer_or_optimizer_fn = self._optimizer_or_optimizer_fn() + self._graph_state().has_tower_optimizer_been_used = True + return self._optimizer_or_optimizer_fn + + def _construct_no_op_train_op(self): + return control_flow_ops.no_op(name='train_op_placeholder') + + @staticmethod + def _graph_state(): + graph_states = ops_lib.get_default_graph().get_collection_ref( + TowerOptimizer.COLLECTION_FOR_GRAPH_STATES) + if not graph_states: + graph_states.append(TowerOptimizer._PerGraphState()) + return graph_states[-1] + + @staticmethod + def _did_towers_have_same_optimizer_calls(): + graph_state = TowerOptimizer._graph_state() + return graph_state.did_towers_have_same_optimizer_calls() + + @staticmethod + def _clear_graph_state(): + # Clearing the Graph collection will prevent _PerGraphState from being + # serialized. + ops_lib.get_default_graph().clear_collection( + TowerOptimizer.COLLECTION_FOR_GRAPH_STATES) + + class _PerGraphState(object): + """Gradient reduction related state of a Tensorflow graph.""" + + def __init__(self): + self._collected_grads_and_vars = defaultdict(list) + self._current_tower_index = 0 + self._number_of_towers = 1 + self._loss_reduction = None + # Scopes of the first tower that don't have a prefix: + self._variable_scope = None + self._name_scope = None + # If needed, alert that TowerOptimizer needs to be used with model_fn. + self._has_tower_optimizer_been_used = False + + def collect_gradients(self, grads_and_vars): + self._collected_grads_and_vars[self._current_tower_index].append( + grads_and_vars) + + def get_latest_gradients_from_all_towers(self): + """Get gradients across towers for the last called optimizer.""" + grads_and_vars = [] + index_of_last_gradients = len( + self._collected_grads_and_vars[self._current_tower_index]) - 1 + for tower_id in range(self._current_tower_index + 1): + grads_and_vars.extend( + self._collected_grads_and_vars[tower_id][index_of_last_gradients]) + return grads_and_vars + + def set_reduction_across_towers(self, loss_reduction, number_of_towers): + self._loss_reduction = loss_reduction + self._number_of_towers = number_of_towers + + @contextmanager + def tower(self, tower_id, var_scope, name_scope): + if tower_id == 0: + self._variable_scope = var_scope + self._name_scope = name_scope + self._current_tower_index = tower_id + yield + + @property + def scopes_of_the_first_tower(self): + return self._variable_scope, self._name_scope + + @property + def is_the_last_tower(self): + return self._current_tower_index == (self._number_of_towers - 1) + + @property + def number_of_towers(self): + return self._number_of_towers + + @property + def loss_reduction(self): + return self._loss_reduction + + @property + def has_tower_optimizer_been_used(self): + return self._has_tower_optimizer_been_used + + @has_tower_optimizer_been_used.setter + def has_tower_optimizer_been_used(self, value): + self._has_tower_optimizer_been_used = value + + def did_towers_have_same_optimizer_calls(self): + total_number_of_grads = sum([ + len(grads) + for _, grads in six.iteritems(self._collected_grads_and_vars) + ]) + return total_number_of_grads % self._number_of_towers == 0 def _get_local_devices(device_type): @@ -182,6 +457,13 @@ def _get_local_devices(device_type): def _split_batch(features, labels, number_of_shards, device): """Split input features and labes into batches.""" + def ensure_divisible_by_shards(sequence): + batch_size = ops_lib.convert_to_tensor(sequence).get_shape()[0] + if batch_size % number_of_shards != 0: + raise ValueError( + 'Batch size {} needs to be divisible by the number of GPUs, which ' + 'is {}.'.format(batch_size, number_of_shards)) + def split_dictionary(dictionary): """Split a dictionary into shards.""" shards = [{} for _ in range(number_of_shards)] @@ -192,6 +474,7 @@ def _split_batch(features, labels, number_of_shards, device): sp_input=tensor, num_split=number_of_shards, axis=0)): shards[i][name] = shard else: + ensure_divisible_by_shards(tensor) for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): shards[i][name] = shard return shards @@ -201,6 +484,7 @@ def _split_batch(features, labels, number_of_shards, device): if isinstance(features, dict): feature_shards = split_dictionary(features) else: + ensure_divisible_by_shards(features) feature_shards = array_ops.split(features, number_of_shards) if labels is None: @@ -208,6 +492,7 @@ def _split_batch(features, labels, number_of_shards, device): elif isinstance(labels, dict): label_shards = split_dictionary(labels) else: + ensure_divisible_by_shards(labels) label_shards = array_ops.split(labels, number_of_shards) return feature_shards, label_shards @@ -222,7 +507,8 @@ def _get_loss_towers(model_fn, params, config, devices, - local_ps_device, + local_ps_devices, + loss_reduction, name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): """Replicate the loss computation across devices.""" tower_specs = [] @@ -234,36 +520,64 @@ def _get_loss_towers(model_fn, if 'config' in model_fn_args: optional_params['config'] = copy.deepcopy(config) + # pylint: disable=protected-access + round_robin_strategy = device_setter_lib._RoundRobinStrategy( + num_tasks=len(local_ps_devices)) + TowerOptimizer._graph_state().set_reduction_across_towers( + loss_reduction, len(devices)) + for i, device in enumerate(devices): is_the_first_tower = (i == 0) device_setter = _local_device_setter( - worker_device=device, ps_device=local_ps_device) + worker_device=device, + ps_devices=local_ps_devices, + ps_strategy=round_robin_strategy) - # We would like to preserve the names of the variables and ops that a user - # might be relying on. Names with prefix are going to resolve to variables - # and ops of the first tower. + # We would like to preserve the names of the variables and ops that the user + # might be relying on. Names without a prefix are going to resolve to + # variables and ops of the first tower. name_scope = name_scope_pattern if is_the_first_tower: name_scope = '' - with variable_scope.variable_scope('', reuse=not is_the_first_tower): - with ops_lib.name_scope(name_scope.format(i)): - with ops_lib.device(device_setter): - labels_shard = None - if labels: - labels_shard = labels[i] - - tower_specs.append( - model_fn( - mode=mode, - features=features[i], - labels=labels_shard, - **optional_params)) + with variable_scope.variable_scope( + '', reuse=not is_the_first_tower) as var_scope: + with ops_lib.name_scope(name_scope.format(i)) as name_scope: + with TowerOptimizer._graph_state().tower( + tower_id=i, var_scope=var_scope, name_scope=name_scope): + with ops_lib.device(device_setter): + labels_shard = None + if labels: + labels_shard = labels[i] + + tower_spec = model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params) + + if (tower_spec.train_op is not None and len(devices) > 1 and + not TowerOptimizer.has_been_used()): + raise ValueError('Please wrap optimizers with TowerOptimizer' + ' in order to use replicate_model_fn with' + ' multiple `devices`.') + + # Scaling the loss here doesn't actually affect gradients. Another + # instance of scaling happens inside the TowerOptimizer. + tower_spec = _scale_tower_loss( + tower_spec, loss_reduction, number_of_towers=len(devices)) + tower_specs.append(tower_spec) + + if not TowerOptimizer._did_towers_have_same_optimizer_calls(): + raise ValueError('Each invocation of model_fn was supposed to make the same' + ' optimizer calls.') + TowerOptimizer._clear_graph_state() + # pylint: enable=protected-access return tower_specs -def _local_device_setter(ps_device, worker_device): +def _local_device_setter(worker_device, ps_devices, ps_strategy): """A device setter that puts distributes Var/Ops to PS/workers.""" ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] @@ -273,7 +587,7 @@ def _local_device_setter(ps_device, worker_device): node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def if node_def.op in ps_ops: ps_device_spec = framework_device.DeviceSpec.from_string( - '{}'.format(ps_device)) + '{}'.format(ps_devices[ps_strategy(op)])) ps_device_spec.merge_from(current_device) return ps_device_spec.to_string() @@ -286,33 +600,33 @@ def _local_device_setter(ps_device, worker_device): return local_device_chooser -def _minimize_towers(tower_specs, optimizer): - """Aggregate and apply gradients for computed losses.""" - grad_lists = {} - for tower_spec in tower_specs: - with ops_lib.device(tower_spec.loss.device): - for grad, var in optimizer.compute_gradients(tower_spec.loss): - if grad is not None: - grad_lists.setdefault(var, []).append(grad) +def _scale_tower_loss(tower_spec, loss_reduction, number_of_towers): + """Produce an EstimatorSpec with approproriately scaled loss.""" + if tower_spec.loss is None: + return tower_spec + + estimator_spec = _asdict(tower_spec) + estimator_spec['loss'] = _scale_loss(tower_spec.loss, loss_reduction, + number_of_towers) + return model_fn_lib.EstimatorSpec(**estimator_spec) - aggregated_grads = [] - with ops_lib.name_scope('gradient_aggregating'): - for var, grads in six.iteritems(grad_lists): - grad = _compute_sum_on_device(grads, var.device) - aggregated_grads.append((grad, var)) - train_op = optimizer.apply_gradients( - aggregated_grads, global_step=training_util.get_global_step()) +def _scale_loss(loss, loss_reduction, number_of_towers): + """If needed, scale down the loss for averaging loss by summing.""" + if loss is None: + return None + if number_of_towers == 1: + return loss - return train_op + if loss_reduction != losses.Reduction.SUM: + return math_ops.div(loss, 1.0 * number_of_towers, name='averaged_loss') + else: + return loss -def _call_optimizer_fn(optimizer_fn, params): - arguments = {} - optimizer_fn_arguments = util.fn_args(optimizer_fn) - if 'params' in optimizer_fn_arguments: - arguments['params'] = params - return optimizer_fn(**arguments) +def _minimize_towers(tower_specs): + """`train_op` of the last tower applies aggregated gradients.""" + return tower_specs[-1].train_op def _compute_sum_on_device(values, device, name=None): @@ -335,7 +649,12 @@ def _train_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" - estimator_spec = tower_specs[0]._asdict() + # Spec of the last tower is used as the template for the final spec, because + # some `EstimatorSpec.training_hooks` rely on calls made in model_fn. For + # example, `SyncReplicasOptimizerHook` validates the + # `SyncReplicasOptimizer.apply_gradients` call. `TowerEstimator` makes that + # call only in the last tower. + estimator_spec = _asdict(tower_specs[-1]) estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN estimator_spec['train_op'] = train_op estimator_spec['loss'] = _compute_sum_on_device( @@ -346,7 +665,7 @@ def _train_spec(tower_specs, def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" - estimator_spec = tower_specs[0]._asdict() + estimator_spec = _asdict(tower_specs[0]) estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL estimator_spec['loss'] = _compute_sum_on_device( [spec.loss for spec in tower_specs], aggregation_device, @@ -370,7 +689,7 @@ def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): def _reduce_metric_variables(number_of_towers): """Aggregate local variables used in metrics into the first tower.""" if number_of_towers == 1: - return control_flow_ops.no_op() + return control_flow_ops.no_op(name='no_eval_metric_reduction') metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES) variables_per_tower = len(metric_variables) // number_of_towers @@ -414,7 +733,7 @@ def _reduce_metric_variables(number_of_towers): def _predict_spec(tower_specs, aggregation_device): """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" - estimator_spec = tower_specs[0]._asdict() + estimator_spec = _asdict(tower_specs[0]) estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT with ops_lib.device(aggregation_device): @@ -465,6 +784,17 @@ def _concat_tensor_dicts(*tensor_dicts): } +def _extract_tensors(tensors_and_vars): + tensors = [] + for tensor_and_var in tensors_and_vars: + tensor, _ = tensor_and_var + if isinstance(tensor, ops_lib.IndexedSlices): + tensors.append(tensor.values) + elif tensor is not None: + tensors.append(tensor) + return tensors + + def _dict_concat(*dicts): list_dict = {} for d in dicts: @@ -474,3 +804,19 @@ def _dict_concat(*dicts): for k, v in six.iteritems(d): list_dict.setdefault(k, []).append(v) return list_dict + + +def _asdict(namedtuple): + """Returns a namedtuple as a dictionary. + + This is required because `_asdict()` in Python 3.x.x is broken in classes + that inherit from `collections.namedtuple`. See + https://bugs.python.org/issue24931 for more details. + + Args: + namedtuple: An object that inherits from `collections.namedtuple`. + + Returns: + A dictionary version of the tuple. + """ + return {k: getattr(namedtuple, k) for k in namedtuple._fields} diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py index ffe69f89b4c4d48d329a1aef3aa3cad2b17b3fdf..d46a18aacfcd911c56a9f22dc9581060c7b458a6 100644 --- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -37,9 +37,11 @@ from tensorflow.python.feature_column import feature_column from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import losses from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variable_scope @@ -49,15 +51,32 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import adam +from tensorflow.python.training import device_setter from tensorflow.python.training import gradient_descent +from tensorflow.python.training import training +# TODO(isaprykin): Parametrize all the tests on +# replicate_model_fn._VariableDistributionMode when it's supported. class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): def setUp(self): self._model_dir = tempfile.mkdtemp() - def test_complete_flow(self): + def test_complete_flow_with_public_version(self): + return self._complete_flow_with_mode(mode=None) + + def test_complete_flow_with_mode_local_ps_server(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode. + SHARED_LOCAL_PARAMETER_SERVER) + + def test_complete_flow_with_mode_round_robin(self): + return self._complete_flow_with_mode( + replicate_model_fn._VariableDistributionMode.SHARED_ROUND_ROBIN) + + def _complete_flow_with_mode(self, mode): n_classes = 3 input_dimension = 2 batch_size = 12 @@ -96,20 +115,30 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): 0., len(x_data), len(x_data), dtype=np.int64)), 1) ] + def optimizer_fn(): + return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + estimator = dnn.DNNClassifier( hidden_units=(2, 2), + # Adagrad is configured with `get_optimizer_instance`, so the function + # form of `TowerOptimizer.__init__` is used. + optimizer=replicate_model_fn.TowerOptimizer(optimizer_fn), feature_columns=feature_columns, n_classes=n_classes, model_dir=self._model_dir) - def optimizer_fn(): - return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + if not mode: # Use the public `replicate_model_fn`. + model_fn = replicate_model_fn.replicate_model_fn( + estimator.model_fn, devices=['/gpu:0', '/gpu:1', '/gpu:2']) + else: + model_fn = replicate_model_fn._replicate_model_fn_with_mode( + estimator.model_fn, + devices=['/gpu:0', '/gpu:1', '/gpu:2'], + loss_reduction=losses.Reduction.SUM, + mode=mode) estimator = estimator_lib.Estimator( - model_fn=replicate_model_fn.replicate_model_fn( - estimator.model_fn, - optimizer_fn, - devices=['/gpu:0', '/gpu:1', '/gpu:2']), + model_fn=model_fn, model_dir=estimator.model_dir, config=estimator.config, params=estimator.params) @@ -134,6 +163,10 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): serving_input_receiver_fn) self.assertTrue(gfile.Exists(export_dir)) + # Nothing should be left in the graph so that it doesn't get serialized. + self.assertFalse(ops_lib.get_default_graph().get_collection_ref( + replicate_model_fn.TowerOptimizer.COLLECTION_FOR_GRAPH_STATES)) + def _as_label(self, data_in_float): return np.rint(data_in_float).astype(np.int64) @@ -153,28 +186,24 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): predictions = math_ops.multiply(features, c) - loss = None - if mode is not model_fn_lib.ModeKeys.PREDICT: - loss = losses.absolute_difference( - labels=labels, - predictions=predictions, - reduction=losses.Reduction.SUM) - loss = math_ops.reduce_sum(loss) + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) metrics = { 'accuracy': metrics_lib.accuracy(labels, predictions), 'auc': metrics_lib.auc(labels, predictions) } + optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(params['learning_rate'])) + return model_fn_lib.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=metrics, predictions={'probabilities': predictions}, - train_op=control_flow_ops.no_op()) # This train_op isn't actually used. - - def optimizer_fn(self, params): - return gradient_descent.GradientDescentOptimizer(params['learning_rate']) + train_op=optimizer.minimize(loss)) @property def params(self): @@ -188,7 +217,9 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) @@ -197,31 +228,78 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) self.assertEqual(total_loss, session.run(estimator_spec.loss)) - # loss' of c is 3. + # derivative of loss = (1*c - 1) + (2*c - 2) is 3. # new value of c = 10 - learning rate * 3 = 7.0. session.run(estimator_spec.train_op) with variable_scope.variable_scope('', reuse=True): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(7.0, session.run(c)) - def test_train_spec_with_optimizer_without_params(self): - - def optimizer_fn_without_params(): - return gradient_descent.GradientDescentOptimizer(learning_rate=1.0) - + def test_train_with_mean_reduction(self): features = np.array([[1.0], [2.0]]) labels = np.array([[1.0], [2.0]]) - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session() as session: + # Add another trainable variable that doesn't produce a gradient to + # verify that None gradients are supported. + _ = variable_scope.get_variable( + 'another_variable', + initializer=constant_op.constant(1, dtype=dtypes.float64), + dtype=dtypes.float64) + replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, - optimizer_fn_without_params, - devices=['/gpu:0', '/gpu:1']) - # This call is going to fail if `replicated_model_fn` is still passing - # `params` inside `optimizer_fn`, even though the latter doesn't take any: + self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) - del estimator_spec + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0)) / 2.0 + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # derivative of loss = (1*c - 1)/2 + (2*c - 2)/2 is 1.5. + # It's the same computation as without mean reduction, but the + # loss from every tower is scaled by 1/. + # new value of c = 10 - learning rate * 1.5 = 8.5 + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(8.5, session.run(c)) + + def test_train_two_steps_collected_gradients_are_reset_between_steps(self): + with ops_lib.Graph().as_default(): + features = array_ops.placeholder(dtypes.float64) + labels = array_ops.placeholder(dtypes.float64) + + feature_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]]) + label_inputs = np.array([[1.0], [2.0]]), np.array([[1.5], [2.5]]) + + # loss = feature * c - label + expected_losses = ((1.0 * 10 - 1.0) + (2.0 * 10 - 2.0), + (1.5 * 7.0 - 1.5) + (2.5 * 7.0 - 2.5)) + # Derivative of the loss is 1.0 + 2.0 for the first step and 1.5 + 2.5 + # for the second. + expected_c = 10.0 - 3.0, 7.0 - 4.0 + + with self.test_session() as session, variable_scope.variable_scope( + '', reuse=variable_scope.AUTO_REUSE): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + for feature_input, label_input, loss, weight in zip( + feature_inputs, label_inputs, expected_losses, expected_c): + feeds = {features: feature_input, labels: label_input} + + self.assertEqual(loss, session.run(estimator_spec.loss, feeds)) + + session.run(estimator_spec.train_op, feeds) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(weight, session.run(c, feeds)) def test_eval(self): features = np.array([[0.01], [0.002]]) @@ -229,7 +307,9 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) @@ -252,13 +332,42 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): self.assertEqual(0, auc) self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + def test_eval_with_mean_reduction(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, losses.Reduction.MEAN, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.EVAL, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # loss[i] = features[i] * 10 - labels[i]. + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) / 2.0 + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + def test_predict(self): features = np.array([[0.01], [0.002]]) labels = np.array([[0.01], [0.02]]) with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) @@ -273,7 +382,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn) + self.model_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) session.run(variables.global_variables_initializer()) @@ -295,7 +404,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + self.model_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.EVAL, self.params) session.run(variables.local_variables_initializer()) @@ -323,7 +432,7 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): with self.test_session() as session: replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + self.model_fn, devices=['/gpu:0']) estimator_spec = replicated_model_fn( features, labels, model_fn_lib.ModeKeys.PREDICT, self.params) session.run(variables.global_variables_initializer()) @@ -332,6 +441,451 @@ class ReplicateModelTest(test_util.TensorFlowTestCase): 'probabilities': np.array([[0.1], [0.02]]) }, session.run(estimator_spec.predictions)) + def test_batch_size_that_is_not_divisible_by_the_number_of_gpus(self): + features = np.array([[1.0], [2.0], [3.0]]) + labels = np.array([[1.0], [2.0], [3.0]]) + + with self.assertRaisesRegexp( + ValueError, '.*Batch.+size.+needs.+to.+be.+divisible.+by.+GPUs.+'): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + def test_unsupported_loss_reduction(self): + with self.assertRaisesRegexp(ValueError, + '.+none.+reduction.+is.+specified.+'): + _ = replicate_model_fn.replicate_model_fn(self.model_fn, + losses.Reduction.NONE) + + def test_places_on_gpu_with_upper_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/GPU:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + + def test_places_on_gpu_with_lower_case_spelling(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session(): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0']) + _ = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', c.device) + + +class ReplicateAcrossASingleDeviceWithoutTowerOptimizer( + test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer( + params['learning_rate']) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=optimizer.minimize(loss)) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn( + features, labels, model_fn_lib.ModeKeys.TRAIN, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + +class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + features = features['features'] + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(params['learning_rate'])) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=optimizer.minimize(loss)) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + train_input_fn = numpy_io.numpy_input_fn( + x={'features': features}, y=labels, batch_size=2, shuffle=False) + + with self.test_session(): + estimator = estimator_lib.Estimator( + model_fn=self.model_fn, + model_dir=tempfile.mkdtemp(), + params=self.params) + estimator.train(train_input_fn, steps=1) + + self.assertEqual(7.0, estimator.get_variable_value('c')) + + +class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + features = features['features'] + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer( + params['learning_rate']) + optimizer = training.SyncReplicasOptimizer( + optimizer, replicas_to_aggregate=1) + sync_hook = optimizer.make_session_run_hook(True) + optimizer = replicate_model_fn.TowerOptimizer(optimizer) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + training_hooks=[sync_hook], + predictions={'probabilities': predictions}, + train_op=optimizer.minimize( + loss, global_step=training.get_global_step())) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train_multiple_towers(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + train_input_fn = numpy_io.numpy_input_fn( + x={'features': features}, y=labels, batch_size=2, shuffle=False) + + model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + + estimator = estimator_lib.Estimator( + model_fn=model_fn, model_dir=tempfile.mkdtemp(), params=self.params) + estimator.train(train_input_fn, steps=1) + + self.assertEqual(7.0, estimator.get_variable_value('c')) + + +class ReplicateWithTwoOptimizersTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + side_effects = variable_scope.get_variable( + 'side_effects', + initializer=constant_op.constant(0, dtype=dtypes.float64), + dtype=dtypes.float64, + trainable=False) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + first_optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0)) + second_optimizer = replicate_model_fn.TowerOptimizer( + adam.AdamOptimizer(1.0)) + + with ops_lib.control_dependencies([side_effects.assign_add(1.0)]): + first_grads_and_vars = first_optimizer.compute_gradients(loss) + + train_op = control_flow_ops.group( + [first_optimizer.apply_gradients(first_grads_and_vars), + second_optimizer.minimize(loss)]) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, {}) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + # Adam subtracts another ~1. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertNear(6.0, session.run(c), 0.000001) + + side_effects = variable_scope.get_variable( + 'side_effects', dtype=dtypes.float64) + self.assertNear(2.0, session.run(side_effects), 0.000001) + + +class ReplicateWithTwoLossesAndOneOptimizer(test_util.TensorFlowTestCase): + + def setUp(self): + self._should_skip_optimizer = False + self._towers_left_before_skipping_optimizer = -1 + + def incorrectly_skip_optimizer_for_tower(self, tower_number): + self._should_skip_optimizer = True + self._towers_left_before_skipping_optimizer = tower_number + + def should_skip_optimizer(self): + if not self._should_skip_optimizer: + return False + if self._towers_left_before_skipping_optimizer == 0: + return True + else: + self._towers_left_before_skipping_optimizer -= 1 + return False + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + d = variable_scope.get_variable( + 'd', + initializer=constant_op.constant(2, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + another_predictions = math_ops.multiply(features, d) + another_loss = losses.absolute_difference( + labels=labels, + predictions=another_predictions, + reduction=losses.Reduction.SUM) + another_loss = math_ops.reduce_sum(another_loss) + + total_loss = math_ops.add(loss, another_loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + train_ops = [] + + optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0)) + train_ops.append(optimizer.minimize(loss, var_list=[c])) + if not self.should_skip_optimizer(): + another_optimizer = replicate_model_fn.TowerOptimizer( + gradient_descent.GradientDescentOptimizer(1.0)) + train_ops.append(another_optimizer.minimize(another_loss, var_list=[d])) + + train_op = control_flow_ops.group(train_ops) + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=total_loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + loss_reduction=losses.Reduction.SUM, + devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(features, labels, + model_fn_lib.ModeKeys.TRAIN, {}) + session.run(variables.global_variables_initializer()) + + # For each tower, loss = (feature * c - label) + (feature * d - label). + total_loss = (1.0 * 10 - 1.0 + 1.0 * 2.0 - 1.0) + ( + 2.0 * 10 - 2.0 + 2.0 * 2.0 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + session.run(estimator_spec.train_op) + + # loss' of c or loss' of d is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + # new value of d = 2 - learning rate * 3 = -1.0. + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertNear(7.0, session.run(c), 0.000001) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertNear(-1.0, session.run(d), 0.000001) + + def test_different_optimizer_calls_within_towers(self): + self.incorrectly_skip_optimizer_for_tower(1) + + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session(), ops_lib.Graph().as_default(): + with self.assertRaisesRegexp( + ValueError, '.+was.+supposed.+to.+make.+same.+optimizer.+calls.+'): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, + {}) + + +class FailToWrapOptimizerInTheModelFn(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + optimizer = gradient_descent.GradientDescentOptimizer(1.0) + train_op = optimizer.minimize(loss) + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=train_op) + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session(): + with self.assertRaisesRegexp(ValueError, + 'Please.+wrap.+with.+TowerOptimizer'): + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, devices=['/gpu:0', '/gpu:1']) + _ = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN, + {}) + class GetLossTowersTest(test_util.TensorFlowTestCase): @@ -358,8 +912,9 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): labels=[[0.6], [0.6]], params=None, config=None, + loss_reduction=losses.Reduction.SUM, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], name_scope_pattern='test_tower_{}') session.run(variables.global_variables_initializer()) @@ -382,6 +937,89 @@ class GetLossTowersTest(test_util.TensorFlowTestCase): c = variable_scope.get_variable('c', dtype=dtypes.float64) self.assertEqual(0.25, session.run(c)) + def test_gradients_are_computed_with_mean_reduction(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=model_fn_lib.ModeKeys.EVAL, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + loss_reduction=losses.Reduction.MEAN, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_devices=['/gpu:0'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('averaged_loss:0', tower_specs[0].loss.name) + self.assertEqual(0.5, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/averaged_loss:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(1.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + def test_variables_are_round_robined_correctly(self): + """Test that creates multiple variables and tests round-robin placement.""" + + def model_fn(mode, features, labels, params): + del params + for variable_name in ['a', 'b', 'c', 'd']: + c = variable_scope.get_variable( + variable_name, + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + return model_fn_lib.EstimatorSpec( + mode=mode, loss=math_ops.reduce_sum(loss)) + + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + model_fn, + mode=None, + features=[[0.6], [1.6], [2.6]], + labels=[[0.6], [0.6], [2.6]], + params=None, + loss_reduction=losses.Reduction.SUM, + config=None, + devices=['/gpu:0', '/gpu:1', '/gpu:3'], + local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'], + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 3) + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('/device:GPU:3', tower_specs[2].loss.device) + + with variable_scope.variable_scope('', reuse=True): + a = variable_scope.get_variable('a', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', a.device) + b = variable_scope.get_variable('b', dtype=dtypes.float64) + self.assertEqual('/device:GPU:1', b.device) + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual('/device:GPU:3', c.device) + d = variable_scope.get_variable('d', dtype=dtypes.float64) + self.assertEqual('/device:GPU:0', d.device) + class SplitBatchTest(test_util.TensorFlowTestCase): @@ -390,8 +1028,13 @@ class SplitBatchTest(test_util.TensorFlowTestCase): return list(map(evaluate_items, first_list)), list( map(evaluate_items, second_list)) + def assertSparseValuesEqual(self, a, b): + self.assertAllEqual(a.indices, b.indices) + self.assertAllEqual(a.values, b.values) + self.assertAllEqual(a.dense_shape, b.dense_shape) + def test_simple_half_split(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -404,7 +1047,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) def test_to_each_their_own(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -417,7 +1060,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) def test_one_batch(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = [0.0, 1.0, 2.0, 3.0] labels = [10.0, 11.0, 12.0, 13.0] feature_shards, label_shards = replicate_model_fn._split_batch( @@ -430,7 +1073,7 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) def test_half_split_in_dictionary(self): - with self.test_session() as session: # pylint: disable=unused-variable + with self.test_session(): features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} labels = [10.0, 11.0, 12.0, 13.0] @@ -444,6 +1087,58 @@ class SplitBatchTest(test_util.TensorFlowTestCase): self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + def test_sparse_tensor_can_be_split_unevenly(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 2], [2, 2]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 2]], values=[1., 2.], dense_shape=[2, 4]), + feature_shards[0]['x'].eval()) + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 2]], values=[3.], dense_shape=[1, 4]), + feature_shards[1]['x'].eval()) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + + def test_sparse_tensor_can_be_split_unevenly_repeated_row(self): + with self.test_session(): + features = { + 'x': + sparse_tensor.SparseTensor( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1.0, 2.0, 3.0], + dense_shape=[3, 4]) + } + labels = np.array([[1.0], [2.0]]) + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertSparseValuesEqual( + sparse_tensor.SparseTensorValue( + indices=[[0, 0], [1, 0], [1, 1]], + values=[1., 2., 3.], + dense_shape=[2, 4]), feature_shards[0]['x'].eval()) + + second_batch = feature_shards[1]['x'].eval() + self.assertFalse(len(second_batch.indices)) + self.assertFalse(len(second_batch.values)) + self.assertAllEqual([1, 4], second_batch.dense_shape) + self.assertAllEqual([[1.0]], label_shards[0].eval()) + self.assertAllEqual([[2.0]], label_shards[1].eval()) + def test_one_batch_in_dictionary(self): with self.test_session() as session: # pylint: disable=unused-variable features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} @@ -600,11 +1295,12 @@ class PredictSpecTest(test_util.TensorFlowTestCase): self.model_fn, mode=None, features=[[0.1], [0.2]], + loss_reduction=losses.Reduction.SUM, labels=[[], []], params=None, config=None, devices=['/gpu:0', '/gpu:1'], - local_ps_device='/gpu:0', + local_ps_devices=['/gpu:0'], ) session.run(variables.global_variables_initializer()) @@ -718,16 +1414,14 @@ class ReduceMetricVariablesTest(test_util.TensorFlowTestCase): variables.variables_initializer( ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) - with self.assertRaisesRegexp(ValueError, ''): + with self.assertRaisesRegexp( + ValueError, '.+Expected.+local.+variables.+but.+got.+instead.+'): session.run( replicate_model_fn._reduce_metric_variables(number_of_towers=3)) class MergeExportOutputsTest(test_util.TensorFlowTestCase): - def optimizer_fn(self): - return gradient_descent.GradientDescentOptimizer(1.0) - def model_fn(self, mode, features, labels, params): c = variable_scope.get_variable( 'c', @@ -769,7 +1463,6 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase): loss=math_ops.reduce_sum(loss), eval_metric_ops=metrics, predictions=predictions, - train_op=loss, # This train_op isn't actually used. export_outputs=export_outputs) def replicate_estimator_spec(self, session): @@ -777,13 +1470,13 @@ class MergeExportOutputsTest(test_util.TensorFlowTestCase): labels = np.array([0.01, 0.02]) replicated_model_fn = replicate_model_fn.replicate_model_fn( - self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + self.model_fn, devices=['/gpu:0', '/gpu:1']) estimator_spec = replicated_model_fn(features, labels, model_fn_lib.ModeKeys.PREDICT, {}) session.run(variables.global_variables_initializer()) return estimator_spec - def test_merde_predict_output(self): + def test_merge_predict_output(self): with self.test_session() as session: estimator_spec = self.replicate_estimator_spec(session) self.assertAllClose( @@ -850,25 +1543,66 @@ class GetLocalDevicesTest(test_util.TensorFlowTestCase): class LocalDeviceSetterTest(test_util.TensorFlowTestCase): def test_vars_are_on_ps_but_ops_are_on_workers(self): + ps_devices = ['/device:GPU:3'] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + local_device_setter = replicate_model_fn._local_device_setter( - ps_device='/device:GPU:3', worker_device='/device:GPU:2') + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') with ops_lib.device(local_device_setter): - c = variables.Variable(0.01) + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', b.device) + + c = variables.Variable(0.03) self.assertEqual('/device:GPU:3', c.device) - cc = variables.Variable(0.02) - self.assertEqual('/device:GPU:3', cc.device) + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) - ccc = variables.Variable(0.03) - self.assertEqual('/device:GPU:3', ccc.device) + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + def test_round_robin_placement(self): + ps_devices = [ + '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4' + ] + round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices)) + + local_device_setter = replicate_model_fn._local_device_setter( + ps_devices=ps_devices, + ps_strategy=round_robin, + worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + a = variables.Variable(0.01) + self.assertEqual('/device:GPU:0', a.device) + + b = variables.Variable(0.02) + self.assertEqual('/device:GPU:1', b.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', c.device) + + a_op = array_ops.concat(a, axis=0) + self.assertEqual('/device:GPU:2', a_op.device) + + b_op = array_ops.concat(b, axis=0) + self.assertEqual('/device:GPU:2', b_op.device) + + c = variables.Variable(0.03) + self.assertEqual('/device:GPU:4', c.device) + + d = variables.Variable(0.03) + self.assertEqual('/device:GPU:0', d.device) c_op = array_ops.concat(c, axis=0) self.assertEqual('/device:GPU:2', c_op.device) - cc_op = array_ops.concat(cc, axis=0) - self.assertEqual('/device:GPU:2', cc_op.device) - class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): @@ -939,7 +1673,7 @@ class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): dense_shape=constant_op.constant([2])) b = ops_lib.IndexedSlices(constant_op.constant([3.0, 4.0]), [0, 1]) - with self.assertRaisesRegexp(ValueError, ''): + with self.assertRaisesRegexp(ValueError, '.+name.+not.+expected.+'): _ = replicate_model_fn._compute_sum_on_device( [a, b], device='/device:GPU:0', name='cant_name_indexslices') diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index fe86a20ab1f69a0eaf9d7486142451dac6337274..180f1b68f3b56113dfbbfc100bd04efc3bb8b31f 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -221,6 +221,7 @@ py_test( name = "kmeans_test", size = "medium", srcs = ["python/ops/kmeans_test.py"], + shard_count = 4, srcs_version = "PY2AND3", tags = ["notsan"], # b/67512932 deps = [ diff --git a/tensorflow/contrib/factorization/examples/BUILD b/tensorflow/contrib/factorization/examples/BUILD index 363baa121ab3854a802ca3606e35597d31b35a57..bbe842bd5ccc7357805adda1df42ba8799fcd8f2 100644 --- a/tensorflow/contrib/factorization/examples/BUILD +++ b/tensorflow/contrib/factorization/examples/BUILD @@ -21,3 +21,14 @@ tf_py_test( ], tags = ["notsan"], ) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc index 31d08bfb65ea49e1378ffba480771d38ce16abec..a8c5d0763c28ba2b54f217405f0da65533f26b91 100644 --- a/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc +++ b/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc @@ -57,11 +57,11 @@ typedef Eigen::Map< class MaskedMatmulOp : public OpKernel { public: - explicit MaskedMatmulOp(OpKernelConstruction* context) - : OpKernel(context) { - OP_REQUIRES_OK(context, context->MatchSignature( - {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, - {DT_FLOAT})); + explicit MaskedMatmulOp(OpKernelConstruction* context) : OpKernel(context) { + OP_REQUIRES_OK( + context, + context->MatchSignature( + {DT_FLOAT, DT_FLOAT, DT_INT64, DT_BOOL, DT_BOOL}, {DT_FLOAT})); } void Compute(OpKernelContext* context) override { @@ -110,12 +110,11 @@ class MaskedMatmulOp : public OpKernel { num_nonzero_elements, 2); Tensor* prod_values_tensor; - OP_REQUIRES_OK(context, - context->allocate_output( - 0, TensorShape({num_nonzero_elements}), - &prod_values_tensor)); - EigenMatFloatMap prod_values(prod_values_tensor->vec().data(), - 1, num_nonzero_elements); + OP_REQUIRES_OK(context, context->allocate_output( + 0, TensorShape({num_nonzero_elements}), + &prod_values_tensor)); + EigenMatFloatMap prod_values(prod_values_tensor->vec().data(), 1, + num_nonzero_elements); auto get_a_index = [&indices_mat, &a_dim_0](int64 i) { int64 a_index = internal::SubtleMustCopy(indices_mat(i, 0)); @@ -182,8 +181,8 @@ class MaskedMatmulOp : public OpKernel { } }; // Shard the work. - worker_threads.workers->ParallelFor( - num_nonzero_elements, cost_per_unit, work); + worker_threads.workers->ParallelFor(num_nonzero_elements, cost_per_unit, + work); } }; REGISTER_KERNEL_BUILDER(Name("MaskedMatmul").Device(DEVICE_CPU), diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index 96cc80ce241347ebca5b68140f1b1c8b9898ae72..23137e0a973c0bdd2cdbd97159f7fd310178bf54 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -192,11 +192,11 @@ class KMeans(object): # Computes Euclidean distance. Note the first and third terms are # broadcast additions. squared_distance = ( - math_ops.reduce_sum(math_ops.square(inp), 1, keep_dims=True) - + math_ops.reduce_sum(math_ops.square(inp), 1, keepdims=True) - 2 * math_ops.matmul(inp, clusters, transpose_b=True) + array_ops.transpose( math_ops.reduce_sum( - math_ops.square(clusters), 1, keep_dims=True))) + math_ops.square(clusters), 1, keepdims=True))) output.append(squared_distance) return output @@ -261,8 +261,8 @@ class KMeans(object): inp, clusters, 1) if self._distance_metric == COSINE_DISTANCE: distances *= 0.5 - output.append((score, array_ops.squeeze(distances), - array_ops.squeeze(indices))) + output.append((score, array_ops.squeeze(distances, [-1]), + array_ops.squeeze(indices, [-1]))) return zip(*output) def _clusters_l2_normalized(self): diff --git a/tensorflow/contrib/factorization/python/ops/gmm.py b/tensorflow/contrib/factorization/python/ops/gmm.py index 0d67e09f8151b48c97094b6b48f26e63443707ef..b2dfe48b2dbe0ec0975f865bba95a7ceba0f590c 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm.py +++ b/tensorflow/contrib/factorization/python/ops/gmm.py @@ -24,17 +24,16 @@ import numpy as np from tensorflow.contrib import framework from tensorflow.contrib.factorization.python.ops import gmm_ops from tensorflow.contrib.framework.python.framework import checkpoint_utils -from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn as model_fn_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import logging_ops as logging -from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util def _streaming_sum(scalar_tensor): @@ -70,8 +69,8 @@ class _InitializeClustersHook(session_run_hook.SessionRunHook): class GMM(estimator.Estimator): """An estimator for GMM clustering.""" SCORES = 'scores' + LOG_LIKELIHOOD = 'loss' ASSIGNMENTS = 'assignments' - ALL_SCORES = 'all_scores' def __init__(self, num_clusters, @@ -113,10 +112,7 @@ class GMM(estimator.Estimator): yield result[GMM.ASSIGNMENTS] def score(self, input_fn=None, batch_size=None, steps=None): - """Predict total sum of distances to nearest clusters. - - Note that this function is different from the corresponding one in sklearn - which returns the negative of the sum of distances. + """Predict total log-likelihood. Args: input_fn: see predict. @@ -124,11 +120,11 @@ class GMM(estimator.Estimator): steps: see predict. Returns: - Total sum of distances to nearest clusters. + Total log-likelihood. """ results = self.evaluate(input_fn=input_fn, batch_size=batch_size, steps=steps) - return np.sum(results[GMM.SCORES]) + return np.log(np.sum(np.exp(results[GMM.SCORES]))) def weights(self): """Returns the cluster weights.""" @@ -158,26 +154,26 @@ class GMM(estimator.Estimator): def _model_fn(features, labels, mode, config): """Model function.""" assert labels is None, labels - (all_scores, + (loss, + scores, model_predictions, - losses, training_op, + training_op, init_op, is_initialized) = gmm_ops.gmm(self._parse_tensor_or_dict(features), self._training_initial_clusters, self._num_clusters, self._random_seed, self._covariance_type, self._params) - incr_step = state_ops.assign_add(variables.get_global_step(), 1) - loss = math_ops.reduce_sum(losses) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) training_op = with_dependencies([training_op, incr_step], loss) training_hooks = [_InitializeClustersHook( init_op, is_initialized, config.is_chief)] predictions = { - GMM.ALL_SCORES: all_scores[0], GMM.ASSIGNMENTS: model_predictions[0][0], } eval_metric_ops = { - GMM.SCORES: _streaming_sum(loss), + GMM.SCORES: scores, + GMM.LOG_LIKELIHOOD: _streaming_sum(loss), } return model_fn_lib.ModelFnOps(mode=mode, predictions=predictions, eval_metric_ops=eval_metric_ops, diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py index a61681c7f5a69a0fff1089404fc80b95c1c3106e..98d6434f4752b224201e38bed05ccd14428a758b 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py @@ -21,7 +21,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -36,7 +35,6 @@ from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.ops.embedding_ops import embedding_lookup -from tensorflow.python.summary import summary # Machine epsilon. MEPS = np.finfo(float).eps @@ -253,14 +251,16 @@ class GmmAlgorithm(object): return ret def scores(self): - """Returns the distances to each class. + """Returns the per-sample likelihood fo the data. Returns: - A tuple with two Tensors. The first contains the distance to - each class. The second contains the distance to the assigned - class. + Log probabilities of each data point. """ - return (self._all_scores, self._scores) + return self._scores + + def log_likelihood_op(self): + """Returns the log-likelihood operation.""" + return self._log_likelihood_op def _define_graph(self, data): """Define graph for a single iteration. @@ -276,7 +276,8 @@ class GmmAlgorithm(object): self._define_expectation_operation(shard_id) self._define_partial_maximization_operation(shard_id, shard) self._define_maximization_operation(len(data)) - self._define_distance_to_clusters(data) + self._define_loglikelihood_operation() + self._define_score_samples() def _define_full_covariance_probs(self, shard_id, shard): """Defines the full covariance probabilties per example in a class. @@ -440,50 +441,20 @@ class GmmAlgorithm(object): state_ops.assign( self._covs, new_covs, validate_shape=False)) - def _define_distance_to_clusters(self, data): - """Defines the Mahalanobis distance to the assigned Gaussian.""" - # TODO(xavigonzalvo): reuse (input - mean) * cov^-1 * (input - - # mean) from log probability function. - self._all_scores = [] - for shard in data: - all_scores = [] - shard = array_ops.expand_dims(shard, 0) - for c in xrange(self._num_classes): - if self._covariance_type == FULL_COVARIANCE: - cov = self._covs[c, :, :] - elif self._covariance_type == DIAG_COVARIANCE: - cov = array_ops.diag(self._covs[c, :]) - inverse = linalg_ops.matrix_inverse(cov + self._min_var) - inv_cov = array_ops.tile( - array_ops.expand_dims(inverse, 0), - array_ops.stack([self._num_examples, 1, 1])) - diff = array_ops.transpose(shard - self._means[c, :, :], perm=[1, 0, 2]) - m_left = math_ops.matmul(diff, inv_cov) - all_scores.append( - math_ops.sqrt( - math_ops.matmul( - m_left, array_ops.transpose( - diff, perm=[0, 2, 1])))) - self._all_scores.append( - array_ops.reshape( - array_ops.concat(all_scores, 1), - array_ops.stack([self._num_examples, self._num_classes]))) - - # Distance to the associated class. - self._all_scores = array_ops.concat(self._all_scores, 0) - assignments = array_ops.concat(self.assignments(), 0) - rows = math_ops.to_int64(math_ops.range(0, self._num_examples)) - indices = array_ops.concat( - [array_ops.expand_dims(rows, 1), array_ops.expand_dims(assignments, 1)], - 1) - self._scores = array_ops.gather_nd(self._all_scores, indices) - def _define_loglikelihood_operation(self): """Defines the total log-likelihood of current iteration.""" - self._ll_op = [] + op = [] for prior_probs in self._prior_probs: - self._ll_op.append(math_ops.reduce_sum(math_ops.log(prior_probs))) - summary.scalar('ll', math_ops.reduce_sum(self._ll_op)) + op.append(math_ops.reduce_logsumexp(prior_probs)) + self._log_likelihood_op = math_ops.reduce_logsumexp(op) + + def _define_score_samples(self): + """Defines the likelihood of each data sample.""" + op = [] + for shard_id, prior_probs in enumerate(self._prior_probs): + op.append(prior_probs + math_ops.log(self._w[shard_id])) + self._scores = array_ops.squeeze( + math_ops.reduce_logsumexp(op, axis=2, keep_dims=True), axis=0) def gmm(inp, @@ -511,14 +482,9 @@ def gmm(inp, Returns: Note: tuple of lists returned to be consistent with skflow A tuple consisting of: - all_scores: A matrix (or list of matrices) of dimensions (num_input, - num_clusters) where the value is the distance of an input vector and a - cluster center. assignments: A vector (or list of vectors). Each element in the vector corresponds to an input row in 'inp' and specifies the cluster id corresponding to the input. - scores: Similar to assignments but specifies the distance to the - assigned cluster instead. training_op: an op that runs an iteration of training. init_op: an op that runs the initialization. """ @@ -532,6 +498,7 @@ def gmm(inp, gmm_tool = GmmAlgorithm(inp, num_clusters, initial_means, params, covariance_type, random_seed) assignments = gmm_tool.assignments() - all_scores, scores = gmm_tool.scores() - return ([all_scores], [assignments], [scores], gmm_tool.training_ops(), + scores = gmm_tool.scores() + loss = gmm_tool.log_likelihood_op() + return (loss, scores, [assignments], gmm_tool.training_ops(), gmm_tool.init_ops(), gmm_tool.is_initialized()) diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py index c50e82db8a230012ba13c1d7ad7e28c23bd27355..888c3c238c2654ea11ea3bf8270d6c3fcd951a03 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py @@ -122,17 +122,23 @@ class GmmOpsTest(test.TestCase): g.seed = 5 with self.test_session() as sess: data = constant_op.constant(self.data, dtype=dtypes.float32) - _, assignments, _, training_op, init_op, _ = gmm_ops.gmm( + loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm( data, 'random', num_classes, random_seed=self.seed) variables.global_variables_initializer().run() sess.run(init_op) + first_loss = sess.run(loss_op) for _ in xrange(self.iterations): sess.run(training_op) assignments = sess.run(assignments) + end_loss = sess.run(loss_op) + scores = sess.run(scores) + self.assertEqual((self.num_examples, 1), scores.shape) accuracy = np.mean( np.asarray(self.true_assignments) == np.squeeze(assignments)) logging.info('Accuracy: %f', accuracy) + logging.info('First loss: %f, end loss: %f', first_loss, end_loss) + self.assertGreater(end_loss, first_loss) self.assertGreater(accuracy, 0.98) def testParams(self): diff --git a/tensorflow/contrib/factorization/python/ops/gmm_test.py b/tensorflow/contrib/factorization/python/ops/gmm_test.py index 7717b47daefce9ff65b1f1e84f671a463cf2e826..00a4734eb6d89cd02484f1c5161366377cc71208 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import numpy as np -from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.factorization.python.ops import gmm as gmm_lib from tensorflow.contrib.learn.python.learn.estimators import kmeans @@ -30,12 +29,9 @@ from tensorflow.python.framework import random_seed as random_seed_lib from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import random_ops -from tensorflow.python.platform import flags from tensorflow.python.platform import test from tensorflow.python.training import queue_runner -FLAGS = flags.FLAGS - class GMMTest(test.TestCase): @@ -64,9 +60,8 @@ class GMMTest(test.TestCase): self.batch_size = self.num_points self.true_centers = self.make_random_centers(self.num_centers, self.num_dims) - self.points, self.assignments, self.scores = self.make_random_points( + self.points, self.assignments = self.make_random_points( self.true_centers, self.num_points) - self.true_score = np.add.reduce(self.scores) # Use initial means from kmeans (just like scikit-learn does). clusterer = kmeans.KMeansClustering(num_clusters=self.num_centers) @@ -86,24 +81,7 @@ class GMMTest(test.TestCase): offsets = np.round( np.random.randn(num_points, num_dims).astype(np.float32) * 20) points = centers[assignments] + offsets - means = [ - np.mean( - points[assignments == center], axis=0) - for center in xrange(num_centers) - ] - covs = [ - np.cov(points[assignments == center].T) - for center in xrange(num_centers) - ] - scores = [] - for r in xrange(num_points): - scores.append( - np.sqrt( - np.dot( - np.dot(points[r, :] - means[assignments[r]], - np.linalg.inv(covs[assignments[r]])), points[r, :] - - means[assignments[r]]))) - return (points, assignments, scores) + return (points, assignments) def test_weights(self): """Tests the shape of the weights.""" @@ -136,8 +114,7 @@ class GMMTest(test.TestCase): gmm.fit(input_fn=self.input_fn(), steps=10) score2 = gmm.score(input_fn=self.input_fn(batch_size=self.num_points), steps=1) - self.assertGreater(score1, score2) - self.assertNear(self.true_score, score2, self.true_score * 0.15) + self.assertLess(score1, score2) def test_infer(self): gmm = gmm_lib.GMM(self.num_centers, @@ -149,8 +126,7 @@ class GMMTest(test.TestCase): # Make a small test set num_points = 40 - points, true_assignments, true_offsets = ( - self.make_random_points(clusters, num_points)) + points, true_assignments = self.make_random_points(clusters, num_points) assignments = [] for item in gmm.predict_assignments( @@ -159,11 +135,6 @@ class GMMTest(test.TestCase): assignments = np.ravel(assignments) self.assertAllEqual(true_assignments, assignments) - # Test score - score = gmm.score(input_fn=self.input_fn(points=points, - batch_size=num_points), steps=1) - self.assertNear(score, np.sum(true_offsets), 4.05) - def _compare_with_sklearn(self, cov_type): # sklearn version. iterations = 40 diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py index 9a5413fc3f2642443621b33d325e3d8c893fd6ac..c861cfff544a78617aa1ace730b50c094cf16330 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -25,6 +25,7 @@ import time from tensorflow.contrib.factorization.python.ops import clustering_ops from tensorflow.python.estimator import estimator from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.export import export_output from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -32,6 +33,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics from tensorflow.python.ops import state_ops from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training import training_util @@ -141,7 +143,7 @@ class _ModelFn(object): def model_fn(self, features, mode, config): """Model function for the estimator. - Note that this does not take a `1abels` arg. This works, but `input_fn` must + Note that this does not take a `labels` arg. This works, but `input_fn` must return either `features` or, equivalently, `(features, None)`. Args: @@ -207,6 +209,15 @@ class _ModelFn(object): training_hooks.append( _LossRelativeChangeHook(loss, self._relative_tolerance)) + export_outputs = { + KMeansClustering.ALL_DISTANCES: + export_output.PredictOutput(all_distances[0]), + KMeansClustering.CLUSTER_INDEX: + export_output.PredictOutput(model_predictions[0]), + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(model_predictions[0]) + } + return model_fn_lib.EstimatorSpec( mode=mode, predictions={ @@ -216,7 +227,8 @@ class _ModelFn(object): loss=loss, train_op=training_op, eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)}, - training_hooks=training_hooks) + training_hooks=training_hooks, + export_outputs=export_outputs) # TODO(agarwal,ands): support sharded input. diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 4709d7942583f1406a3fa0ff3a078d0283872ea6..f9598bfc08c05ea3bba88b3135da0cf2e6bb0c95 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -194,15 +194,7 @@ class KMeansTest(KMeansTestBase): score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) self.assertNear(self.true_score, score, self.true_score * 0.01) - def test_infer(self): - kmeans = self._kmeans() - # Make a call to fit to initialize the cluster centers. - max_steps = 1 - kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) - clusters = kmeans.cluster_centers() - - # Make a small test set - num_points = 10 + def _infer_helper(self, kmeans, clusters, num_points): points, true_assignments, true_offsets = make_random_points( clusters, num_points) input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1) @@ -223,6 +215,17 @@ class KMeansTest(KMeansTestBase): np.sum(np.square(clusters), axis=1, keepdims=True))) self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.cluster_centers() + + # Run inference on small datasets. + self._infer_helper(kmeans, clusters, 10) + self._infer_helper(kmeans, clusters, 1) + class KMeansTestMultiStageInit(KMeansTestBase): diff --git a/tensorflow/contrib/feature_column/BUILD b/tensorflow/contrib/feature_column/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6fc053759c58d30c24657dd22e7d12be46fc7a7e --- /dev/null +++ b/tensorflow/contrib/feature_column/BUILD @@ -0,0 +1,37 @@ +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "feature_column_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":sequential_feature_column", + ], +) + +py_library( + name = "sequential_feature_column", + srcs = ["python/feature_column/sequential_feature_column.py"], + srcs_version = "PY2AND3", + deps = [], +) diff --git a/tensorflow/contrib/feature_column/__init__.py b/tensorflow/contrib/feature_column/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6da7b126931effae9cc97091a27070d7013450d4 --- /dev/null +++ b/tensorflow/contrib/feature_column/__init__.py @@ -0,0 +1,30 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental utilities for tf.feature_column.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.feature_column.python.feature_column.sequential_feature_column import * + +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py new file mode 100644 index 0000000000000000000000000000000000000000..690a44ff4368663306733300a1ea70397fb93e1e --- /dev/null +++ b/tensorflow/contrib/feature_column/python/feature_column/sequential_feature_column.py @@ -0,0 +1,19 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Experimental methods for tf.feature_column sequential input.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index dc5a04a0b15870babbc98cf104e109caf829901c..eccce99071dc1477cf4f3bb152f3304b3b0fc35a 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -155,7 +155,10 @@ tf_py_test( data = [ ":test_data", ], - tags = ["manual"], + tags = [ + "manual", + "notap", + ], ) py_library( diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py index 871dff7bbe4912f0daf2bc184d6b0f12510abee7..daba965a98893b992abdc598ec713f13020d6e91 100644 --- a/tensorflow/contrib/ffmpeg/__init__.py +++ b/tensorflow/contrib/ffmpeg/__init__.py @@ -26,6 +26,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio +from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op.cc b/tensorflow/contrib/ffmpeg/decode_audio_op.cc index 92fad70b1f9cc55e0690a3fbb35abcf56aa68f16..5ab57ca4cd413bd92f1576278b22d2602c905309 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_audio_op.cc @@ -44,7 +44,7 @@ const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"}; void Decode(OpKernelContext* context, const tensorflow::StringPiece& file_contents, const string& file_format, const int32 samples_per_second, - const int32 channel_count) { + const int32 channel_count, const string& stream) { // Write the input data to a temp file. const string temp_filename = io::GetTempFilename(file_format); OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents)); @@ -54,7 +54,7 @@ void Decode(OpKernelContext* context, std::vector output_samples; Status result = ffmpeg::ReadAudioFile(temp_filename, file_format, samples_per_second, - channel_count, &output_samples); + channel_count, stream, &output_samples); if (result.code() == error::Code::NOT_FOUND) { OP_REQUIRES( context, result.ok(), @@ -99,7 +99,12 @@ void Decode(OpKernelContext* context, */ class DecodeAudioOpV2 : public OpKernel { public: - explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) {} + explicit DecodeAudioOpV2(OpKernelConstruction* context) : OpKernel(context) { + string stream; + if (context->GetAttr("stream", &stream).ok()) { + stream_ = stream; + } + } void Compute(OpKernelContext* context) override { OP_REQUIRES( @@ -153,8 +158,12 @@ class DecodeAudioOpV2 : public OpKernel { errors::InvalidArgument("channel_count must be positive, but got: ", channel_count)); - Decode(context, contents, file_format, samples_per_second, channel_count); + Decode(context, contents, file_format, samples_per_second, channel_count, + stream_); } + + private: + string stream_; }; REGISTER_KERNEL_BUILDER(Name("DecodeAudioV2").Device(DEVICE_CPU), @@ -166,6 +175,7 @@ REGISTER_OP("DecodeAudioV2") .Input("samples_per_second: int32") .Input("channel_count: int32") .Output("sampled_audio: float") + .Attr("stream: string = ''") .SetShapeFn([](shape_inference::InferenceContext* c) { const Tensor* channels_tensor = c->input_tensor(3); if (channels_tensor == nullptr) { @@ -237,7 +247,7 @@ class DecodeAudioOp : public OpKernel { const tensorflow::StringPiece file_contents = contents.scalar()(); Decode(context, file_contents, file_format_, samples_per_second_, - channel_count_); + channel_count_, ""); } private: diff --git a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py index 0d7c9cb99e8a5fad4a7ccf86d7253170ace91fd7..3dc663bb6f589d09ed067eae09d7d7dd0c40ec95 100644 --- a/tensorflow/contrib/ffmpeg/decode_audio_op_test.py +++ b/tensorflow/contrib/ffmpeg/decode_audio_op_test.py @@ -33,7 +33,8 @@ class DecodeAudioOpTest(test.TestCase): def _loadFileAndTest(self, filename, file_format, duration_sec, samples_per_second, channel_count, - samples_per_second_tensor=None, feed_dict=None): + samples_per_second_tensor=None, feed_dict=None, + stream=None): """Loads an audio file and validates the output tensor. Args: @@ -49,6 +50,9 @@ class DecodeAudioOpTest(test.TestCase): feed_dict: Used when evaluating the `decode_audio` op. If not provided, will be empty. Useful when providing a placeholder for `samples_per_second_tensor`. + stream: A string specifying which stream from the content file + should be decoded. The default value is '' which leaves the + decision to ffmpeg. """ if samples_per_second_tensor is None: samples_per_second_tensor = samples_per_second @@ -62,7 +66,7 @@ class DecodeAudioOpTest(test.TestCase): contents, file_format=file_format, samples_per_second=samples_per_second_tensor, - channel_count=channel_count) + channel_count=channel_count, stream=stream) audio = audio_op.eval(feed_dict=feed_dict or {}) self.assertEqual(len(audio.shape), 2) self.assertNear( @@ -72,6 +76,17 @@ class DecodeAudioOpTest(test.TestCase): 0.1 * audio.shape[0]) self.assertEqual(audio.shape[1], channel_count) + def testStreamIdentifier(self): + # mono_16khz_mp3_32khz_aac.mp4 was generated from: + # ffmpeg -i tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3.mp4 \ + # -i tensorflow/contrib/ffmpeg/testdata/mono_32khz_aac.mp4 \ + # -strict -2 -map 0:a -map 1:a \ + # tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 + self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000, + 1, stream='0') + self._loadFileAndTest('mono_16khz_mp3_32khz_aac.mp4', 'mp4', 2.77, 20000, + 1, stream='1') + def testMonoMp3(self): self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 1) self._loadFileAndTest('mono_16khz.mp3', 'mp3', 0.57, 20000, 2) diff --git a/tensorflow/contrib/ffmpeg/decode_video_op.cc b/tensorflow/contrib/ffmpeg/decode_video_op.cc index d44032968d559bec14722902a4d47d22c46ea4aa..6f8ad486d10a825a277749157d68fa671b9f8d3a 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op.cc +++ b/tensorflow/contrib/ffmpeg/decode_video_op.cc @@ -102,16 +102,12 @@ REGISTER_OP("DecodeVideo") return Status::OK(); }) .Doc(R"doc( -Processes the contents of an audio file into a tensor using FFmpeg to decode +Processes the contents of an video file into a tensor using FFmpeg to decode the file. -One row of the tensor is created for each channel in the audio file. Each -channel contains audio samples starting at the beginning of the audio and -having `1/samples_per_second` time between them. If the `channel_count` is -different from the contents of the file, channels will be merged or created. - -contents: The binary audio file contents, as a string or rank-0 string - tensor. +contents: The binary contents of the video file to decode. This is a + scalar. +output: A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output. )doc"); } // namespace ffmpeg diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py index 4d1fac4ef8afbf44cd45bae065f8a95b0527079a..b43b6b8919223bd7731209d5423b142601396ea5 100644 --- a/tensorflow/contrib/ffmpeg/decode_video_op_test.py +++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py @@ -20,11 +20,9 @@ from __future__ import print_function import os.path -import six +import six # pylint: disable=unused-import from tensorflow.contrib import ffmpeg -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops from tensorflow.python.ops import image_ops from tensorflow.python.platform import resource_loader from tensorflow.python.platform import test @@ -32,7 +30,8 @@ from tensorflow.python.platform import test class DecodeVideoOpTest(test.TestCase): - def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, index): + def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, + index): """Loads an video file and validates the output tensor. Args: @@ -40,6 +39,8 @@ class DecodeVideoOpTest(test.TestCase): width: The width of the video. height: The height of the video. frames: The frames of the video. + bmp_filename: The filename for the bmp file. + index: Index location inside the video. """ with self.test_session(): path = os.path.join(resource_loader.get_data_files_path(), 'testdata', @@ -48,7 +49,7 @@ class DecodeVideoOpTest(test.TestCase): contents = f.read() bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata', - bmp_filename) + bmp_filename) with open(bmp_path, 'rb') as f: bmp_contents = f.read() @@ -58,7 +59,7 @@ class DecodeVideoOpTest(test.TestCase): video_op = ffmpeg.decode_video(contents) video = video_op.eval() self.assertEqual(video.shape, (frames, height, width, 3)) - self.assertAllEqual(video[index,:,:,:], image) + self.assertAllEqual(video[index, :, :, :], image) def testMp4(self): self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99) diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 949ae9ad9e4b045ee1b5cc82d49c0e7468c2005d..6b455567d766dbe6d380a498bd7f521db27e077b 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -19,6 +19,7 @@ cc_library( ], deps = [ "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", "@protobuf_archive//:protobuf_headers", ], ) diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 201774e1d011f35df9c3803f2ed8818cc9b1c1c2..e61221a6b0d34373279a379f356c99c379488182 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -44,39 +44,43 @@ std::vector FfmpegAudioCommandLine(const string& input_filename, const string& output_filename, const string& input_format_id, int32 samples_per_second, - int32 channel_count) { - return {"-nostats", // No additional progress display. - "-nostdin", // No interactive commands accepted. - "-f", input_format_id, // eg: "mp3" - "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, - "-loglevel", "info", // Enable verbose logging to support debugging. - "-map_metadata", "-1", // Copy global metadata from input to output. - "-vn", // No video recording. - "-ac:a:0", StrCat(channel_count), "-ar:a:0", - StrCat(samples_per_second), - // Output set (in several ways) to signed 16-bit little-endian ints. - "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", - "-sn", // No subtitle recording. - "-y", // Overwrite output file. - StrCat(output_filename)}; + int32 channel_count, + const string& stream) { + std::vector command({ + "-nostats", // No additional progress display. + "-nostdin", // No interactive commands accepted. + "-f", input_format_id, // eg: "mp3" + "-probesize", StrCat(kDefaultProbeSize), "-i", input_filename, + "-loglevel", "error", // Print errors only. + "-hide_banner", // Skip printing build options, version, etc. + "-map_metadata", "-1", // Copy global metadata from input to output. + "-vn", // No video recording. + "-ac:a:0", StrCat(channel_count), "-ar:a:0", StrCat(samples_per_second), + // Output set (in several ways) to signed 16-bit little-endian ints. + "-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le", + "-sn", // No subtitle recording. + "-y" // Overwrite output file. + }); + if (!stream.empty()) { + command.emplace_back("-map"); + command.emplace_back(StrCat("0:", stream)); + } + command.emplace_back(StrCat(output_filename)); + + return command; } std::vector FfmpegVideoCommandLine(const string& input_filename, const string& output_filename) { return {"-nostats", // No additional progress display. "-nostdin", // No interactive commands accepted. - "-i", - input_filename, - "-f", - "image2pipe", - "-probesize", - StrCat(kDefaultProbeSize), - "-loglevel", - "info", // Enable verbose logging to support debugging. - "-vcodec", - "rawvideo", - "-pix_fmt", - "rgb24", + "-i", input_filename, "-f", "image2pipe", "-probesize", + StrCat(kDefaultProbeSize), "-loglevel", + // Info is needed to get the information about stream, etc. + // It is generated to a separate file, not stdout/stderr. + "info", + "-hide_banner", // Skip printing build options, version, etc. + "-vcodec", "rawvideo", "-pix_fmt", "rgb24", "-y", // Overwrite output file. StrCat(output_filename)}; } @@ -121,7 +125,6 @@ bool IsBinaryInstalled(const string& binary_name) { std::transform(args.begin(), args.end(), std::back_inserter(args_chars), [](const string& s) { return const_cast(s.c_str()); }); args_chars.push_back(nullptr); - ::execvp(kFfmpegExecutable, args_chars.data()); // exec only returns on error. const int error = errno; @@ -220,7 +223,8 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, Status ReadInfoFile(const string& filename, uint32* width, uint32* height, uint32* frames) { string data; - ReadFileToString(Env::Default(), filename, &data); + TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data)) + << "Could not read FFmpeg file: " << filename; bool in_output = false; bool in_mapping = false; uint32 frames_value = 0; @@ -305,13 +309,12 @@ Status WriteFile(const string& filename, StringPiece contents) { Status ReadAudioFile(const string& filename, const string& audio_format_id, int32 samples_per_second, int32 channel_count, - std::vector* output_samples) { + const string& stream, std::vector* output_samples) { // Create an argument list. string output_filename = io::GetTempFilename("raw"); const std::vector args = FfmpegAudioCommandLine(filename, output_filename, audio_format_id, - samples_per_second, channel_count); - + samples_per_second, channel_count, stream); // Unfortunately, it's impossible to differentiate an exec failure due to the // binary being missing and an error from the binary's execution. Therefore, // check to see if the binary *should* be available. If not, return an error @@ -365,7 +368,6 @@ Status ReadVideoFile(const string& filename, std::vector* output_data, // Create an argument list. const std::vector args = FfmpegVideoCommandLine(filename, output_filename); - // Execute ffmpeg and report errors. pid_t child_pid = ::fork(); if (child_pid < 0) { @@ -377,7 +379,7 @@ Status ReadVideoFile(const string& filename, std::vector* output_data, open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600); if (fd < 0) { const int error = errno; - LOG(ERROR) << "FFmpeg stderr file coule not be created: " + LOG(ERROR) << "FFmpeg stderr file could not be created: " << strerror(error); ::_exit(error); } diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc index 85b61b26163d87a10d4e316720b4f633e038bbec..05728b3d37570d06f2f8af67e3b0612d21d07601 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_test.cc @@ -32,10 +32,8 @@ namespace tensorflow { namespace ffmpeg { namespace { -const char kTestWavFilename[] = - "contrib/ffmpeg/testdata/mono_10khz.wav"; -const char kTestMp3Filename[] = - "contrib/ffmpeg/testdata/test_sound1.mp3"; +const char kTestWavFilename[] = "contrib/ffmpeg/testdata/mono_10khz.wav"; +const char kTestMp3Filename[] = "contrib/ffmpeg/testdata/test_sound1.mp3"; // Set to true via a command line flag iff the test is expected to have FFmpeg // installed. @@ -139,7 +137,7 @@ TEST(FfmpegLibTest, TestRoundTripWav) { } // namespace ffmpeg } // namespace tensorflow -int main(int argc, char **argv) { +int main(int argc, char** argv) { tensorflow::string usage = tensorflow::ffmpeg::ParseTestFlags(&argc, argv); testing::InitGoogleTest(&argc, argv); if (argc != 1) { diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc index 39e7e90cccf1012eb42261bde55d0dc3b7f278ef..d6c885a32424334bfc28c830e3701f219aa244ee 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -20,9 +20,8 @@ #include #include - -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h index c5ea1432bf8b61c87615074a93a45325371c4c87..a8d5a0dd83fb504b5e6671c3e82dc7d2dd3e6a9b 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_lib.h +++ b/tensorflow/contrib/ffmpeg/ffmpeg_lib.h @@ -13,8 +13,8 @@ // limitations under the License. // ============================================================================= -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ +#ifndef TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ +#define TENSORFLOW_CONTRIB_FFMPEG_FFMPEG_LIB_H_ #include #include @@ -42,7 +42,7 @@ Status WriteFile(const string& filename, tensorflow::StringPiece contents); // contain a separate sample for each channel. Frames are ordered by time. Status ReadAudioFile(const string& filename, const string& audio_format_id, int32 samples_per_second, int32 channel_count, - std::vector* output_samples); + const string& stream, std::vector* output_samples); // Creates an audio file using ffmpeg in a specific format. The samples are in // [-1.0, 1.0]. If there are multiple channels in the audio then each frame will @@ -61,4 +61,4 @@ Status ReadVideoFile(const string& filename, std::vector* output_data, } // namespace ffmpeg } // namespace tensorflow -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_ +#endif // TENSORFLOW_CONTRIB_FFMPEG_DEFAULT_FFMPEG_LIB_H_ diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py index 78ead471d2cf9f0654a06dc022d7cc592d14c710..020b5c99c61019254bef0b1dff6bc5901c92758a 100644 --- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py +++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py +from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py from tensorflow.contrib.util import loader @@ -30,7 +31,7 @@ _ffmpeg_so = loader.load_op_library( def decode_audio(contents, file_format=None, samples_per_second=None, - channel_count=None): + channel_count=None, stream=None): """Create an op that decodes the contents of an audio file. Note that ffmpeg is free to select the "best" audio track from an mp4. @@ -50,6 +51,9 @@ def decode_audio(contents, file_format=None, samples_per_second=None, `contents` have more than this number, then some channels will be merged or dropped. If `contents` has fewer than this, then additional channels will be created from the existing ones. + stream: A string specifying which stream from the content file + should be decoded, e.g., '0' means the 0-th stream. + The default value is '' which leaves the decision to ffmpeg. Returns: A rank-2 tensor that has time along dimension 0 and channels along @@ -60,7 +64,7 @@ def decode_audio(contents, file_format=None, samples_per_second=None, """ return gen_decode_audio_op_py.decode_audio_v2( contents, file_format=file_format, samples_per_second=samples_per_second, - channel_count=channel_count) + channel_count=channel_count, stream=stream) ops.NotDifferentiable('DecodeAudio') diff --git a/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 b/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2485da86d60837800fbb0b390c440e674de25993 Binary files /dev/null and b/tensorflow/contrib/ffmpeg/testdata/mono_16khz_mp3_32khz_aac.mp4 differ diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 5b659ddaa1386736eb8cc05a203ed1827ccd160e..9e5f54f0973eae899ca65e4098358107053cb7d4 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -11,11 +11,12 @@ package(default_visibility = [ ]) load("//tensorflow:tensorflow.bzl", "py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") tf_custom_op_py_library( name = "framework_py", @@ -31,8 +32,10 @@ tf_custom_op_py_library( "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", + "python/ops/critical_section_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", + "python/ops/script_ops.py", "python/ops/sort_ops.py", "python/ops/variables.py", ], @@ -60,6 +63,7 @@ tf_custom_op_py_library( "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", @@ -70,6 +74,7 @@ tf_custom_op_py_library( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", "//third_party/py/numpy", "@six_archive//:six", ], @@ -173,6 +178,21 @@ py_test( ], ) +cuda_py_test( + name = "critical_section_test", + size = "medium", + srcs = ["python/ops/critical_section_test.py"], + additional_deps = [ + "//tensorflow/python:client_testlib", + ":framework_py", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:platform_test", + "//tensorflow/python:resource_variable_ops", + ], +) + py_test( name = "accumulate_n_v2_eager_test", size = "small", diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 4edc77f86ba786ca547b8d3842e2cf02833fbbac..a49d42cd525434d4ffd4a6bb0d8854dc707b9280 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -53,6 +53,7 @@ See the @{$python/contrib.framework} guide. @@assign_from_values_fn @@create_global_step @@filter_variables +@@fuse_op @@get_global_step @@get_or_create_global_step @@get_local_variables @@ -81,7 +82,15 @@ See the @{$python/contrib.framework} guide. @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer +@@py_func @@sort + +@@get_placeholders + +@@CriticalSection + +@@BoundedTensorSpec +@@TensorSpec """ from __future__ import absolute_import @@ -96,6 +105,9 @@ from tensorflow.contrib.framework.python.ops import * from tensorflow.python.framework.ops import prepend_name_scope from tensorflow.python.framework.ops import strip_name_scope +from tensorflow.python.framework.tensor_spec import BoundedTensorSpec +from tensorflow.python.framework.tensor_spec import TensorSpec + from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['nest'] diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc index 6677dca752f84fc1ba7548b7739df04b7aaf14f7..5bf6b67529579e71a615c27e035111a58d5c02e0 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc @@ -21,8 +21,8 @@ limitations under the License. #include "tensorflow/contrib/framework/kernels/zero_initializer_op.h" -#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" namespace tensorflow { @@ -81,8 +81,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); #define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS -#endif // GOOGLE_CUDA +#endif // GOOGLE_CUDA #undef REGISTER_KERNELS -} // namespace tensorflow +} // namespace tensorflow diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.h b/tensorflow/contrib/framework/kernels/zero_initializer_op.h index 14c9268efa869ffd48b01dd2add44990ef7a43f8..99389a5ab6aa73c2ab0e522dd0f9fbc7093c8f4a 100644 --- a/tensorflow/contrib/framework/kernels/zero_initializer_op.h +++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.h @@ -29,5 +29,5 @@ struct TensorSetZero { }; } // namespace functor -} // end namespace tensorflow -#endif // TENSORFLOW_CONTRIB_FRAMEWORK_KERNELS_ZERO_INITIALIZER_OP_H_ +} // end namespace tensorflow +#endif // TENSORFLOW_CONTRIB_FRAMEWORK_KERNELS_ZERO_INITIALIZER_OP_H_ diff --git a/tensorflow/contrib/framework/ops/variable_ops.cc b/tensorflow/contrib/framework/ops/variable_ops.cc index 1ee8e1498cf07559fe3db78ef832e2cdf26bea1c..706134ba9a51de6253ba7463b17ff662ea740ed0 100644 --- a/tensorflow/contrib/framework/ops/variable_ops.cc +++ b/tensorflow/contrib/framework/ops/variable_ops.cc @@ -26,8 +26,8 @@ REGISTER_OP("ZeroInitializer") .Attr("T: realnumbertype") .SetAllowsUninitializedInput() .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); + c->set_output(0, c->input(0)); + return Status::OK(); }) .Doc(R"doc( Initialize 'ref' with all zeros. This op requires that the tensor is not diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py index 6d5cde5c9e118d372a6532bfc593bd08b9e18a7b..49eec3a3f1a0f357ea3adfade51e71cb0f89942d 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util.py +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -133,6 +133,18 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, def get_placeholders(graph): """Get placeholders of a graph. + For example: + + ```python + a = tf.placeholder(dtype=tf.float32, shape=[2, 2], name='a') + a = tf.placeholder(dtype=tf.int32, shape=[3, 2], name='b') + + tf.contrib.framework.get_placeholders(tf.get_default_graph()) + # Returns: + # [, + # ] + ``` + Args: graph: A tf.Graph. Returns: @@ -150,5 +162,5 @@ def get_placeholders(graph): # The return value (a Tensor) of placeholder() is the # first output of this operation in fact. operations = graph.get_operations() - result = [i.outputs[0] for i in operations if i.type == 'Placeholder'] + result = [i.outputs[0] for i in operations if i.type == "Placeholder"] return result diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py index 0722fafc132c0db2ad621f6f9345185f34c643f5..b8a6d109e19211d271c2b15bac66ddacd38fe395 100644 --- a/tensorflow/contrib/framework/python/framework/graph_util_test.py +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -90,8 +90,9 @@ class GetPlaceholdersTest(test.TestCase): with ops.Graph().as_default() as g: placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)] results = graph_util.get_placeholders(g) - self.assertEqual(sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access - sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access + self.assertEqual( + sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access + sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access if __name__ == '__main__': diff --git a/tensorflow/contrib/framework/python/framework/tensor_util_test.py b/tensorflow/contrib/framework/python/framework/tensor_util_test.py index 2effe8eb26e98caa2707315d5f2e0e530ead31d3..8cdb340f2ddd9b3a7f55c1937ef045f4627e99be 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util_test.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors_impl from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test @@ -77,6 +78,7 @@ class AssertScalarIntTest(test.TestCase): [3, 4], dtype=dtypes.int32)) +@test_util.with_c_api class WithShapeTest(test.TestCase): def _assert_with_shape(self, tensor, expected_value, expected_shape, @@ -213,16 +215,25 @@ class WithShapeTest(test.TestCase): tensor_partial_shape.set_shape([None, 2]) for incompatible_shape in [[0], [1]]: + if ops._USE_C_API: + error_message = "Shapes must be equal rank, but are 2 and 1" + else: + error_message = r"Shapes \(\?, 2\) and \([01],\) are not compatible" self.assertRaisesRegexp( - ValueError, r"Shapes \(\?, 2\) and \([01],\) are not compatible", + ValueError, error_message, tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[1, 2, 1]]: self.assertRaisesRegexp(ValueError, "Dimensions must be equal", tensor_util.with_shape, incompatible_shape, tensor_partial_shape) for incompatible_shape in [[2, 1]]: + if ops._USE_C_API: + error_message = (r"Dimension 1 in both shapes must be equal, but are " + r"2 and 1. Shapes are \[\?,2\] and \[2,1\].") + else: + error_message = r"Shapes \(\?, 2\) and \(2, 1\) are not compatible" self.assertRaisesRegexp( - ValueError, r"Shapes \(\?, 2\) and \(2, 1\) are not compatible", + ValueError, error_message, tensor_util.with_shape, incompatible_shape, tensor_partial_shape) compatible_shape = [2, 2] diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index 685bb94779762ce46ee342e7e0a182c54be64743..c4976497f5fa95d82e492153b117681f693eaa13 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -22,8 +22,10 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * +from tensorflow.contrib.framework.python.ops.critical_section_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * +from tensorflow.contrib.framework.python.ops.script_ops import * from tensorflow.contrib.framework.python.ops.sort_ops import * from tensorflow.contrib.framework.python.ops.variables import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py index 2375ee4f550616ff60d20b87b5773704d8fbbe1e..476528b0dd3df05239d5dc402b466e06dd789985 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -22,6 +22,7 @@ from __future__ import print_function from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops @@ -108,4 +109,3 @@ def _AddNGrad(op, grad): """Same as gradient for AddN. Copies the gradient to all inputs.""" # Not broadcasting. return [grad] * len(op.inputs) - diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py index 8f44698da851b48abf831e957c80fa1643a58bda..35974b9e21d2d7423777a95a99f51c9cb4b453b2 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -27,16 +27,11 @@ import numpy as np from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 from tensorflow.python.eager import backprop -from tensorflow.python.eager import context as eager_context -from tensorflow.python.eager import tape from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import gradients -from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py index b5e9f8df79262635bf579a6bf2260bc40c140c6f..45962098e93acfac414396ddbeaa847701ff2b4b 100644 --- a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -22,7 +22,6 @@ import numpy as np from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes as dtypes_lib from tensorflow.python.framework import ops from tensorflow.python.framework import test_util @@ -31,7 +30,6 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import googletest - class AccumulateNV2Test(test_util.TensorFlowTestCase): """Tests of the new, differentiable version of accumulate_n""" @@ -62,8 +60,9 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): accum_n = av2.accumulate_n_v2(input_vars) sess.run(variables.global_variables_initializer()) accum_n_grad = gradients.gradients(accum_n, input_vars) - self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 - [g.eval() for g in accum_n_grad]) + self.assertAllEqual( + np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [g.eval() for g in accum_n_grad]) # The tests below used to be in a separate class under cwise_ops_test.py, # which did not run in the default test target. @@ -75,8 +74,8 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) ] random_tensors = [ - ops.convert_to_tensor( - x, dtype=dtypes_lib.float32) for x in random_arrays + ops.convert_to_tensor(x, dtype=dtypes_lib.float32) + for x in random_arrays ] tf_val = av2.accumulate_n_v2(random_tensors) np_val = random_arrays[0] @@ -95,21 +94,21 @@ class AccumulateNV2Test(test_util.TensorFlowTestCase): with self.assertRaises(ValueError): a = variables.Variable(0.2) b = variables.Variable(0.1) - tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[] + tf_val = av2.accumulate_n_v2([a, b], shape=[2, 2]) # Should be shape=[] def testIncompatibleShapes(self): with self.test_session(): with self.assertRaises(ValueError): - a = variables.Variable(np.array([0.1,0.2])) - b = variables.Variable(np.array([[0.3],[0.4]])) - tf_val = av2.accumulate_n_v2([a,b]) + a = variables.Variable(np.array([0.1, 0.2])) + b = variables.Variable(np.array([[0.3], [0.4]])) + tf_val = av2.accumulate_n_v2([a, b]) def testWrongType(self): with self.test_session(): with self.assertRaises(TypeError): a = variables.Variable(0.2, dtype=np.float32) b = variables.Variable(0.1, dtype=np.float32) - tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + tf_val = av2.accumulate_n_v2([a, b], tensor_dtype=np.int32) def testWrongTypeOneInput(self): # Scenario that used to trigger a bug, even when testWrongType() worked diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 2bce00fde2459878a12027bb4d98bd3818bc92a2..409657fe1da0e5540cd2ad6070d86737c039e91f 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -53,7 +53,8 @@ net = layers.conv2d(net, 256, [5, 5], scope='conv2') ``` - Example of how to use tf.contrib.framework.add_arg_scope to enable your function to be called within an arg_scope later: + Example of how to use tf.contrib.framework.add_arg_scope to enable your + function to be called within an arg_scope later: @tf.contrib.framework.add_arg_scope def conv2d(*args, **kwargs) @@ -65,11 +66,10 @@ from __future__ import print_function from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator -__all__ = ['arg_scope', - 'add_arg_scope', - 'current_arg_scope', - 'has_arg_scope', - 'arg_scoped_arguments'] +__all__ = [ + 'arg_scope', 'add_arg_scope', 'current_arg_scope', 'has_arg_scope', + 'arg_scoped_arguments' +] _ARGSTACK = [{}] @@ -172,6 +172,7 @@ def add_arg_scope(func): Returns: A tuple with the decorated function func_with_args(). """ + def func_with_args(*args, **kwargs): current_scope = current_arg_scope() current_args = kwargs @@ -180,6 +181,7 @@ def add_arg_scope(func): current_args = current_scope[key_func].copy() current_args.update(kwargs) return func(*args, **current_args) + _add_op(func) setattr(func_with_args, '_key_op', _key_op(func)) return tf_decorator.make_decorator(func, func_with_args) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_ops.py b/tensorflow/contrib/framework/python/ops/critical_section_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..182fec924febb74a23b82b1664d137f033f3b1b4 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/critical_section_ops.py @@ -0,0 +1,324 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Critical Section object and execution logic.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +# TODO(ebrevdo): Re-enable once CriticalSection is in core. +# from tensorflow.core.protobuf import critical_section_pb2 + +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_resource_variable_ops +from tensorflow.python.util import nest + + +# Graph Keys +CRITICAL_SECTIONS = "critical_sections" +CRITICAL_SECTION_EXECUTIONS = "critical_section_executions" + + +class _ExecutionSignature( + collections.namedtuple("_ExecutionSignature", + ("op", "exclusive_resource_access"))): + """A class storing an `ExecuteInCriticalResource` op and associated attrs.""" + pass + + +class CriticalSection(object): + """Critical section. + + A `CriticalSection` object is a resource in the graph which executes subgraphs + in **serial** order. A common example of a subgraph one may wish to run + exclusively is the one given by the following function: + + ```python + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def count(): + value = v.read_value() + with tf.control_dependencies([value]): + with tf.control_dependencies([v.assign_add(1)]): + return tf.identity(value) + ``` + + Here, a snapshot of `v` is captured in `value`; and then `v` is updated. + The snapshot value is returned. + + If multiple workers or threads all execute `count` in parallel, there is no + guarantee that access to the variable `v` is atomic at any point within + any thread's calculation of `count`. In fact, even implementing an atomic + counter that guarantees that the user will see each value `0, 1, ...,` is + currently impossible. + + The solution is to ensure any access to the underlying resource `v` is + only processed through a critical section: + + ```python + cs = CriticalSection() + f1 = cs.execute(count) + f2 = cs.execute(count) + output = f1 + f2 + session.run(output) + ``` + The functions `f1` and `f2` will be executed serially, and updates to `v` + will be atomic. + + **NOTES** + + All resource objects, including the critical section and any captured + variables of functions executed on that critical section, will be + colocated to the same device (host and cpu/gpu). + + When using multiple critical sections on the same resources, there is no + guarantee of exclusive access to those resources. This behavior is disallowed + by default (but see the kwarg `exclusive_resource_access`). + + For example, running the same function in two separate critical sections + will not ensure serial execution: + + ```python + v = tf.get_variable("v", initializer=0.0, use_resource=True) + def accumulate(up): + x = v.read_value() + with tf.control_dependencies([x]): + with tf.control_dependencies([v.assign_add(up)]): + return tf.identity(x) + ex1 = CriticalSection().execute( + accumulate, 1.0, exclusive_resource_access=False) + ex2 = CriticalSection().execute( + accumulate, 1.0, exclusive_resource_access=False) + bad_sum = ex1 + ex2 + sess.run(v.initializer) + sess.run(bad_sum) # May return 0.0 + ``` + """ + + def __init__(self, name=None, critical_section_def=None, import_scope=None): + """Creates a critical section.""" + if critical_section_def and name is not None: + raise ValueError("critical_section_def and name are mutually exclusive.") + if critical_section_def: + self._init_from_proto(critical_section_def, import_scope=import_scope) + else: + self._init_from_args(name) + + def _init_from_proto(self, critical_section_def, import_scope): + raise NotImplementedError("Not yet implemented") + # TODO(ebrevdo): Re-enable once CriticalSection is in core. + # assert isinstance( + # critical_section_def, critical_section_pb2.CriticalSectionDef) + # # Create from critical_section_def. + # g = ops.get_default_graph() + # self._handle = g.as_graph_element( + # ops.prepend_name_scope( + # critical_section_def.critical_section_name, + # import_scope=import_scope)) + + def _init_from_args(self, name): + """Initialize the CriticalSection from constructor arguments.""" + with ops.name_scope(name, "CriticalSection", []) as name: + with ops.control_dependencies(None): + # pylint: disable=protected-access + handle_name = ops._name_from_scope_name(name) + container = ops.get_default_graph()._container + # pylint: enable=protected-access + if container is None: + container = "" + self._handle = gen_resource_variable_ops.critical_section_op( + shared_name=handle_name, name=name) + if context.in_graph_mode(): + ops.add_to_collections(CRITICAL_SECTIONS, self) + + @property + def name(self): + return self._handle.op.name + + def execute(self, fn, *args, **kwargs): + """Execute function `fn(*args, **kwargs)` inside the CriticalSection. + + Args: + fn: The function to execute. Must return at least one tensor. + *args: Additional positional arguments to `fn`. + **kwargs: Additional keyword arguments to `fn`. + Several keywords are reserved for `execute`. These are: + + - name; The name to use when creating the execute operation. + - exclusive_resource_access; Whether the resources required by + `fn` should be exclusive to this `CriticalSection`. Default: `True`. + You may want to set this to `False` if you will be accessing a + resource in read-only mode in two different CriticalSections. + + Returns: + The tensors returned from `fn(*args, **kwargs)`. + + Raises: + ValueError: If `fn` attempts to use this `CriticalSection` in any nested + way. + ValueError: If `exclusive_resource_access` is not provided (is `True`) and + another `CriticalSection` has an execution requesting the same + resources as in `*args`, `**kwargs`, and any additionaly captured + inputs in `fn`. Note, even if `exclusive_resource_access` is `True`, + if another execution in another `CriticalSection` was created without + `exclusive_resource_access=True`, a `ValueError` will be raised. + """ + name = kwargs.pop("name", None) + exclusive_resource_access = kwargs.pop("exclusive_resource_access", True) + + args = nest.map_structure(ops.convert_to_tensor, args) + with ops.name_scope(name, "critical_section_execute", []): + fn_op = function.make_defun_op(fn, *args, **kwargs) + flat_dtypes = nest.flatten(fn_op.output_dtypes) + flat_shapes = nest.flatten(fn_op.output_shapes) + all_inputs = nest.flatten(args) + fn_op.captured_inputs + if self._handle in all_inputs: + raise ValueError("The function fn attempts to access the " + "CriticalSection in which it would be running. This " + "is illegal and would cause deadlocks. " + "CriticalSection: %s." % self._handle) + + if context.in_graph_mode(): + # Collections and op introspection does not work in eager + # mode. This is generally ok; since eager mode (as of + # writing) executes sequentially anyway. + all_input_resources = [ + x for x in all_inputs if x.dtype == dtypes.resource] + for sg in ops.get_collection(CRITICAL_SECTION_EXECUTIONS): + if sg.op.inputs[0].name == self._handle.name: + # Other executions in the same critical section are allowed. + continue + if not (exclusive_resource_access or sg.exclusive_resource_access): + # Neither execution requested exclusive access. + continue + sg_input_names = [y.name for y in sg.op.inputs[1:]] + for res in all_input_resources: + if res.name in sg_input_names: + raise ValueError( + "This execution would access resource %s; but either this " + "execution (CriticalSection: %s) or Execution '%s' " + "(CriticalSection: %s) requested exclusive resource access " + "of this resource for their critical section. Did you mean " + "to call execute with keyword argument " + "exclusive_resource_access=False?" + % (res.name, + self.name, + sg.op.name, + sg.op.inputs[0].op.name)) + + flat_outputs = gen_resource_variable_ops.execute_in_critical_section( + critical_section=self._handle, + arguments=all_inputs, + f=fn_op, + output_types=flat_dtypes, + output_shapes=flat_shapes) + + if context.in_graph_mode(): + if isinstance(flat_outputs, ops.Operation): + flat_outputs = [flat_outputs] + op = (flat_outputs[0].op if isinstance(flat_outputs[0], ops.Tensor) + else flat_outputs[0]) + signature = _ExecutionSignature( + op=op, + exclusive_resource_access=exclusive_resource_access) + ops.add_to_collections( + CRITICAL_SECTION_EXECUTIONS, signature) + + return (flat_outputs[0] + if (len(flat_outputs) == 1 + and isinstance(flat_outputs[0], ops.Operation)) + else nest.pack_sequence_as(fn_op.output_dtypes, flat_outputs)) + + # TODO(ebrevdo): Re-enable once CriticalSection is in core. + + # def to_proto(self, export_scope=None): + # """Converts a `CriticalSection` to a `CriticalSectoinDef` protocol buffer. + + # Args: + # export_scope: Optional `string`. Name scope to remove. + + # Returns: + # A `CriticalSectionDef` protocol buffer, or `None` if the + # `CriticalSection` is not in the specified name scope. + # """ + # if export_scope is None or self.handle.name.startswith(export_scope): + # cs_def = critical_section_pb2.CriticalSectionDef() + # cs_def.critical_section_name = ops.strip_name_scope( + # self._handle.name, export_scope) + # return cs_def + # else: + # return None + + # @staticmethod + # def from_proto(critical_section_def, import_scope=None): + # return CriticalSection( + # critical_section_def=critical_section_def, import_scope=import_scope) + + +# TODO(ebrevdo): Re-enable once CriticalSection is in core. + +# def _execution_to_proto_fn(execution_signature, export_scope=None): +# """Converts `_ExecutionSignature` to a `CriticalSectionExecutionDef`. + +# Args: +# execution_signature: Instance of `_ExecutionSignature`. +# export_scope: The export scope, if any. + +# Returns: +# An instance of `CriticalSectionExecutionDef`. +# """ +# if (export_scope is None +# or execution_signature.op.name.startswith(export_scope)): +# op_def = critical_section_pb2.CriticalSectionExecutionDef() +# op_def.execute_in_critical_section_name = ops.strip_name_scope( +# execution_signature.op.name, export_scope) +# op_def.exclusive_resource_access = ( +# execution_signature.exclusive_resource_access) +# return op_def +# else: +# return None + + +# def _execution_from_proto_fn(op_def, import_scope=None): +# """Converts a `CriticalSectionExecutionDef` to a `_ExecutionSignature`.""" +# assert isinstance( +# op_def, critical_section_pb2.CriticalSectionExecutionDef) + +# # Create from op_def. +# g = ops.get_default_graph() +# execution_op = g.as_graph_element( +# ops.prepend_name_scope( +# op_def.execute_in_critical_section_name, +# import_scope=import_scope)) +# return _ExecutionSignature( +# op=execution_op, +# exclusive_resource_access=op_def.exclusive_resource_access) + +# ops.register_proto_function( +# CRITICAL_SECTIONS, +# proto_type=critical_section_pb2.CriticalSectionDef, +# to_proto=CriticalSection.to_proto, +# from_proto=CriticalSection.from_proto) + +# ops.register_proto_function( +# CRITICAL_SECTION_EXECUTIONS, +# proto_type=critical_section_pb2.CriticalSectionExecutionDef, +# to_proto=_execution_to_proto_fn, +# from_proto=_execution_from_proto_fn) diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a416724d3ba1719471d70667e140f9cd2daf86c7 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py @@ -0,0 +1,178 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""critical section tests.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import critical_section_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test +# TODO(ebrevdo): Re-enable once CriticalSection is in core. +# from tensorflow.python.training import saver as saver_lib + + +class CriticalSectionTest(test.TestCase): + + @test_util.run_in_graph_and_eager_modes() + def testCreateCriticalSection(self): + cs = critical_section_ops.CriticalSection(name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def fn(a, b): + c = v.read_value() + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return array_ops.identity(c) + + num_concurrent = 1000 + r = [cs.execute(fn, 1.0, 2.0) for _ in range(num_concurrent)] + self.evaluate(v.initializer) + r_value = self.evaluate(r) + self.assertAllClose([2.0 * i for i in range(num_concurrent)], + sorted(r_value)) + + @test_util.run_in_graph_and_eager_modes() + def testCreateCriticalSectionFnReturnsOp(self): + cs = critical_section_ops.CriticalSection(name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + def fn_return_op(a, b): + c = v.read_value() + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return () + + num_concurrent = 100 + r = [cs.execute(fn_return_op, 1.0, 2.0) for _ in range(num_concurrent)] + self.evaluate(v.initializer) + self.evaluate(r) + final_v = self.evaluate(v) + self.assertAllClose(2.0 * num_concurrent, final_v) + + def testCreateCriticalSectionRaw(self): + cs = critical_section_ops.CriticalSection(name="cs") + v = resource_variable_ops.ResourceVariable(0.0, name="v") + + @function.Defun(dtypes.float32, dtypes.float32) + def fn(a, b): + c = v.read_value() + with ops.control_dependencies([c]): + nv = v.assign_add(a * b) + with ops.control_dependencies([nv]): + return array_ops.identity(c) + + def execute(fn, *args): + output_args = fn.definition.signature.output_arg + return resource_variable_ops.execute_in_critical_section( + critical_section=cs._handle, + arguments=list(args) + fn.captured_inputs, + f=fn, + output_types=[out.type for out in output_args], + output_shapes=[tensor_shape.TensorShape(None) for _ in output_args]) + + num_concurrent = 1000 + r = [execute(fn, 1.0, 2.0)[0] for _ in range(num_concurrent)] + self.evaluate(v.initializer) + r_value = self.evaluate(r) + self.assertAllClose([2.0 * i for i in range(num_concurrent)], + sorted(r_value)) + + def testCollection(self): + cs = critical_section_ops.CriticalSection(name="cs") + self.assertIn( + cs, ops.get_collection(critical_section_ops.CRITICAL_SECTIONS)) + execute_op = cs.execute(lambda x: x + 1, 1.0).op + self.assertIn( + execute_op, + [signature.op for signature in + ops.get_collection(critical_section_ops.CRITICAL_SECTION_EXECUTIONS)]) + + @test_util.run_in_graph_and_eager_modes() + def testRecursiveCriticalSectionAccessIsIllegal(self): + cs = critical_section_ops.CriticalSection(name="cs") + def fn(x): + return cs.execute(lambda x: x+1, x) + with self.assertRaisesRegexp( + ValueError, + r"attempts to access the CriticalSection in which it would be running"): + cs.execute(fn, 1.0) + + def testMultipleCSExecutionsRequestSameResource(self): + cs0 = critical_section_ops.CriticalSection() + cs1 = critical_section_ops.CriticalSection() + v = resource_variable_ops.ResourceVariable(0.0, name="v") + cs0.execute(lambda: v + 1) + # It's OK for the same CriticalSection to access this resource. + cs0.execute(lambda: v - 1) + # It's *not* OK for a different CriticalSection to access it by + # default. + with self.assertRaisesRegexp( + ValueError, "requested exclusive resource access"): + cs1.execute(lambda: v + 1) + # It's not even OK if the second call doesn't request exclusive access. + with self.assertRaisesRegexp( + ValueError, "requested exclusive resource access"): + cs1.execute(lambda: v + 1, exclusive_resource_access=False) + + v2 = resource_variable_ops.ResourceVariable(0.0, name="v2") + cs0.execute(lambda: v2 + 1, exclusive_resource_access=False) + # It's OK if neither requests exclusive resource access. + cs1.execute(lambda: v2 + 1, exclusive_resource_access=False) + + # It's not OK if the second request requires exlusive resource + # access. + with self.assertRaisesRegexp( + ValueError, "requested exclusive resource access"): + cs1.execute(lambda: v2 + 1) + + # TODO(ebrevdo): Re-enable once CriticalSection is in core. + # + # def testCriticalSectionAndExecuteOpSaverRoundTrip(self): + # cs = critical_section_ops.CriticalSection() + # r = cs.execute(lambda x: x + 1, 1.0) + # graph = ops.get_default_graph() + # meta_graph = saver_lib.export_meta_graph( + # graph=graph, collection_list=graph.get_all_collection_keys()) + # graph_copy = ops.Graph() + # with graph_copy.as_default(): + # _ = saver_lib.import_meta_graph(meta_graph, import_scope="imported") + # restored_cs = ops.get_collection(critical_section_ops.CRITICAL_SECTIONS) + # restored_exec = ops.get_collection( + # critical_section_ops.CRITICAL_SECTION_EXECUTIONS) + # self.assertEqual(1, len(restored_cs)) + # self.assertEqual(1, len(restored_exec)) + # self.assertEqual(restored_cs[0].name, "imported/%s" % cs.name) + # self.assertEqual(restored_exec[0].op.name, "imported/%s" % r.op.name) + + # def testToProto(self): + # cs = critical_section_ops.CriticalSection(name="cs") + # proto = cs.to_proto() + # self.assertEqual(proto.critical_section_name, cs._handle.name) + # cs_copy = critical_section_ops.CriticalSection.from_proto(proto) + # self.assertEqual(cs_copy._handle, cs._handle) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/framework/python/ops/script_ops.py b/tensorflow/contrib/framework/python/ops/script_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5d269fefdcfae7902b35e0f29f8cd12fcc58b882 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/script_ops.py @@ -0,0 +1,143 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Script Language Operators. See the @{$python/script_ops} guide. + +@@py_func +""" + +# pylint: disable=g-bad-name +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops.script_ops import py_func as _py_func +from tensorflow.python.util import nest + +__all__ = ['py_func'] + + +def py_func(func, + args=(), + kwargs=None, + output_types=None, + output_shapes=None, + stateful=True, + name=None): + """Wraps a python function and uses it as a TensorFlow op. + + This function is a wrapper around `tf.py_func` and improve it with kwargs + and output_shapes. Further it changed some argument names. + + Given a python function `func`, which takes numpy arrays as its + inputs and returns numpy arrays as its outputs, wrap this function as an + operation in a TensorFlow graph. The following snippet constructs a simple + TensorFlow graph that invokes the `np.sinh()` NumPy function as a operation + in the graph: + + ```python + def my_func(x): + # x will be a numpy array with the contents of the placeholder below + return np.sinh(x) + inp = tf.placeholder(tf.float32) + y = tf.py_func(my_func, [inp], tf.float32) + ``` + + + **N.B.** The `tf.py_func()` operation has the following known limitations: + + * The body of the function (i.e. `func`) will not be serialized in a + `GraphDef`. Therefore, you should not use this function if you need to + serialize your model and restore it in a different environment. + + * The operation must run in the same address space as the Python program + that calls `tf.py_func()`. If you are using distributed TensorFlow, you + must run a `tf.train.Server` in the same process as the program that calls + `tf.py_func()` and you must pin the created operation to a device in that + server (e.g. using `with tf.device():`). + + Args: + func: A Python function, which accepts a list of NumPy `ndarray` objects + having element types that match the corresponding `tf.Tensor` objects + in `inp`, and returns a list of `ndarray` objects (or a single `ndarray`) + having element types that match the corresponding values in `Tout`. + args: A list of `Tensor` objects. + kwargs: A dict with `Tensor` objects as values. + output_types: A nested structure of tensorflow data types or a single + tensorflow data type if there is only one, indicating what `func` returns. + output_shapes: Same as output_types, except the types are replaces with + shapes (optional). + stateful: (Boolean.) If True, the function should be considered stateful. + If a function is stateless, when given the same input it will return the + same output and have no observable side effects. Optimizations such as + common subexpression elimination are only performed on stateless + operations. + name: A name for the operation (optional). + + Returns: + Tensorflow op that wraps the input python function. + """ + + if kwargs is None: + kwargs = {} + + if not isinstance(args, (list, tuple)): + raise TypeError('args must be list and not {}. args: {}'.format( + type(args), args)) + + if not isinstance(kwargs, dict): + raise TypeError('kwargs must be dict and not {}. args: {}'.format( + type(kwargs), kwargs)) + + # For dynamic type inference use callable output_types and output_shapes + if callable(output_types): + # If callable assume same signature and call with tensors and get the types + output_types = output_types(*args, **kwargs) + if callable(output_shapes): + # If callable assume same signature and call with tensors and get the shapes + output_shapes = output_shapes(*args, **kwargs) + + flat_output_types = nest.flatten(output_types) + args = (args, kwargs) + flat_args = nest.flatten(args) + + def python_function_wrapper(*py_args): + py_args, py_kwargs = nest.pack_sequence_as(args, py_args) + + ret = func(*py_args, **py_kwargs) + # TODO(alextp): Catch Exceptions and improve msg, because tensorflow + # ist not able to preserve the traceback, i.e. the Exceptions does not + # contain any information where the Exception was raised. + nest.assert_shallow_structure(output_types, ret) + return nest.flatten(ret) + + flat_values = _py_func( + python_function_wrapper, + flat_args, + flat_output_types, + stateful=stateful, + name=name) + + if output_shapes is not None: + # I am not sure if this is nessesary + output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) + + flattened_shapes = nest.flatten(output_shapes) + for ret_t, shape in zip(flat_values, flattened_shapes): + ret_t.set_shape(shape) + + return nest.pack_sequence_as(output_types, flat_values) diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 07b7857e7b2114d251ebb5c14eda9dff0d55bbef..0754c3e0e30a340910a43a3ce86f6ca10afe848e 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -25,6 +25,7 @@ import re from tensorflow.contrib.framework.python.ops import add_arg_scope as contrib_add_arg_scope from tensorflow.contrib.framework.python.ops import gen_variable_ops from tensorflow.contrib.util import loader +from tensorflow.core.protobuf import saver_pb2 from tensorflow.python import pywrap_tensorflow from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import dtypes @@ -32,9 +33,8 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import gen_state_ops -from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import resource_loader +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import saver as tf_saver from tensorflow.python.training import training_util from tensorflow.python.util.deprecation import deprecated @@ -441,7 +441,7 @@ def get_unique_variable(var_op_name): """ candidates = get_variables(scope=var_op_name) if not candidates: - raise ValueError('Couldnt find variable %s' % var_op_name) + raise ValueError('Couldn\'t find variable %s' % var_op_name) for candidate in candidates: if candidate.op.name == var_op_name: @@ -685,7 +685,8 @@ def assign_from_checkpoint_fn(model_path, var_list, ignore_missing_vars=False, 'Variable %s missing in checkpoint %s', var, model_path) var_list = available_vars if var_list: - saver = tf_saver.Saver(var_list, reshape=reshape_variables) + saver = tf_saver.Saver(var_list, reshape=reshape_variables, + write_version=saver_pb2.SaverDef.V1) def callback(session): saver.restore(session, model_path) return callback diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 88306094ab9947c9c78b03c0013f6afc88316803..0e06575d96f9b9538f0245b12d48cfd7c0e8d981 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -34,6 +34,7 @@ limitations under the License. #include "tensorflow/core/util/use_cudnn.h" #if GOOGLE_CUDA +#include "cuda/include/cudnn.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/util/activation_mode.h" @@ -278,6 +279,28 @@ Status TransformNHWCToNCHW(OpKernelContext* ctx, const Tensor& nhwc_tensor, return Status::OK(); } +// Adjusts padding so cudnn supports it. Sets `adjusted_padding` to be the +// adjusted padding, and `extra_padding_before` and `extra_padding_after` to be +// the extra padding that FusedConv needs to apply before calling cudnn. +void AdjustPaddingForCudnn(int padding, bool is_int8x4, int filter_size, + int* adjusted_padding, int* extra_padding_before, + int* extra_padding_after) { +#if CUDNN_VERSION < 7000 + if (is_int8x4 && filter_size >= 6) { + // TODO(b/70795525): Remove after NVIDIA fixes this bug with int8 fused + // convolution. I don't know cuDNN7 still has the bug, so enable this + // workaround for cuDNN6 or older. + *adjusted_padding = 0; + *extra_padding_before = padding / 2; + *extra_padding_after = padding - *extra_padding_before; + return; + } +#endif + *adjusted_padding = padding / 2 * 2; + *extra_padding_before = 0; + *extra_padding_after = padding % 2; +} + template void LaunchFusedConv2DBiasActivationOp:: launch(OpKernelContext* ctx, bool cudnn_use_autotune, @@ -303,7 +326,7 @@ void LaunchFusedConv2DBiasActivationOp:: stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, &cc_minor); OP_REQUIRES( - ctx, cc_major >= 6 && cc_minor >= 1, + ctx, ((cc_major == 6 && cc_minor >= 1) || cc_major > 6), errors::Unimplemented( "FusedConv2DBiasActivation for int8 is only supported on GPUs with " "compute capability 6.1 or later.")); @@ -338,12 +361,21 @@ void LaunchFusedConv2DBiasActivationOp:: 0, (output_rows - 1) * row_stride + filter_rows - conv_input_rows); padding_cols = std::max( 0, (output_cols - 1) * col_stride + filter_cols - conv_input_cols); - const int padding_rows_parity = padding_rows & 1; - const int padding_cols_parity = padding_cols & 1; - if ((padding_rows_parity | padding_cols_parity) != 0) { + int extra_top_padding = 0; + int extra_bottom_padding = 0; + int extra_left_padding = 0; + int extra_right_padding = 0; + AdjustPaddingForCudnn(padding_rows, is_int8x4, filter_rows, &padding_rows, + &extra_top_padding, &extra_bottom_padding); + AdjustPaddingForCudnn(padding_cols, is_int8x4, filter_cols, &padding_cols, + &extra_left_padding, &extra_right_padding); + if (extra_top_padding != 0 || extra_bottom_padding != 0 || + extra_left_padding != 0 || extra_right_padding != 0) { Tensor transformed_input; - const int new_conv_input_rows = conv_input_rows + padding_rows_parity; - const int new_conv_input_cols = conv_input_cols + padding_cols_parity; + const int new_conv_input_rows = + conv_input_rows + extra_top_padding + extra_bottom_padding; + const int new_conv_input_cols = + conv_input_cols + extra_left_padding + extra_right_padding; using VectT = typename Int8x4ToInt32::type>::type; auto pad_data_format = is_int8x4 ? FORMAT_NCHW : data_format; @@ -361,8 +393,9 @@ void LaunchFusedConv2DBiasActivationOp:: maybe_padded_conv_input.reinterpret_last_dimension()); functor::PadInput()( - ctx->eigen_device(), conv_input_eigen_tensor, {{0, 0}}, - {{padding_rows_parity, padding_cols_parity}}, + ctx->eigen_device(), conv_input_eigen_tensor, + {{extra_top_padding, extra_left_padding}}, + {{extra_bottom_padding, extra_right_padding}}, padded_conv_input_eigen_tensor, pad_data_format); conv_input = &maybe_padded_conv_input; @@ -439,6 +472,8 @@ void LaunchFusedConv2DBiasActivationOp:: .set_feature_map_count(output_depth) .set_layout(data_layout); dnn::ConvolutionDescriptor conv_desc; + CHECK_EQ(0, padding_rows % 2); + CHECK_EQ(0, padding_cols % 2); conv_desc.set_vertical_filter_stride(row_stride) .set_horizontal_filter_stride(col_stride) .set_zero_padding_height(padding_rows / 2) @@ -493,6 +528,8 @@ void LaunchFusedConv2DBiasActivationOp:: {{conv_input_rows, conv_input_cols}}, output_depth, {{filter_rows, filter_cols}}, + // TODO(yangzihao): Add support for arbitrary dilations for fused conv. + {{1, 1}}, // dilation_rows, dilation_cols {{row_stride, col_stride}}, {{padding_rows, padding_cols}}, conv_input->dtype(), diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h index dc43af11580ce5fda74ee25da6c151a5b89c7aee..ba52697679dafc239b1dac5562573b3589877a8c 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#ifndef TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#define TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ #if GOOGLE_CUDA @@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters { public: FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in, int64 out_depths, const SpatialArray& filter, - const SpatialArray& stride, const SpatialArray& padding, - DataType dtype, int device_id, bool has_side_input, + const SpatialArray& dilation, const SpatialArray& stride, + const SpatialArray& padding, DataType dtype, + int device_id, bool has_side_input, ActivationMode activation_mode) - : ConvParameters(batch, in_depths, in, out_depths, filter, stride, - padding, dtype, device_id), + : ConvParameters(batch, in_depths, in, out_depths, filter, dilation, + stride, padding, dtype, device_id), activation_mode_(activation_mode), has_side_input_(has_side_input) { hash_code_ = Hash64Combine(hash_code_, has_side_input); @@ -71,4 +72,4 @@ class FusedConvParameters : public ConvParameters { #endif // GOOGLE_CUDA -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ +#endif // TENSORFLOW_CONTRIB_FUSED_CONV_KERNELS_FUSED_CONV_OPS_GPU_H_ diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc index 887ebc5a6c35379476fa1a643c866d38e2b25699..bafd1d59418f0ba47ebbdaabbf06f8e5471fc1a1 100644 --- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc @@ -25,13 +25,6 @@ limitations under the License. namespace tensorflow { -namespace { -// Return the string containing the list of valid activation modes, that can be -// used as an Attr() in REGISTER_OP. -string GetAllActivationModeAttrString() { return "activation_mode: {'Relu'}"; } - -} // namespace - // -------------------------------------------------------------------------- // TODO(pauldonnelly): Add support for double inputs and scales to this Op, @@ -52,6 +45,7 @@ REGISTER_OP("FusedConv2DBiasActivation") .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'") .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'") .Attr("activation_mode: {'Relu'} = 'Relu'") + .Attr("dilations: list(int) = [1, 1, 1, 1]") .SetShapeFn([](shape_inference::InferenceContext* c) { using shape_inference::ShapeHandle; using shape_inference::DimensionHandle; @@ -151,6 +145,11 @@ REGISTER_OP("FusedConv2DBiasActivation") kernel_height, kernel_width, input_channels % 4 ]` activation_mode: The activation applied to the output. Currently must be "Relu". + dilations: 1-D tensor of length 4. The dilation factor for each dimension + of `input`. If set to k > 1, there will be k-1 skipped cells between + each filter element on that dimension. The dimension order is determined + by the value of `data_format`, see above for details. Dilations in the + batch and depth dimensions must be 1. )doc"); } // namespace tensorflow diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py index a65d4bc50ff796977e8ea7f652b7cbe3fe37f673..96cdd8b1ca4d56d12d38ea961ae73f3a3aa28968 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_benchmark.py @@ -116,7 +116,7 @@ def build_fused_conv_bias_relu_graph(device, input_shape, filter_shape, strides, for _ in range(1, num_iters): with ops.control_dependencies([fused_out]): # pylint: disable=g-line-too-long - fused_out = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( + fused_out = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( # pylint: disable=line-too-long inp, filt, bias, @@ -166,10 +166,10 @@ class FusedConv2DBiasActivationBenchmark(test.Benchmark): duration = (time.time() - start_time) / num_iters print("%s inputshape:%s filtershape:%s strides:%s padding:%s " - "%d iters: %.8f sec" % - (device, str(input_shape).replace(" ", ""), - str(filter_shape).replace(" ", ""), - str(strides).replace(" ", ""), padding, num_iters, duration)) + "%d iters: %.8f sec" % (device, str(input_shape).replace(" ", ""), + str(filter_shape).replace(" ", ""), + str(strides).replace(" ", ""), padding, + num_iters, duration)) name_template = ( "conv2d_{device}_input_shape_{inputshape}_filter_shape_{filtershape}_" "strides_{strides}_padding_{padding}") diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 2a18f3eeecc7e0e69c54b219886a263136f01b2c..bb155aa2496cbafd9f0630d3dffb2ba69395186c 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -658,6 +658,36 @@ def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, class FusedConvInt8Tests(test.TestCase): _test_params = [ + { + "batch_size": 1, + "input_channels": 4, + "output_channels": 4, + "input_height": 8, + "input_width": 8, + "filter_height": 6, + "filter_width": 6, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, + { + "batch_size": 1, + "input_channels": 4, + "output_channels": 4, + "input_height": 6, + "input_width": 6, + "filter_height": 6, + "filter_width": 6, + "vertical_stride": 2, + "horizontal_stride": 2, + "conv_input_scale": 0.002, + "side_input_scale": 0.0, + "bias_scale": 1, + "padding_type": "SAME" + }, { "batch_size": 2, "input_channels": 8, diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index abe4665caa9b23b5663df48487c6c77d33d15c59..5db34f0f8db93620b8b4a6b71f63b66ac718ee30 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -56,6 +56,7 @@ py_test( srcs = ["python/train_test.py"], srcs_version = "PY2AND3", deps = [ + ":features", ":namedtuples", ":train", "//tensorflow/contrib/framework:framework_py", @@ -82,6 +83,7 @@ py_library( deps = [ ":classifier_metrics", ":eval_utils", + ":sliced_wasserstein", ":summaries", "//tensorflow/python:util", ], @@ -116,7 +118,7 @@ py_library( deps = [ ":clip_weights", ":conditioning_utils", - ":tensor_pool", + ":random_tensor_pool", ":virtual_batchnorm", "//tensorflow/python:util", ], @@ -175,6 +177,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":losses_impl", + ":namedtuples", "//tensorflow/python:util", ], ) @@ -186,6 +189,9 @@ py_test( deps = [ ":tuple_losses", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:variables", "//third_party/py/numpy", ], ) @@ -221,10 +227,10 @@ py_test( ) py_library( - name = "tensor_pool", + name = "random_tensor_pool", srcs = [ - "python/features/python/tensor_pool.py", - "python/features/python/tensor_pool_impl.py", + "python/features/python/random_tensor_pool.py", + "python/features/python/random_tensor_pool_impl.py", ], srcs_version = "PY2AND3", deps = [ @@ -239,11 +245,11 @@ py_library( ) py_test( - name = "tensor_pool_test", - srcs = ["python/features/python/tensor_pool_test.py"], + name = "random_tensor_pool_test", + srcs = ["python/features/python/random_tensor_pool_test.py"], srcs_version = "PY2AND3", deps = [ - ":tensor_pool", + ":random_tensor_pool", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", @@ -393,6 +399,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":eval_utils", + ":namedtuples", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", @@ -502,6 +509,41 @@ py_test( ], ) +py_library( + name = "sliced_wasserstein", + srcs = [ + "python/eval/python/sliced_wasserstein.py", + "python/eval/python/sliced_wasserstein_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_test( + name = "sliced_wasserstein_test", + srcs = ["python/eval/python/sliced_wasserstein_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":sliced_wasserstein", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 4bca0a1d62a2b404c6783c7cfe3b5c67cfc58221..4ead66ca13e74bacc0e4679a8d5c4e0f23d04b69 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -99,8 +99,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss) + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss) # Create the train ops, which calculate gradients and apply updates to weights. train_ops = tfgan.gan_train_ops( @@ -161,8 +161,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss and standard pixel loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty=1.0) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) @@ -193,8 +193,8 @@ gan_model = tfgan.gan_model( # Build the GAN loss and standard pixel loss. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.least_squares_generator_loss, - discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss) + generator_loss_fn=tfgan.losses.least_squares_generator_loss, + discriminator_loss_fn=tfgan.losses.least_squares_discriminator_loss) l1_pixel_loss = tf.norm(gan_model.real_data - gan_model.generated_data, ord=1) # Modify the loss tuple to include the pixel loss. @@ -223,8 +223,8 @@ gan_model = tfgan.infogan_model( # Build the GAN loss with mutual information penalty. gan_loss = tfgan.gan_loss( gan_model, - generator_loss_fn=tfgan_losses.wasserstein_generator_loss, - discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, gradient_penalty=1.0, mutual_information_penalty_weight=1.0) diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index dff361fdc42708ea69999c2def4721f9d49fcf14..f1946c7f925660eae3aaa650c437e03da1f33d6c 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN is a lightweight library for training and evaluating GANs. + +In addition to providing the infrastructure for easily training and evaluating +GANS, this library contains modules for a TFGAN-backed Estimator, +evaluation metrics, features (such as virtual batch normalization), and losses. +Please see README.md for details and usage. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py index 8c4a18228039cb4f2c06e0333f4b8408f1f631e9..c9f7bc61b25230e4159cf8cbc7c9cceead0aa706 100644 --- a/tensorflow/contrib/gan/python/estimator/__init__.py +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN estimator module. + +GANEstimator provides all the infrastructure support of a TensorFlow Estimator +with the feature support of TFGAN. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 058dc1d1f8cc176dcdb81268da2c4704d7eddc99..082c42eba180917e732bb7890129dfa94bf00fec 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -59,7 +59,11 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. + This Estimator is backed by TFGAN. The network functions follow the TFGAN API + except for one exception: if either `generator_fn` or `discriminator_fn` have + an argument called `mode`, then the tf.Estimator mode is passed in for that + argument. This helps with operations like batch normalization, which have + different train and evaluation behavior. Example: @@ -96,7 +100,7 @@ class GANEstimator(estimator.Estimator): # Generate samples from generator. predictions = np.array([ x for x in gan_estimator.predict(predict_input_fn)]) - ``` + ``` """ def __init__(self, @@ -107,6 +111,7 @@ class GANEstimator(estimator.Estimator): discriminator_loss_fn=None, generator_optimizer=None, discriminator_optimizer=None, + get_hooks_fn=None, add_summaries=None, use_loss_summaries=True, config=None): @@ -137,6 +142,10 @@ class GANEstimator(estimator.Estimator): work. discriminator_optimizer: Same as `generator_optimizer`, but for the discriminator updates. + get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a + list of hooks. These hooks are run on the generator and discriminator + train ops, and can be used to implement the GAN training scheme. + Defaults to `train.get_sequential_train_hooks()`. add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. @@ -151,7 +160,7 @@ class GANEstimator(estimator.Estimator): else discriminator_optimizer) gan_head = head_lib.gan_head( generator_loss_fn, discriminator_loss_fn, gopt, dopt, - use_loss_summaries) + use_loss_summaries, get_hooks_fn=get_hooks_fn) return _gan_model_fn( features, labels, mode, generator_fn, discriminator_fn, gan_head, add_summaries) @@ -160,11 +169,6 @@ class GANEstimator(estimator.Estimator): model_fn=_model_fn, model_dir=model_dir, config=config) -def _use_check_shapes(real_data): - """Determines whether TFGAN should check Tensor shapes.""" - return isinstance(real_data, ops.Tensor) - - def _gan_model_fn( features, labels, @@ -233,16 +237,18 @@ def _gan_model_fn( def _make_gan_model(generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope, add_summaries, mode): """Make a `GANModel`, and optionally pass in `mode`.""" - # If `generator_fn` has an argument `mode`, pass mode to it. + # If network functions have an argument `mode`, pass mode to it. if 'mode' in inspect.getargspec(generator_fn).args: generator_fn = functools.partial(generator_fn, mode=mode) + if 'mode' in inspect.getargspec(discriminator_fn).args: + discriminator_fn = functools.partial(discriminator_fn, mode=mode) gan_model = tfgan_train.gan_model( generator_fn, discriminator_fn, real_data, generator_inputs, generator_scope=generator_scope, - check_shapes=_use_check_shapes(real_data)) + check_shapes=False) if add_summaries: if not isinstance(add_summaries, (tuple, list)): add_summaries = [add_summaries] diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index e752f0bcccda418b79d4fdabb27807394cbbb425..387a62bd741bd42c03dc1bf70592060c29ccd7a8 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -54,7 +54,8 @@ def generator_fn(noise_dict, mode): return layers.fully_connected(noise, noise.shape[1].value) -def discriminator_fn(data, _): +def discriminator_fn(data, unused_conditioning, mode): + del unused_conditioning, mode return layers.fully_connected(data, 1) @@ -99,7 +100,6 @@ def mock_head(testcase, expected_generator_inputs, expected_real_data, else: testcase.assertEqual(discriminator_scope_name, gan_model.discriminator_scope.name) - testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn) with ops.control_dependencies(assertions): if mode == model_fn_lib.ModeKeys.TRAIN: diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index 204c646e194319c0e63599da0b2a4909ef270ef3..a21358c50bbdb4a1a929b0c5bc322cec4c9923b5 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -71,7 +71,7 @@ class GANHead(head._Head): # pylint: disable=protected-access def __init__(self, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, - get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + get_hooks_fn=None, name=None): """`Head` for GAN training. @@ -86,10 +86,12 @@ class GANHead(head._Head): # pylint: disable=protected-access use_loss_summaries: If `True`, add loss summaries. If `False`, does not. If `None`, uses defaults. get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list - of hooks. + of hooks. Defaults to `train.get_sequential_train_hooks()` name: name of the head. If provided, summary and metrics keys will be suffixed by `"/" + name`. """ + if get_hooks_fn is None: + get_hooks_fn = tfgan_train.get_sequential_train_hooks() # TODO(joelshor): Validate inputs. if use_loss_summaries in [True, False]: diff --git a/tensorflow/contrib/gan/python/eval/__init__.py b/tensorflow/contrib/gan/python/eval/__init__.py index bb8046187807d0cc584f7174eb9aac578855c110..f86b8513053a45f9830411f7df2c32d1f36a97b2 100644 --- a/tensorflow/contrib/gan/python/eval/__init__.py +++ b/tensorflow/contrib/gan/python/eval/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN evaluation module. + +This module supports techniques such as Inception Score, Frechet Inception +distance, and Sliced Wasserstein distance. +""" # pylint: disable=,wildcard-import,unused-import from __future__ import absolute_import @@ -22,10 +26,12 @@ from __future__ import print_function # Collapse eval into a single namespace. from tensorflow.contrib.gan.python.eval.python import classifier_metrics from tensorflow.contrib.gan.python.eval.python import eval_utils +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein from tensorflow.contrib.gan.python.eval.python import summaries from tensorflow.contrib.gan.python.eval.python.classifier_metrics import * from tensorflow.contrib.gan.python.eval.python.eval_utils import * +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein import * from tensorflow.contrib.gan.python.eval.python.summaries import * # pylint: enable=wildcard-import,unused-import @@ -33,7 +39,10 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'classifier_metrics', + 'sliced_wasserstein_distance', 'summaries', 'eval_utils', -] + classifier_metrics.__all__ + summaries.__all__ + eval_utils.__all__ +] + ( + classifier_metrics.__all__ + sliced_wasserstein.__all__ + + summaries.__all__ + eval_utils.__all__) remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index bb65f05b5a17e9a872e41d1dcb05aeb3cd6f6f40..fdfabd07c13f689d075ecbb8786d725fa8a62d01 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -28,6 +28,7 @@ from __future__ import division from __future__ import print_function import functools +import os import sys import tarfile @@ -57,8 +58,10 @@ __all__ = [ 'run_inception', 'inception_score', 'classifier_score', + 'classifier_score_from_logits', 'frechet_inception_distance', 'frechet_classifier_distance', + 'frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] @@ -187,20 +190,34 @@ def get_graph_def_from_resource(filename): return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename)) -def get_graph_def_from_url_tarball(url, filename): - """Get a GraphDef proto from a tarball on the web.""" - def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % ( - url, float(count * block_size) / float(total_size) * 100.0)) - sys.stdout.flush() - tar_filename, _ = urllib.request.urlretrieve(url, reporthook=_progress) +def get_graph_def_from_url_tarball(url, filename, tar_filename=None): + """Get a GraphDef proto from a tarball on the web. + + Args: + url: Web address of tarball + filename: Filename of graph definition within tarball + tar_filename: Temporary download filename (None = always download) + + Returns: + A GraphDef loaded from a file in the downloaded tarball. + """ + if not (tar_filename and os.path.exists(tar_filename)): + + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % + (url, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + + tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress) with tarfile.open(tar_filename, 'r:gz') as tar: proto_str = tar.extractfile(filename).read() return graph_pb2.GraphDef.FromString(proto_str) def _default_graph_def_fn(): - return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH) + return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH, + os.path.basename(INCEPTION_URL)) def run_inception(images, @@ -222,13 +239,13 @@ def run_inception(images, image_size: Required image width and height. See unit tests for the default values. input_tensor: Name of input Tensor. - output_tensor: Name of output Tensor. This function will compute activations - at the specified layer. Examples include INCEPTION_V3_OUTPUT and - INCEPTION_V3_FINAL_POOL which would result in this function computing + output_tensor: Name or list of output Tensors. This function will compute + activations at the specified layer. Examples include INCEPTION_V3_OUTPUT + and INCEPTION_V3_FINAL_POOL which would result in this function computing the final logits or the penultimate pooling layer. Returns: - Logits. + Tensor or Tensors corresponding to computed `output_tensor`. Raises: ValueError: If images are not the correct size. @@ -244,8 +261,14 @@ def run_inception(images, activations = run_image_classifier(images, graph_def, input_tensor, output_tensor) - if array_ops.rank(activations) != 2: - activations = layers.flatten(activations) + if isinstance(activations, list): + for i, activation in enumerate(activations): + if array_ops.rank(activation) != 2: + activations[i] = layers.flatten(activation) + else: + if array_ops.rank(activations) != 2: + activations = layers.flatten(activations) + return activations @@ -257,23 +280,26 @@ def run_image_classifier(tensor, graph_def, input_tensor, tensor: An Input tensor. graph_def: A GraphDef proto. input_tensor: Name of input tensor in graph def. - output_tensor: Name of output tensor in graph def. + output_tensor: A tensor name or list of tensor names in graph def. scope: Name scope for classifier. Returns: - Classifier output. Shape depends on the classifier used, but is often - [batch, classes]. + Classifier output if `output_tensor` is a string, or a list of outputs if + `output_tensor` is a list. Raises: - ValueError: If `image_size` is not `None`, and `tensor` are not the correct - size. + ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def. """ input_map = {input_tensor: tensor} - return_elements = [output_tensor] - classifier_output = importer.import_graph_def( - graph_def, input_map, return_elements, name=scope)[0] + is_singleton = isinstance(output_tensor, str) + if is_singleton: + output_tensor = [output_tensor] + classifier_outputs = importer.import_graph_def( + graph_def, input_map, output_tensor, name=scope) + if is_singleton: + classifier_outputs = classifier_outputs[0] - return classifier_output + return classifier_outputs def classifier_score(images, classifier_fn, num_batches=1): @@ -289,6 +315,11 @@ def classifier_score(images, classifier_fn, num_batches=1): which captures how different the network's classification prediction is from the prior distribution over classes. + NOTE: This function consumes images, computes their logits, and then + computes the classifier score. If you would like to precompute many logits for + large batches, use clasifier_score_from_logits(), which this method also + uses. + Args: images: Images to calculate the classifier score for. classifier_fn: A function that takes images and produces logits based on a @@ -312,6 +343,34 @@ def classifier_score(images, classifier_fn, num_batches=1): swap_memory=True, name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) + + return classifier_score_from_logits(logits) + + +def classifier_score_from_logits(logits): + """Classifier score for evaluating a generative model from logits. + + This method computes the classifier score for a set of logits. This can be + used independently of the classifier_score() method, especially in the case + of using large batches during evaluation where we would like precompute all + of the logits before computing the classifier score. + + This technique is described in detail in https://arxiv.org/abs/1606.03498. In + summary, this function calculates: + + exp( E[ KL(p(y|x) || p(y)) ] ) + + which captures how different the network's classification prediction is from + the prior distribution over classes. + + Args: + logits: Precomputed 2D tensor of logits that will be used to + compute the classifier score. + + Returns: + The classifier score. A floating-point scalar of the same type as the output + of `logits`. + """ logits.shape.assert_has_rank(2) # Use maximum precision for best results. @@ -328,6 +387,7 @@ def classifier_score(images, classifier_fn, num_batches=1): if logits_dtype != dtypes.float64: final_score = math_ops.cast(final_score, logits_dtype) + return final_score @@ -406,6 +466,11 @@ def frechet_classifier_distance(real_images, sample size to compute frechet classifier distance when comparing two generative models. + NOTE: This function consumes images, computes their activations, and then + computes the classifier score. If you would like to precompute many + activations for real and generated images for large batches, please use + frechet_clasifier_distance_from_activations(), which this method also uses. + Args: real_images: Real images to use to compute Frechet Inception distance. generated_images: Generated images to use to compute Frechet Inception @@ -417,7 +482,7 @@ def frechet_classifier_distance(real_images, Returns: The Frechet Inception distance. A floating-point scalar of the same type - as the output of `classifier_fn` + as the output of `classifier_fn`. """ real_images_list = array_ops.split( @@ -436,31 +501,69 @@ def frechet_classifier_distance(real_images, swap_memory=True, name='RunClassifier') - activations_dtype = activations.dtype # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) - if activations_dtype != dtypes.float64: - real_a = math_ops.to_double(real_a) - gen_a = math_ops.to_double(gen_a) - real_a.shape.assert_has_rank(2) - gen_a.shape.assert_has_rank(2) + return frechet_classifier_distance_from_activations(real_a, gen_a) + + +def frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model from activations. + + This methods computes the Frechet classifier distance from activations of + real images and generated images. This can be used independently of the + frechet_classifier_distance() method, especially in the case of using large + batches during evaluation where we would like precompute all of the + activations before computing the classifier distance. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Args: + real_activations: 2D Tensor containing activations of real data. Shape is + [batch_size, activation_size]. + generated_activations: 2D Tensor containing activations of generated data. + Shape is [batch_size, activation_size]. + + Returns: + The Frechet Inception distance. A floating-point scalar of the same type + as the output of the activations. + + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) # Compute mean and covariance matrices of activations. - m = math_ops.reduce_mean(real_a, 0) - m_v = math_ops.reduce_mean(gen_a, 0) - num_examples = math_ops.to_double(array_ops.shape(real_a)[0]) + m = math_ops.reduce_mean(real_activations, 0) + m_v = math_ops.reduce_mean(generated_activations, 0) + num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T + real_centered = real_activations - m sigma = math_ops.matmul( - real_a - m, real_a - m, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_v sigma_v = math_ops.matmul( - gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) + gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) # Find the Tr(sqrt(sigma sigma_v)) component of FID sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 92e0a995748c1c4c2ddfff0daae59be5a6eaefb4..61dc8646ddc10605561ae6b19e90f4739c346608 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -181,7 +181,8 @@ class ClassifierMetricsTest(test.TestCase): batch_size = 3 img = array_ops.ones([batch_size, 299, 299, 3]) pool = _run_with_mock( - classifier_metrics.run_inception, img, + classifier_metrics.run_inception, + img, output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) @@ -190,10 +191,32 @@ class ClassifierMetricsTest(test.TestCase): # Check that none of the model variables are trainable. self.assertListEqual([], variables.trainable_variables()) + def test_run_inception_multiple_outputs(self): + """Test `run_inception` graph construction with multiple outputs.""" + batch_size = 3 + img = array_ops.ones([batch_size, 299, 299, 3]) + logits, pool = _run_with_mock( + classifier_metrics.run_inception, + img, + output_tensor=[ + classifier_metrics.INCEPTION_OUTPUT, + classifier_metrics.INCEPTION_FINAL_POOL + ]) + + self.assertTrue(isinstance(logits, ops.Tensor)) + self.assertTrue(isinstance(pool, ops.Tensor)) + logits.shape.assert_is_compatible_with([batch_size, 1001]) + pool.shape.assert_is_compatible_with([batch_size, 2048]) + + # Check that none of the model variables are trainable. + self.assertListEqual([], variables.trainable_variables()) + def test_inception_score_graph(self): """Test `inception_score` graph construction.""" - score = _run_with_mock(classifier_metrics.inception_score, - array_ops.zeros([6, 299, 299, 3]), num_batches=3) + score = _run_with_mock( + classifier_metrics.inception_score, + array_ops.zeros([6, 299, 299, 3]), + num_batches=3) self.assertTrue(isinstance(score, ops.Tensor)) score.shape.assert_has_rank(0) @@ -231,12 +254,14 @@ class ClassifierMetricsTest(test.TestCase): array_ops.zeros([8, 10], dtype=dtypes.int32), p_logits, q) with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence( - p, array_ops.zeros([8, 10], dtype=dtypes.int32), q) + classifier_metrics._kl_divergence(p, + array_ops.zeros( + [8, 10], dtype=dtypes.int32), q) with self.assertRaisesRegexp(ValueError, 'must be floating type'): - classifier_metrics._kl_divergence( - p, p_logits, array_ops.zeros([10], dtype=dtypes.int32)) + classifier_metrics._kl_divergence(p, p_logits, + array_ops.zeros( + [10], dtype=dtypes.int32)) with self.assertRaisesRegexp(ValueError, 'must have rank 2'): classifier_metrics._kl_divergence(array_ops.zeros([8]), p_logits, q) @@ -249,8 +274,9 @@ class ClassifierMetricsTest(test.TestCase): def test_inception_score_value(self): """Test that `inception_score` gives the correct value.""" - logits = np.array([np.array([1, 2] * 500 + [4]), - np.array([4, 5] * 500 + [6])]) + logits = np.array( + [np.array([1, 2] * 500 + [4]), + np.array([4, 5] * 500 + [6])]) unused_image = array_ops.zeros([2, 299, 299, 3]) incscore = _run_with_mock(classifier_metrics.inception_score, unused_image) @@ -268,9 +294,11 @@ class ClassifierMetricsTest(test.TestCase): test_pool_real_a = np.float32(np.random.randn(512, 256)) test_pool_gen_a = np.float32(np.random.randn(512, 256)) - fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance, - test_pool_real_a, test_pool_gen_a, - classifier_fn=lambda x: x) + fid_op = _run_with_mock( + classifier_metrics.frechet_classifier_distance, + test_pool_real_a, + test_pool_gen_a, + classifier_fn=lambda x: x) with self.test_session() as sess: actual_fid = sess.run(fid_op) @@ -279,6 +307,33 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(expected_fid, actual_fid, 0.0001) + def test_frechet_classifier_distance_covariance(self): + """Test that `frechet_classifier_distance` takes covariance into account.""" + np.random.seed(0) + + # Make num_examples > num_features to ensure scipy's sqrtm function + # doesn't return a complex matrix. + test_pool_reals, test_pool_gens = [], [] + for i in range(1, 11, 2): + test_pool_reals.append(np.float32(np.random.randn(2048, 256) * i)) + test_pool_gens.append(np.float32(np.random.randn(2048, 256) * i)) + + fid_ops = [] + for i in range(len(test_pool_reals)): + fid_ops.append(_run_with_mock( + classifier_metrics.frechet_classifier_distance, + test_pool_reals[i], + test_pool_gens[i], + classifier_fn=lambda x: x)) + + fids = [] + with self.test_session() as sess: + for fid_op in fid_ops: + fids.append(sess.run(fid_op)) + + # Check that the FIDs increase monotonically. + self.assertTrue(all(fid_a < fid_b for fid_a, fid_b in zip(fids, fids[1:]))) + def test_trace_sqrt_product_value(self): """Test that `trace_sqrt_product` gives the correct value.""" np.random.seed(0) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py new file mode 100644 index 0000000000000000000000000000000000000000..523968bed91f1021ae629bf52c405cf5c2d7b917 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model evaluation tools for TFGAN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.eval.python.sliced_wasserstein_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = sliced_wasserstein_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..9bebcacbe46d85fc4226c4275b71b3ecbde57a97 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py @@ -0,0 +1,282 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implementation of Sliced Wasserstein Distance. + +Proposed in https://arxiv.org/abs/1710.10196 and the official Theano +implementation that we used as reference can be found here: +https://github.com/tkarras/progressive_growing_of_gans + +Note: this is not an exact distance but an approximation through random +projections. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import script_ops + +__all__ = ['sliced_wasserstein_distance'] +_GAUSSIAN_FILTER = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 +], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]).reshape([5, 5, 1, 1]) / 256.0 + + +def _laplacian_pyramid(batch, num_levels): + """Compute a Laplacian pyramid. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + num_levels: (int) Desired number of hierarchical levels. + Returns: + List of tensors from the highest to lowest resolution. + """ + gaussian_filter = constant_op.constant(_GAUSSIAN_FILTER) + + def spatial_conv(batch, gain): + s = array_ops.shape(batch) + padded = array_ops.pad(batch, [[0, 0], [2, 2], [2, 2], [0, 0]], 'REFLECT') + xt = array_ops.transpose(padded, [0, 3, 1, 2]) + xt = array_ops.reshape(xt, [s[0] * s[3], s[1] + 4, s[2] + 4, 1]) + conv_out = nn_ops.conv2d(xt, gaussian_filter * gain, [1] * 4, 'VALID') + conv_xt = array_ops.reshape(conv_out, [s[0], s[3], s[1], s[2]]) + conv_xt = array_ops.transpose(conv_xt, [0, 2, 3, 1]) + return conv_xt + + def pyr_down(batch): # matches cv2.pyrDown() + return spatial_conv(batch, 1)[:, ::2, ::2] + + def pyr_up(batch): # matches cv2.pyrUp() + s = array_ops.shape(batch) + zeros = array_ops.zeros([3 * s[0], s[1], s[2], s[3]]) + res = array_ops.concat([batch, zeros], 0) + res = array_ops.batch_to_space(res, crops=[[0, 0], [0, 0]], block_size=2) + res = spatial_conv(res, 4) + return res + + pyramid = [math_ops.to_float(batch)] + for _ in range(1, num_levels): + pyramid.append(pyr_down(pyramid[-1])) + pyramid[-2] -= pyr_up(pyramid[-1]) + return pyramid + + +def _batch_to_patches(batch, patches_per_image, patch_size): + """Extract patches from a batch. + + Args: + batch: (tensor) The batch of images (batch, height, width, channels). + patches_per_image: (int) Number of patches to extract per image. + patch_size: (int) Size of the patches (size, size, channels) to extract. + Returns: + Tensor (batch*patches_per_image, patch_size, patch_size, channels) of + patches. + """ + + def py_func_random_patches(batch): + """Numpy wrapper.""" + batch_size, height, width, channels = batch.shape + patch_count = patches_per_image * batch_size + hs = patch_size // 2 + # Randomly pick patches. + patch_id, y, x, chan = np.ogrid[0:patch_count, -hs:hs + 1, -hs:hs + 1, 0:3] + img_id = patch_id // patches_per_image + # pylint: disable=g-no-augmented-assignment + # Need explicit addition for broadcast to work properly. + y = y + np.random.randint(hs, height - hs, size=(patch_count, 1, 1, 1)) + x = x + np.random.randint(hs, width - hs, size=(patch_count, 1, 1, 1)) + # pylint: enable=g-no-augmented-assignment + idx = ((img_id * height + y) * width + x) * channels + chan + patches = batch.flat[idx] + return patches + + patches = script_ops.py_func( + py_func_random_patches, [batch], batch.dtype, stateful=False) + return patches + + +def _normalize_patches(patches): + """Normalize patches by their mean and standard deviation. + + Args: + patches: (tensor) The batch of patches (batch, size, size, channels). + Returns: + Tensor (batch, size, size, channels) of the normalized patches. + """ + patches = array_ops.concat(patches, 0) + mean, variance = nn.moments(patches, [1, 2, 3], keep_dims=True) + patches = (patches - mean) / math_ops.sqrt(variance) + return array_ops.reshape(patches, [array_ops.shape(patches)[0], -1]) + + +def _sort_rows(matrix, num_rows): + """Sort matrix rows by the last column. + + Args: + matrix: a matrix of values (row,col). + num_rows: (int) number of sorted rows to return from the matrix. + Returns: + Tensor (num_rows, col) of the sorted matrix top K rows. + """ + tmatrix = array_ops.transpose(matrix, [1, 0]) + sorted_tmatrix = nn_ops.top_k(tmatrix, num_rows)[0] + return array_ops.transpose(sorted_tmatrix, [1, 0]) + + +def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim): + """Compute the approximate sliced Wasserstein distance. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + means = [] + for _ in range(random_sampling_count): + # Random projection matrix. + proj = random_ops.random_normal( + [array_ops.shape(a)[1], random_projection_dim]) + proj *= math_ops.rsqrt( + math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True)) + # Project both distributions and sort them. + proj_a = math_ops.matmul(a, proj) + proj_b = math_ops.matmul(b, proj) + proj_a = _sort_rows(proj_a, s[0]) + proj_b = _sort_rows(proj_b, s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + means.append(wdist) + return math_ops.reduce_mean(means) + + +def _sliced_wasserstein_svd(a, b): + """Compute the approximate sliced Wasserstein distance using an SVD. + + This is not part of the paper, it's a variant with possibly more accurate + measure. + + Args: + a: (matrix) Distribution "a" of samples (row, col). + b: (matrix) Distribution "b" of samples (row, col). + Returns: + Float containing the approximate distance between "a" and "b". + """ + s = array_ops.shape(a) + # Random projection matrix. + sig, u = linalg_ops.svd(array_ops.concat([a, b], 0))[:2] + proj_a, proj_b = array_ops.split(u * sig, 2, axis=0) + proj_a = _sort_rows(proj_a[:, ::-1], s[0]) + proj_b = _sort_rows(proj_b[:, ::-1], s[0]) + # Pairwise Wasserstein distance. + wdist = math_ops.reduce_mean(math_ops.abs(proj_a - proj_b)) + return wdist + + +def sliced_wasserstein_distance(real_images, + fake_images, + resolution_min=16, + patches_per_image=64, + patch_size=7, + random_sampling_count=1, + random_projection_dim=7 * 7 * 3, + use_svd=False): + """Compute the Wasserstein distance between two distributions of images. + + Note that measure vary with the number of images. Use 8192 images to get + numbers comparable to the ones in the original paper. + + Args: + real_images: (tensor) Real images (batch, height, width, channels). + fake_images: (tensor) Fake images (batch, height, width, channels). + resolution_min: (int) Minimum resolution for the Laplacion pyramid. + patches_per_image: (int) Number of patches to extract per image per + Laplacian level. + patch_size: (int) Width of a square patch. + random_sampling_count: (int) Number of random projections to average. + random_projection_dim: (int) Dimension of the random projection space. + use_svd: experimental method to compute a more accurate distance. + Returns: + List of tuples (distance_real, distance_fake) for each level of the + Laplacian pyramid from the highest resoluion to the lowest. + distance_real is the Wasserstein distance between real images + distance_fake is the Wasserstein distance between real and fake images. + Raises: + ValueError: If the inputs shapes are incorrect. Input tensor dimensions + (batch, height, width, channels) are expected to be known at graph + construction time. In addition height and width must be the same and the + number of colors should be exactly 3. Real and fake images must have the + same size. + """ + height = real_images.shape[1] + real_images.shape.assert_is_compatible_with([None, None, height, 3]) + fake_images.shape.assert_is_compatible_with(real_images.shape) + + # Select resolutions. + resolution_full = int(height) + resolution_min = min(resolution_min, resolution_full) + resolution_max = resolution_full + # Base loss of detail. + resolutions = [ + 2**i + for i in range( + int(np.log2(resolution_max)), + int(np.log2(resolution_min)) - 1, -1) + ] + + # Gather patches for each level of the Laplacian pyramids. + patches_real, patches_fake, patches_test = ( + [[] for _ in resolutions] for _ in range(3)) + for lod, level in enumerate( + _laplacian_pyramid(real_images, len(resolutions))): + patches_real[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + patches_test[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod, level in enumerate( + _laplacian_pyramid(fake_images, len(resolutions))): + patches_fake[lod].append( + _batch_to_patches(level, patches_per_image, patch_size)) + + for lod in range(len(resolutions)): + for patches in [patches_real, patches_test, patches_fake]: + patches[lod] = _normalize_patches(patches[lod]) + + # Evaluate scores. + scores = [] + for lod in range(len(resolutions)): + if not use_svd: + scores.append( + (_sliced_wasserstein(patches_real[lod], patches_test[lod], + random_sampling_count, random_projection_dim), + _sliced_wasserstein(patches_real[lod], patches_fake[lod], + random_sampling_count, random_projection_dim))) + else: + scores.append( + (_sliced_wasserstein_svd(patches_real[lod], patches_test[lod]), + _sliced_wasserstein_svd(patches_real[lod], patches_fake[lod]))) + return scores diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py new file mode 100644 index 0000000000000000000000000000000000000000..871f1ad54e2559f5df28efa78f99997a866f7087 --- /dev/null +++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_test.py @@ -0,0 +1,131 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Sliced Wasserstein Distance.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import ndimage +from tensorflow.contrib.gan.python.eval.python import sliced_wasserstein_impl as swd +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class ClassifierMetricsTest(test.TestCase): + + def test_laplacian_pyramid(self): + # The numpy/scipy code for reference estimation comes from: + # https://github.com/tkarras/progressive_growing_of_gans + gaussian_filter = np.float32([[1, 4, 6, 4, 1], [4, 16, 24, 16, 4], [ + 6, 24, 36, 24, 6 + ], [4, 16, 24, 16, 4], [1, 4, 6, 4, 1]]) / 256.0 + + def np_pyr_down(minibatch): # matches cv2.pyrDown() + assert minibatch.ndim == 4 + return ndimage.convolve( + minibatch, + gaussian_filter[np.newaxis, np.newaxis, :, :], + mode='mirror')[:, :, ::2, ::2] + + def np_pyr_up(minibatch): # matches cv2.pyrUp() + assert minibatch.ndim == 4 + s = minibatch.shape + res = np.zeros((s[0], s[1], s[2] * 2, s[3] * 2), minibatch.dtype) + res[:, :, ::2, ::2] = minibatch + return ndimage.convolve( + res, + gaussian_filter[np.newaxis, np.newaxis, :, :] * 4.0, + mode='mirror') + + def np_laplacian_pyramid(minibatch, num_levels): + # Note: there's a bug in the original SWD, fixed repeatability. + pyramid = [minibatch.astype('f').copy()] + for _ in range(1, num_levels): + pyramid.append(np_pyr_down(pyramid[-1])) + pyramid[-2] -= np_pyr_up(pyramid[-1]) + return pyramid + + data = np.random.normal(size=[256, 3, 32, 32]).astype('f') + pyramid = np_laplacian_pyramid(data, 3) + data_tf = array_ops.placeholder(dtypes.float32, [256, 32, 32, 3]) + pyramid_tf = swd._laplacian_pyramid(data_tf, 3) + with self.test_session() as sess: + pyramid_tf = sess.run( + pyramid_tf, feed_dict={ + data_tf: data.transpose(0, 2, 3, 1) + }) + for x in range(3): + self.assertAllClose( + pyramid[x].transpose(0, 2, 3, 1), pyramid_tf[x], atol=1e-6) + + def test_sliced_wasserstein_distance(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.014, 0.014], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.15) + self.assertAllClose( + np.array([0.014, 0.020], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.15) + + def test_sliced_wasserstein_distance_svd(self): + """Test the distance.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 32, 3]) + wfunc = swd.sliced_wasserstein_distance(d1, d2, use_svd=True) + with self.test_session() as sess: + wscores = [sess.run(x) for x in wfunc] + self.assertAllClose( + np.array([0.013, 0.013], 'f'), + np.array([x[0] for x in wscores], 'f'), + rtol=0.15) + self.assertAllClose( + np.array([0.014, 0.019], 'f'), + np.array([x[1] for x in wscores], 'f'), + rtol=0.15) + + def test_swd_mismatched(self): + """Test the inputs mismatched shapes are detected.""" + d1 = random_ops.random_uniform([256, 32, 32, 3]) + d2 = random_ops.random_normal([256, 32, 31, 3]) + d3 = random_ops.random_normal([256, 31, 32, 3]) + d4 = random_ops.random_normal([255, 32, 32, 3]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d3) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d4) + + def test_swd_not_rgb(self): + """Test that only RGB is supported.""" + d1 = random_ops.random_uniform([256, 32, 32, 1]) + d2 = random_ops.random_normal([256, 32, 32, 1]) + with self.assertRaises(ValueError): + swd.sliced_wasserstein_distance(d1, d2) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 508b4d20d8767f42246a0d0c87f911b7ac612f45..0d1afad72da8a8e087239868e25ddebe23490d1e 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.eval.python import eval_utils from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -38,16 +39,26 @@ def _assert_is_image(data): data.shape[1:].assert_is_fully_defined() -def add_gan_model_image_summaries(gan_model, grid_size=4): +def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): """Adds image summaries for real and fake images. Args: gan_model: A GANModel tuple. grid_size: The size of an image grid. + model_summaries: Also add summaries of the model. Raises: ValueError: If real and generated data aren't images. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + saved_params = locals() + saved_params.pop('gan_model', None) + with ops.name_scope('cyclegan_x2y_image_summaries'): + add_gan_model_image_summaries(gan_model.model_x2y, **saved_params) + with ops.name_scope('cyclegan_y2x_image_summaries'): + add_gan_model_image_summaries(gan_model.model_y2x, **saved_params) + return + _assert_is_image(gan_model.real_data) _assert_is_image(gan_model.generated_data) @@ -73,7 +84,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4): image_shape=generated_image_shape, num_channels=generated_channels), max_outputs=1) - add_gan_model_summaries(gan_model) + + if model_summaries: + add_gan_model_summaries(gan_model) def add_image_comparison_summaries(gan_model, num_comparisons=2, @@ -96,6 +109,15 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2, ValueError: If the generator input, real, and generated data aren't all the same size. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + saved_params = locals() + saved_params.pop('gan_model', None) + with ops.name_scope('cyclegan_x2y_image_comparison_summaries'): + add_image_comparison_summaries(gan_model.model_x2y, **saved_params) + with ops.name_scope('cyclegan_y2x_image_comparison_summaries'): + add_image_comparison_summaries(gan_model.model_y2x, **saved_params) + return + _assert_is_image(gan_model.generator_inputs) _assert_is_image(gan_model.generated_data) _assert_is_image(gan_model.real_data) @@ -133,6 +155,13 @@ def add_gan_model_summaries(gan_model): Args: gan_model: A GANModel tuple. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + with ops.name_scope('cyclegan_x2y_summaries'): + add_gan_model_summaries(gan_model.model_x2y) + with ops.name_scope('cyclegan_y2x_summaries'): + add_gan_model_summaries(gan_model.model_y2x) + return + with ops.name_scope('generator_variables'): for var in gan_model.generator_variables: summary.histogram(var.name, var) @@ -147,6 +176,13 @@ def add_regularization_loss_summaries(gan_model): Args: gan_model: A GANModel tuple. """ + if isinstance(gan_model, namedtuples.CycleGANModel): + with ops.name_scope('cyclegan_x2y_regularization_loss_summaries'): + add_regularization_loss_summaries(gan_model.model_x2y) + with ops.name_scope('cyclegan_y2x_regularization_loss_summaries'): + add_regularization_loss_summaries(gan_model.model_y2x) + return + if gan_model.generator_scope: summary.scalar( 'generator_regularization_loss', diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index a3b02bcefc6cbaa6e24131b336b5c9c072bde52c..7956db43348c0cc0f3d372e92a2e343f5aa62013 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -57,40 +57,89 @@ def get_gan_model(): discriminator_fn=discriminator_model) +def get_cyclegan_model(): + with variable_scope.variable_scope('x2y'): + model_x2y = get_gan_model() + with variable_scope.variable_scope('y2x'): + model_y2x = get_gan_model() + return namedtuples.CycleGANModel( + model_x2y=model_x2y, + model_y2x=model_y2x, + reconstructed_x=array_ops.zeros([3, 30, 35, 6]), + reconstructed_y=array_ops.zeros([3, 30, 35, 6])) + + class SummariesTest(test.TestCase): - def testAddGanModelImageSummaries(self): - summaries.add_gan_model_image_summaries(get_gan_model(), grid_size=2) + def _test_add_gan_model_image_summaries_impl(self, get_model_fn, + expected_num_summary_ops, + model_summaries): + summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2, + model_summaries=model_summaries) - self.assertEquals(5, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() summary.merge_all().eval() - def testAddGanModelSummaries(self): - summaries.add_gan_model_summaries(get_gan_model()) + def test_add_gan_model_image_summaries(self): + self._test_add_gan_model_image_summaries_impl(get_gan_model, 5, True) + + def test_add_gan_model_image_summaries_no_model(self): + self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) - self.assertEquals(3, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + def test_add_gan_model_image_summaries_for_cyclegan(self): + self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, + True) + + def _test_add_gan_model_summaries_impl(self, get_model_fn, + expected_num_summary_ops): + summaries.add_gan_model_summaries(get_model_fn()) + + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): variables.global_variables_initializer().run() summary.merge_all().eval() - def testAddRegularizationLossSummaries(self): - summaries.add_regularization_loss_summaries(get_gan_model()) + def test_add_gan_model_summaries(self): + self._test_add_gan_model_summaries_impl(get_gan_model, 3) + + def test_add_gan_model_summaries_for_cyclegan(self): + self._test_add_gan_model_summaries_impl(get_cyclegan_model, 6) - self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + def _test_add_regularization_loss_summaries_impl(self, get_model_fn, + expected_num_summary_ops): + summaries.add_regularization_loss_summaries(get_model_fn()) + + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): summary.merge_all().eval() + def test_add_regularization_loss_summaries(self): + self._test_add_regularization_loss_summaries_impl(get_gan_model, 2) + + def test_add_regularization_loss_summaries_for_cyclegan(self): + self._test_add_regularization_loss_summaries_impl(get_cyclegan_model, 4) + # TODO(joelshor): Add correctness test. - def testAddImageComparisonSummaries(self): - summaries.add_image_comparison_summaries( - get_gan_model(), display_diffs=True) + def _test_add_image_comparison_summaries_impl(self, get_model_fn, + expected_num_summary_ops): + summaries.add_image_comparison_summaries(get_model_fn(), display_diffs=True) - self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) + self.assertEquals(expected_num_summary_ops, + len(ops.get_collection(ops.GraphKeys.SUMMARIES))) with self.test_session(use_gpu=True): summary.merge_all().eval() + def test_add_image_comparison_summaries(self): + self._test_add_image_comparison_summaries_impl(get_gan_model, 1) + + def test_add_image_comparison_summaries_for_cyclegan(self): + self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/features/__init__.py b/tensorflow/contrib/gan/python/features/__init__.py index 6d0972f8db418d6fcf517cc6f7e96093ae08a9e4..4816daf760143af9f1502873b123ffad8e5ec8ce 100644 --- a/tensorflow/contrib/gan/python/features/__init__.py +++ b/tensorflow/contrib/gan/python/features/__init__.py @@ -12,7 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN features module. + +This module includes support for virtual batch normalization, buffer replay, +conditioning, etc. +""" from __future__ import absolute_import from __future__ import division @@ -22,10 +26,12 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import from tensorflow.contrib.gan.python.features.python import clip_weights from tensorflow.contrib.gan.python.features.python import conditioning_utils +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.gan.python.features.python import virtual_batchnorm from tensorflow.contrib.gan.python.features.python.clip_weights import * from tensorflow.contrib.gan.python.features.python.conditioning_utils import * +from tensorflow.contrib.gan.python.features.python.random_tensor_pool import * from tensorflow.contrib.gan.python.features.python.virtual_batchnorm import * # pylint: enable=unused-import,wildcard-import @@ -33,5 +39,6 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = clip_weights.__all__ _allowed_symbols += conditioning_utils.__all__ +_allowed_symbols += random_tensor_pool.__all__ _allowed_symbols += virtual_batchnorm.__all__ remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py index 030e37ec679ec58e3b534fd3644ffe1d23173404..2b7bb5f14e7f3d1b3f913d3426efaaae19079ffb 100644 --- a/tensorflow/contrib/gan/python/features/python/clip_weights_test.py +++ b/tensorflow/contrib/gan/python/features/python/clip_weights_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tfgan.python.features.clip_weights.""" +"""Tests for features.clip_weights.""" from __future__ import absolute_import from __future__ import division @@ -31,17 +31,18 @@ class ClipWeightsTest(test.TestCase): """Tests for `discriminator_weight_clip`.""" def setUp(self): + super(ClipWeightsTest, self).setUp() self.variables = [variables.Variable(2.0)] self.tuple = collections.namedtuple( 'VarTuple', ['discriminator_variables'])(self.variables) def _test_weight_clipping_helper(self, use_tuple): - loss = self.variables[0] * 2.0 + loss = self.variables[0] opt = training.GradientDescentOptimizer(1.0) if use_tuple: - opt_clip = clip_weights.weight_clip(opt, self.variables, 0.1) + opt_clip = clip_weights.clip_variables(opt, self.variables, 0.1) else: - opt_clip = clip_weights.discriminator_weight_clip(opt, self.tuple, 0.1) + opt_clip = clip_weights.clip_discriminator_weights(opt, self.tuple, 0.1) train_op1 = opt.minimize(loss, var_list=self.variables) train_op2 = opt_clip.minimize(loss, var_list=self.variables) @@ -72,10 +73,14 @@ class ClipWeightsTest(test.TestCase): clip_weights.clip_discriminator_weights(opt, self.tuple, weight_clip=-1) else: with self.assertRaisesRegexp(ValueError, 'must be positive'): - clip_weights.clip_weights(opt, self.variables, weight_clip=-1) + clip_weights.clip_variables(opt, self.variables, weight_clip=-1) def test_incorrect_weight_clip_value_argsonly(self): self._test_incorrect_weight_clip_value_helper(False) def test_incorrect_weight_clip_value_tuple(self): self._test_incorrect_weight_clip_value_helper(True) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/features/python/tensor_pool.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py similarity index 86% rename from tensorflow/contrib/gan/python/features/python/tensor_pool.py rename to tensorflow/contrib/gan/python/features/python/random_tensor_pool.py index 0bd2fa3db9427315ed623bc4d47d74683777bb94..ca904971fa8cb0440d3e0c9060f13cc214c9eaad 100644 --- a/tensorflow/contrib/gan/python/features/python/tensor_pool.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool.py @@ -25,11 +25,11 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.gan.python.features.python import tensor_pool_impl +from tensorflow.contrib.gan.python.features.python import random_tensor_pool_impl # pylint: disable=wildcard-import -from tensorflow.contrib.gan.python.features.python.tensor_pool_impl import * +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented -__all__ = tensor_pool_impl.__all__ +__all__ = random_tensor_pool_impl.__all__ remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/features/python/tensor_pool_impl.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py similarity index 67% rename from tensorflow/contrib/gan/python/features/python/tensor_pool_impl.py rename to tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py index 79318a69d291f11b7978e898423f1dd3e757466f..4cfae0de4451880cf8229903b0eb74b1c6e2e04d 100644 --- a/tensorflow/contrib/gan/python/features/python/tensor_pool_impl.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_impl.py @@ -42,8 +42,14 @@ __all__ = [ ] -def tensor_pool(input_value, - pool_size, +def _to_tuple(x): + if isinstance(x, (list, tuple)): + return tuple(x) + return (x,) + + +def tensor_pool(input_values, + pool_size=50, pooling_probability=0.5, name='tensor_pool'): """Queue storing input values and returning random previously stored ones. @@ -57,15 +63,18 @@ def tensor_pool(input_value, `pool_size` = 0 or `pooling_probability` = 0. Args: - input_value: A `Tensor` from which to read values to be pooled. - pool_size: An integer specifying the maximum size of the pool. + input_values: A `Tensor`, or a list or tuple of `Tensor`s from which to read + values to be pooled. + pool_size: An integer specifying the maximum size of the pool. Defaults to + 50. pooling_probability: A float `Tensor` specifying the probability of getting a value from the pool, as opposed to just the current input. name: A string prefix for the name scope for all tensorflow ops. Returns: - A `Tensor` which is with given probability either the `input_value` or a - randomly chosen sample that was previously inserted in the pool. + A `Tensor`, or a list or tuple of `Tensor`s (according to the type ofx + `input_values`) which is with given probability either the `input_values` or + a randomly chosen sample that was previously inserted in the pool. Raises: ValueError: If `pool_size` is negative. @@ -74,45 +83,57 @@ def tensor_pool(input_value, if pool_size < 0: raise ValueError('`pool_size` is negative.') elif pool_size == 0: - return input_value + return input_values - with ops.name_scope('{}_pool_queue'.format(name), - values=[input_value, pooling_probability]): + original_input_values = input_values + input_values = _to_tuple(input_values) + + with ops.name_scope( + '{}_pool_queue'.format(name), + values=input_values + (pooling_probability,)): pool_queue = data_flow_ops.RandomShuffleQueue( capacity=pool_size, min_after_dequeue=0, - dtypes=[input_value.dtype], + dtypes=[v.dtype for v in input_values], shapes=None) # In pseudeo code this code does the following: # if not pool_full: - # enqueue(input_value) - # return input_value + # enqueue(input_values) + # return input_values # else - # dequeue_value = dequeue_random_sample() - # enqueue(input_value) + # dequeue_values = dequeue_random_sample() + # enqueue(input_values) # if rand() < pooling_probability: - # return dequeue_value + # return dequeue_values # else - # return input_value + # return input_values def _get_input_value_pooled(): - enqueue_op = pool_queue.enqueue(input_value) + enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): - return array_ops.identity(input_value) + return tuple(array_ops.identity(v) for v in input_values) def _get_random_pool_value_and_enqueue_input(): - dequeue_value = pool_queue.dequeue() - with ops.control_dependencies([dequeue_value]): - enqueue_op = pool_queue.enqueue(input_value) + dequeue_values = _to_tuple(pool_queue.dequeue()) + with ops.control_dependencies(dequeue_values): + enqueue_op = pool_queue.enqueue(input_values) with ops.control_dependencies([enqueue_op]): prob = random_ops.random_uniform( (), dtype=dtypes.float32) < pooling_probability - return control_flow_ops.cond(prob, lambda: dequeue_value, - lambda: input_value) + return control_flow_ops.cond(prob, lambda: dequeue_values, + lambda: input_values) - output_value = control_flow_ops.cond( + output_values = _to_tuple(control_flow_ops.cond( pool_queue.size() < pool_size, _get_input_value_pooled, - _get_random_pool_value_and_enqueue_input) + _get_random_pool_value_and_enqueue_input)) + + # Make sure that the shape of `output_value` is set. + for input_value, output_value in zip(input_values, output_values): + output_value.set_shape(input_value.shape) - return output_value + if isinstance(original_input_values, list): + return list(output_values) + elif isinstance(original_input_values, tuple): + return output_values + return output_values[0] diff --git a/tensorflow/contrib/gan/python/features/python/tensor_pool_test.py b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py similarity index 70% rename from tensorflow/contrib/gan/python/features/python/tensor_pool_test.py rename to tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py index 49b77bb3fc56b91cd419f76b6eea920df7efe4a7..d8cf549cf71838178c9da01df462d41d81595fe5 100644 --- a/tensorflow/contrib/gan/python/features/python/tensor_pool_test.py +++ b/tensorflow/contrib/gan/python/features/python/random_tensor_pool_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for tf.contrib.gan.python.features.tensor_pool.""" +"""Tests for tf.contrib.gan.python.features.random_tensor_pool.""" from __future__ import absolute_import from __future__ import division @@ -20,7 +20,7 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.gan.python.features.python import tensor_pool_impl as tensor_pool +from tensorflow.contrib.gan.python.features.python.random_tensor_pool_impl import tensor_pool from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -32,7 +32,8 @@ class TensorPoolTest(test.TestCase): """Checks that `input_value` can have unknown shape.""" input_value = array_ops.placeholder( dtype=dtypes.int32, shape=[None, None, 3]) - output_value = tensor_pool.tensor_pool(input_value, pool_size=10) + output_value = tensor_pool(input_value, pool_size=10) + self.assertEqual(output_value.shape.as_list(), [None, None, 3]) with self.test_session(use_gpu=True) as session: for i in range(10): @@ -43,7 +44,8 @@ class TensorPoolTest(test.TestCase): def test_pool_sequence(self): """Checks that values are pooled and returned maximally twice.""" input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - output_value = tensor_pool.tensor_pool(input_value, pool_size=10) + output_value = tensor_pool(input_value, pool_size=10) + self.assertEqual(output_value.shape.as_list(), []) with self.test_session(use_gpu=True) as session: outs = [] @@ -59,8 +61,9 @@ class TensorPoolTest(test.TestCase): def test_never_pool(self): """Checks that setting `pooling_probability` to zero works.""" input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) - output_value = tensor_pool.tensor_pool( + output_value = tensor_pool( input_value, pool_size=10, pooling_probability=0.0) + self.assertEqual(output_value.shape.as_list(), []) with self.test_session(use_gpu=True) as session: for i in range(50): @@ -72,10 +75,11 @@ class TensorPoolTest(test.TestCase): input_value = array_ops.placeholder(dtype=dtypes.int32, shape=[]) pool_size = 10 pooling_probability = 0.2 - output_value = tensor_pool.tensor_pool( + output_value = tensor_pool( input_value, pool_size=pool_size, pooling_probability=pooling_probability) + self.assertEqual(output_value.shape.as_list(), []) with self.test_session(use_gpu=True) as session: not_pooled = 0 @@ -89,6 +93,24 @@ class TensorPoolTest(test.TestCase): 1 - pooling_probability, atol=0.03) + def test_input_values_tuple(self): + """Checks that `input_values` can be a tuple.""" + input_values = (array_ops.placeholder(dtype=dtypes.int32, shape=[]), + array_ops.placeholder(dtype=dtypes.int32, shape=[])) + output_values = tensor_pool(input_values, pool_size=3) + self.assertEqual(len(output_values), len(input_values)) + for output_value in output_values: + self.assertEqual(output_value.shape.as_list(), []) + + with self.test_session(use_gpu=True) as session: + for i in range(10): + outs = session.run(output_values, { + input_values[0]: i, + input_values[1]: i + 1 + }) + self.assertEqual(len(outs), len(input_values)) + self.assertEqual(outs[1] - outs[0], 1) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/gan/python/losses/__init__.py b/tensorflow/contrib/gan/python/losses/__init__.py index 290ff867a1e443f20a63e27fd97f53fed8a6cc11..d9bf8ebfdf65dfc76e4569dcaf26e0e51c7fc107 100644 --- a/tensorflow/contrib/gan/python/losses/__init__.py +++ b/tensorflow/contrib/gan/python/losses/__init__.py @@ -12,7 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN grouped API. Please see README.md for details and usage.""" +"""TFGAN losses and penalties. + +Losses can be used with individual arguments or with GANModel tuples. +""" from __future__ import absolute_import from __future__ import division diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 940762cf2aa0f473cd41d9d543e2773b565a5248..39588b7219ebac1cc4855532be3fcc38e6381134 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -67,6 +67,7 @@ __all__ = [ 'wasserstein_gradient_penalty', 'mutual_information_penalty', 'combine_adversarial_loss', + 'cycle_consistency_loss', ] @@ -304,6 +305,7 @@ def wasserstein_gradient_penalty( discriminator_fn, discriminator_scope, epsilon=1e-10, + target=1.0, weights=1.0, scope=None, loss_collection=ops.GraphKeys.LOSSES, @@ -323,6 +325,8 @@ def wasserstein_gradient_penalty( discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. + target: Optional Python number or `Tensor` indicating the target value of + gradient norm. Defaults to 1.0. weights: Optional `Tensor` whose rank is either 0, or the same rank as `real_data` and `generated_data`, and must be broadcastable to them (i.e., all dimensions must be either `1`, or the same as the @@ -373,7 +377,7 @@ def wasserstein_gradient_penalty( # For numerical stability, add epsilon to the sum before taking the square # root. Note tf.norm does not add epsilon. slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes - 1.0) + penalties = math_ops.square(slopes / target - 1.0) penalty = losses.compute_weighted_loss( penalties, weights, scope=scope, loss_collection=loss_collection, reduction=reduction) @@ -915,3 +919,63 @@ def combine_adversarial_loss(main_loss, array_ops.stop_gradient(adv_coeff) * adversarial_loss) return final_loss + + +def cycle_consistency_loss(data_x, + reconstructed_data_x, + data_y, + reconstructed_data_y, + scope=None, + add_summaries=False): + """Defines the cycle consistency loss. + + The cyclegan model has two partial models where `model_x2y` generator F maps + data set X to Y, `model_y2x` generator G maps data set Y to X. For a `data_x` + in data set X, we could reconstruct it by + * reconstructed_data_x = G(F(data_x)) + Similarly + * reconstructed_data_y = F(G(data_y)) + + The cycle consistency loss is about the difference between data and + reconstructed data, namely + * loss_x2x = |data_x - G(F(data_x))| (L1-norm) + * loss_y2y = |data_y - F(G(data_y))| (L1-norm) + * loss = (loss_x2x + loss_y2y) / 2 + where `loss` is the final result. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + data_x: A `Tensor` of data X. + reconstructed_data_x: A `Tensor` of reconstructed data X. + data_y: A `Tensor` of data Y. + reconstructed_data_y: A `Tensor` of reconstructed data Y. + scope: The scope for the operations performed in computing the loss. + Defaults to None. + add_summaries: Whether or not to add detailed summaries for the loss. + Defaults to False. + + Returns: + A scalar `Tensor` of cycle consistency loss. + """ + + def _partial_cycle_consistency_loss(data, reconstructed_data): + # Following the original implementation + # https://github.com/junyanz/CycleGAN/blob/master/models/cycle_gan_model.lua + # use L1-norm of pixel-wise error normalized by data size so that + # `cycle_loss_weight` can be specified independent of image size. + return math_ops.reduce_mean(math_ops.abs(data - reconstructed_data)) + + with ops.name_scope( + scope, + 'cycle_consistency_loss', + values=[data_x, reconstructed_data_x, data_y, reconstructed_data_y]): + loss_x2x = _partial_cycle_consistency_loss(data_x, reconstructed_data_x) + loss_y2y = _partial_cycle_consistency_loss(data_y, reconstructed_data_y) + loss = (loss_x2x + loss_y2y) / 2.0 + if add_summaries: + summary.scalar('cycle_consistency_loss_x2x', loss_x2x) + summary.scalar('cycle_consistency_loss_y2y', loss_y2y) + summary.scalar('cycle_consistency_loss', loss) + + return loss diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index b5cd8c92ba180e981e0faf877021cb6d69dc34b4..dbaa624ae9d6a5a5949db692e52c0c1deb18b8df 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -481,6 +481,29 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): }) self.assertAlmostEqual(self._expected_loss, loss, 5) + def test_loss_with_gradient_norm_target(self): + """Test loss value with non default gradient norm target.""" + generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + real_data = array_ops.placeholder(dtypes.float32, shape=(None, None)) + + loss = tfgan_losses.wasserstein_gradient_penalty( + generated_data, + real_data, + self._kwargs['generator_inputs'], + self._kwargs['discriminator_fn'], + self._kwargs['discriminator_scope'], + target=2.0) + + with self.test_session() as sess: + variables.global_variables_initializer().run() + loss = sess.run( + loss, + feed_dict={ + generated_data: self._generated_data_np, + real_data: self._real_data_np, + }) + self.assertAlmostEqual(1.0, loss, 5) + def test_reuses_scope(self): """Test that gradient penalty reuses discriminator scope.""" num_vars = len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)) @@ -620,7 +643,34 @@ class CombineAdversarialLossTest(test.TestCase): with self.test_session(use_gpu=True) as sess: for _ in range(10): # spot check closeness on more than one sample. gnorm_np, precond_gnorm_np = sess.run([gnorm, precond_gnorm]) - self.assertNear(gnorm_np, precond_gnorm_np, 1e-5) + self.assertNear(gnorm_np, precond_gnorm_np, 1e-4) + + +class CycleConsistencyLossTest(test.TestCase): + """Tests for cycle_consistency_loss.""" + + def setUp(self): + super(CycleConsistencyLossTest, self).setUp() + + self._data_x_np = [[1.0, 2, 3], [4, 5, 6]] + self._reconstructed_data_x_np = [[7.0, 8, 9], [10, 11, 12]] + self._data_y_np = [1.0, 9] + self._reconstructed_data_y_np = [-2.0, 3] + + self._data_x = constant_op.constant(self._data_x_np, dtype=dtypes.float32) + self._reconstructed_data_x = constant_op.constant( + self._reconstructed_data_x_np, dtype=dtypes.float32) + self._data_y = constant_op.constant(self._data_y_np, dtype=dtypes.float32) + self._reconstructed_data_y = constant_op.constant( + self._reconstructed_data_y_np, dtype=dtypes.float32) + + def test_correct_loss(self): + loss = tfgan_losses.cycle_consistency_loss( + self._data_x, self._reconstructed_data_x, self._data_y, + self._reconstructed_data_y) + with self.test_session(use_gpu=True): + variables.global_variables_initializer().run() + self.assertNear(5.25, loss.eval(), 1e-5) if __name__ == '__main__': diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index b341f03a0ddaacca8b036189516c71908bee50eb..dcc3f94c2d6b9e5e44036e7cc1a9d1bb39104fb5 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -60,6 +60,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.losses.python import losses_impl from tensorflow.python.util import tf_inspect @@ -78,6 +79,7 @@ __all__ = [ 'wasserstein_gradient_penalty', 'mutual_information_penalty', 'combine_adversarial_loss', + 'cycle_consistency_loss', ] @@ -246,3 +248,32 @@ def combine_adversarial_loss(gan_loss, scalar_summaries, gradient_summaries) return gan_loss._replace(generator_loss=combined_loss) + + +def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False): + """Defines the cycle consistency loss. + + Uses `cycle_consistency_loss` to compute the cycle consistency loss for a + `cyclegan_model`. + + Args: + cyclegan_model: A `CycleGANModel` namedtuple. + scope: The scope for the operations performed in computing the loss. + Defaults to None. + add_summaries: Whether or not to add detailed summaries for the loss. + Defaults to False. + + Returns: + A scalar `Tensor` of cycle consistency loss. + + Raises: + ValueError: If `cyclegan_model` is not a `CycleGANModel` namedtuple. + """ + if not isinstance(cyclegan_model, namedtuples.CycleGANModel): + raise ValueError( + '`cyclegan_model` must be a `CycleGANModel`. Instead, was %s.' % + type(cyclegan_model)) + return losses_impl.cycle_consistency_loss( + cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x, + cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y, + scope, add_summaries) diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py index 215b15ef6915d0b8113def35987ed6ab85617bcc..aa1ef11172dee6799994b87f70a3883cd67fd15b 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py @@ -22,8 +22,11 @@ import collections import numpy as np +from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses - +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -125,6 +128,7 @@ manual_tests = [ 'combine_adversarial_loss', 'mutual_information_penalty', 'wasserstein_gradient_penalty', + 'cycle_consistency_loss', ] discriminator_keyword_args = { @@ -139,6 +143,38 @@ generator_keyword_args = { } +class CycleConsistencyLossTest(test.TestCase): + + def setUp(self): + super(CycleConsistencyLossTest, self).setUp() + + def _partial_model(generator_inputs_np): + model = namedtuples.GANModel(*[None] * 11) + return model._replace( + generator_inputs=constant_op.constant( + generator_inputs_np, dtype=dtypes.float32)) + + self._model_x2y = _partial_model([1, 2]) + self._model_y2x = _partial_model([5, 6]) + + def test_model_type(self): + """Test the input model type for `cycle_consistency_loss`.""" + with self.assertRaises(ValueError): + tfgan_losses.cycle_consistency_loss(self._model_x2y) + + def test_correct_loss(self): + """Test the output of `cycle_consistency_loss`.""" + loss = tfgan_losses.cycle_consistency_loss( + namedtuples.CycleGANModel( + model_x2y=self._model_x2y, + model_y2x=self._model_y2x, + reconstructed_x=constant_op.constant([9, 8], dtype=dtypes.float32), + reconstructed_y=constant_op.constant([7, 2], dtype=dtypes.float32))) + with self.test_session(use_gpu=True): + variables.global_variables_initializer().run() + self.assertNear(5.0, loss.eval(), 1e-5) + + if __name__ == '__main__': for loss_name in tfgan_losses.__all__: if loss_name in manual_tests: continue diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7..25cfeafeec9000b0dc3849ebe646e59c1b4d1cc3 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -30,7 +30,9 @@ __all__ = [ 'GANModel', 'InfoGANModel', 'ACGANModel', + 'CycleGANModel', 'GANLoss', + 'CycleGANLoss', 'GANTrainOps', 'GANTrainSteps', ] @@ -79,6 +81,7 @@ class InfoGANModel( collections.namedtuple('InfoGANModel', GANModel._fields + ( 'structured_generator_inputs', 'predicted_distributions', + 'discriminator_and_aux_fn', ))): """An InfoGANModel contains all the pieces needed for InfoGAN training. @@ -91,6 +94,8 @@ class InfoGANModel( predicted_distributions: A list of tf.Distributions. Predicted by the recognizer, and used to evaluate the likelihood of the structured noise. List length should match `structured_generator_inputs`. + discriminator_and_aux_fn: The original discriminator function that returns + a tuple of (logits, `predicted_distributions`). """ @@ -112,6 +117,25 @@ class ACGANModel( """ +class CycleGANModel( + collections.namedtuple( + 'CycleGANModel', + ('model_x2y', 'model_y2x', 'reconstructed_x', 'reconstructed_y'))): + """An CycleGANModel contains all the pieces needed for CycleGAN training. + + The model `model_x2y` generator F maps data set X to Y, while the model + `model_y2x` generator G maps data set Y to X. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + model_x2y: A `GANModel` namedtuple whose generator maps data set X to Y. + model_y2x: A `GANModel` namedtuple whose generator maps data set Y to X. + reconstructed_x: A `Tensor` of reconstructed data X which is G(F(X)). + reconstructed_y: A `Tensor` of reconstructed data Y which is F(G(Y)). + """ + + class GANLoss( collections.namedtuple('GANLoss', ( 'generator_loss', @@ -125,6 +149,18 @@ class GANLoss( """ +class CycleGANLoss( + collections.namedtuple('CycleGANLoss', ('loss_x2y', 'loss_y2x'))): + """CycleGANLoss contains the losses for `CycleGANModel`. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + loss_x2y: A `GANLoss` namedtuple representing the loss of `model_x2y`. + loss_y2x: A `GANLoss` namedtuple representing the loss of `model_y2x`. + """ + + class GANTrainOps( collections.namedtuple('GANTrainOps', ( 'generator_train_op', diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index e9443f766bdc59cf45513c93e14390cd6126c295..776eb11ecb1624544d24611d8fe6ca19768b8313 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -52,7 +52,9 @@ __all__ = [ 'gan_model', 'infogan_model', 'acgan_model', + 'cyclegan_model', 'gan_loss', + 'cyclegan_loss', 'gan_train_ops', 'gan_train', 'get_sequential_train_hooks', @@ -215,7 +217,8 @@ def infogan_model( disc_scope, lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API structured_generator_inputs, - predicted_distributions) + predicted_distributions, + discriminator_fn) def acgan_model( @@ -276,14 +279,16 @@ def acgan_model( generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) generated_data = generator_fn(generator_inputs) with variable_scope.variable_scope(discriminator_scope) as dis_scope: - (discriminator_gen_outputs, discriminator_gen_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(generated_data, generator_inputs)) + with ops.name_scope(dis_scope.name+'/generated/'): + (discriminator_gen_outputs, discriminator_gen_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(generated_data, generator_inputs)) with variable_scope.variable_scope(dis_scope, reuse=True): - real_data = ops.convert_to_tensor(real_data) - (discriminator_real_outputs, discriminator_real_classification_logits - ) = _validate_acgan_discriminator_outputs( - discriminator_fn(real_data, generator_inputs)) + with ops.name_scope(dis_scope.name+'/real/'): + real_data = ops.convert_to_tensor(real_data) + (discriminator_real_outputs, discriminator_real_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(real_data, generator_inputs)) if check_shapes: if not generated_data.shape.is_compatible_with(real_data.shape): raise ValueError( @@ -304,6 +309,76 @@ def acgan_model( discriminator_gen_classification_logits) +def cyclegan_model( + # Lambdas defining models. + generator_fn, + discriminator_fn, + # data X and Y. + data_x, + data_y, + # Optional scopes. + generator_scope='Generator', + discriminator_scope='Discriminator', + model_x2y_scope='ModelX2Y', + model_y2x_scope='ModelY2X', + # Options. + check_shapes=True): + """Returns a CycleGAN model outputs and variables. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. + data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`. + data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`. + generator_scope: Optional generator variable scope. Useful if you want to + reuse a subgraph that has already been created. Defaults to 'Generator'. + discriminator_scope: Optional discriminator variable scope. Useful if you + want to reuse a subgraph that has already been created. Defaults to + 'Discriminator'. + model_x2y_scope: Optional variable scope for model x2y variables. Defaults + to 'ModelX2Y'. + model_y2x_scope: Optional variable scope for model y2x variables. Defaults + to 'ModelY2X'. + check_shapes: If `True`, check that generator produces Tensors that are the + same shape as `data_x` (`data_y`). Otherwise, skip this check. + + Returns: + A `CycleGANModel` namedtuple. + + Raises: + ValueError: If `check_shapes` is True and `data_x` or the generator output + does not have the same shape as `data_y`. + """ + + # Create models. + def _define_partial_model(input_data, output_data): + return gan_model( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + real_data=output_data, + generator_inputs=input_data, + generator_scope=generator_scope, + discriminator_scope=discriminator_scope, + check_shapes=check_shapes) + + with variable_scope.variable_scope(model_x2y_scope): + model_x2y = _define_partial_model(data_x, data_y) + with variable_scope.variable_scope(model_y2x_scope): + model_y2x = _define_partial_model(data_y, data_x) + + with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True): + reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data) + with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True): + reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data) + + return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x, + reconstructed_y) + + def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'): if isinstance(aux_loss_weight, ops.Tensor): aux_loss_weight.shape.assert_is_compatible_with([]) @@ -326,6 +401,56 @@ def _use_aux_loss(aux_loss_weight): return False +def _tensor_pool_adjusted_model(model, tensor_pool_fn): + """Adjusts model using `tensor_pool_fn`. + + Args: + model: A GANModel tuple. + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns a previously stored + (generated_data, generator_inputs) with some probability. For example + tfgan.features.tensor_pool. + + Returns: + A new GANModel tuple where discriminator outputs are adjusted by taking + pooled generator outputs as inputs. Returns the original model if + `tensor_pool_fn` is None. + + Raises: + ValueError: If tensor pool does not support the `model`. + """ + if tensor_pool_fn is None: + return model + + pooled_generated_data, pooled_generator_inputs = tensor_pool_fn( + (model.generated_data, model.generator_inputs)) + + if isinstance(model, namedtuples.GANModel): + with variable_scope.variable_scope(model.discriminator_scope, reuse=True): + dis_gen_outputs = model.discriminator_fn(pooled_generated_data, + pooled_generator_inputs) + return model._replace(discriminator_gen_outputs=dis_gen_outputs) + elif isinstance(model, namedtuples.ACGANModel): + with variable_scope.variable_scope(model.discriminator_scope, reuse=True): + (dis_pooled_gen_outputs, + dis_pooled_gen_classification_logits) = model.discriminator_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + discriminator_gen_classification_logits= + dis_pooled_gen_classification_logits) + elif isinstance(model, namedtuples.InfoGANModel): + with variable_scope.variable_scope(model.discriminator_scope, reuse=True): + (dis_pooled_gen_outputs, + pooled_predicted_distributions) = model.discriminator_and_aux_fn( + pooled_generated_data, pooled_generator_inputs) + return model._replace( + discriminator_gen_outputs=dis_pooled_gen_outputs, + predicted_distributions=pooled_predicted_distributions) + else: + raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) + + def gan_loss( # GANModel. model, @@ -335,9 +460,11 @@ def gan_loss( # Auxiliary losses. gradient_penalty_weight=None, gradient_penalty_epsilon=1e-10, + gradient_penalty_target=1.0, mutual_information_penalty_weight=None, aux_cond_generator_weight=None, aux_cond_discriminator_weight=None, + tensor_pool_fn=None, # Options. add_summaries=True): """Returns losses necessary to train generator and discriminator. @@ -355,6 +482,9 @@ def gan_loss( small positive value used by the gradient penalty function for numerical stability. Note some applications will need to increase this value to avoid NaNs. + gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python + number or `Tensor` indicating the target value of gradient norm. See the + CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. mutual_information_penalty_weight: If not `None`, must be a non-negative Python number or Tensor indicating how much to weight the mutual information penalty. See https://arxiv.org/abs/1606.03657 for more @@ -363,6 +493,10 @@ def gan_loss( https://arxiv.org/abs/1610.09585 aux_cond_discriminator_weight: If not None: add a classification loss as in https://arxiv.org/abs/1610.09585 + tensor_pool_fn: A function that takes (generated_data, generator_inputs), + stores them in an internal pool and returns previous stored + (generated_data, generator_inputs). For example + `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool). add_summaries: Whether or not to add summaries for the losses. Returns: @@ -402,12 +536,17 @@ def gan_loss( # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) - dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + dis_loss = discriminator_loss_fn( + _tensor_pool_adjusted_model(model, tensor_pool_fn), + add_summaries=add_summaries) # Add optional extra losses. if _use_aux_loss(gradient_penalty_weight): gp_loss = tfgan_losses.wasserstein_gradient_penalty( - model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries) + model, + epsilon=gradient_penalty_epsilon, + target=gradient_penalty_target, + add_summaries=add_summaries) dis_loss += gradient_penalty_weight * gp_loss if _use_aux_loss(mutual_information_penalty_weight): info_loss = tfgan_losses.mutual_information_penalty( @@ -436,6 +575,69 @@ def gan_loss( return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss) +def cyclegan_loss( + model, + # Loss functions. + generator_loss_fn=tfgan_losses.least_squares_generator_loss, + discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss, + # Auxiliary losses. + cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss, + cycle_consistency_loss_weight=10.0, + # Options + **kwargs): + """Returns the losses for a `CycleGANModel`. + + See https://arxiv.org/abs/1703.10593 for more details. + + Args: + model: A `CycleGANModel` namedtuple. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + named tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` namedtuple. + cycle_consistency_loss_fn: The cycle consistency loss function. Takes a + `CycleGANModel` namedtuple. + cycle_consistency_loss_weight: A non-negative Python number or a scalar + `Tensor` indicating how much to weigh the cycle consistency loss. + **kwargs: Keyword args to pass directly to `gan_loss` to construct the loss + for each partial model of `model`. + + Returns: + A `CycleGANLoss` namedtuple. + + Raises: + ValueError: If `model` is not a `CycleGANModel` namedtuple. + """ + # Sanity checks. + if not isinstance(model, namedtuples.CycleGANModel): + raise ValueError( + '`model` must be a `CycleGANModel`. Instead, was %s.' % type(model)) + + # Defines cycle consistency loss. + cycle_consistency_loss = cycle_consistency_loss_fn( + model, add_summaries=kwargs.get('add_summaries', True)) + cycle_consistency_loss_weight = _validate_aux_loss_weight( + cycle_consistency_loss_weight, 'cycle_consistency_loss_weight') + aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss + + # Defines losses for each partial model. + def _partial_loss(partial_model): + partial_loss = gan_loss( + partial_model, + generator_loss_fn=generator_loss_fn, + discriminator_loss_fn=discriminator_loss_fn, + **kwargs) + return partial_loss._replace( + generator_loss=partial_loss.generator_loss + aux_loss) + + with ops.name_scope('cyclegan_loss_x2y'): + loss_x2y = _partial_loss(model.model_x2y) + with ops.name_scope('cyclegan_loss_y2x'): + loss_y2x = _partial_loss(model.model_y2x) + + return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) + + def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): """Gets generator and discriminator update ops. @@ -503,6 +705,24 @@ def gan_train_ops( A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can be used to train a generator/discriminator pair. """ + if isinstance(model, namedtuples.CycleGANModel): + saved_params = locals() + saved_params.pop('model', None) + saved_params.pop('loss', None) + kwargs = saved_params.pop('kwargs', {}) + saved_params.update(kwargs) + with ops.name_scope('cyclegan_x2y_train'): + train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y, + **saved_params) + with ops.name_scope('cyclegan_y2x_train'): + train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x, + **saved_params) + return namedtuples.GANTrainOps( + (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op), + (train_ops_x2y.discriminator_train_op, + train_ops_y2x.discriminator_train_op), + training_util.get_or_create_global_step().assign_add(1)) + # Create global step increment op. global_step = training_util.get_or_create_global_step() global_step_inc = global_step.assign_add(1) diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 6b27b6926102b6e5a7ff134ceed75c23459a6534..f9bdaa74c948ecee11d5cfd89f06087924f8dace 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.contrib.framework.python.ops import variables as variables_lib from tensorflow.contrib.gan.python import namedtuples from tensorflow.contrib.gan.python import train +from tensorflow.contrib.gan.python.features.python import random_tensor_pool from tensorflow.contrib.slim.python.slim import learning as slim_learning from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -145,14 +146,16 @@ def get_infogan_model(): return namedtuples.InfoGANModel( *get_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def get_callable_infogan_model(): return namedtuples.InfoGANModel( *get_callable_gan_model(), structured_generator_inputs=[constant_op.constant(0)], - predicted_distributions=[categorical.Categorical([1.0])]) + predicted_distributions=[categorical.Categorical([1.0])], + discriminator_and_aux_fn=infogan_discriminator_model) def create_infogan_model(): @@ -207,12 +210,63 @@ def create_callable_acgan_model(): one_hot_labels=array_ops.one_hot([0, 1, 2], 10)) +def get_cyclegan_model(): + return namedtuples.CycleGANModel( + model_x2y=get_gan_model(), + model_y2x=get_gan_model(), + reconstructed_x=array_ops.ones([1, 2, 3]), + reconstructed_y=array_ops.zeros([1, 2, 3])) + + +def get_callable_cyclegan_model(): + return namedtuples.CycleGANModel( + model_x2y=get_callable_gan_model(), + model_y2x=get_callable_gan_model(), + reconstructed_x=array_ops.ones([1, 2, 3]), + reconstructed_y=array_ops.zeros([1, 2, 3])) + + +def create_cyclegan_model(): + return train.cyclegan_model( + generator_model, + discriminator_model, + data_x=array_ops.zeros([1, 2]), + data_y=array_ops.ones([1, 2])) + + +def create_callable_cyclegan_model(): + return train.cyclegan_model( + Generator(), + Discriminator(), + data_x=array_ops.zeros([1, 2]), + data_y=array_ops.ones([1, 2])) + + def get_sync_optimizer(): return sync_replicas_optimizer.SyncReplicasOptimizer( gradient_descent.GradientDescentOptimizer(learning_rate=1.0), replicas_to_aggregate=1) +def get_tensor_pool_fn(pool_size): + + def tensor_pool_fn_impl(input_values): + return random_tensor_pool.tensor_pool(input_values, pool_size=pool_size) + + return tensor_pool_fn_impl + + +def get_tensor_pool_fn_for_infogan(pool_size): + + def tensor_pool_fn_impl(input_values): + generated_data, generator_inputs = input_values + output_values = random_tensor_pool.tensor_pool( + [generated_data] + generator_inputs, pool_size=pool_size) + return output_values[0], output_values[1:] + + return tensor_pool_fn_impl + + class GANModelTest(test.TestCase): """Tests for `gan_model`.""" @@ -239,6 +293,13 @@ class GANModelTest(test.TestCase): self._test_output_type_helper( get_callable_acgan_model, namedtuples.ACGANModel) + def test_output_type_cyclegan(self): + self._test_output_type_helper(get_cyclegan_model, namedtuples.CycleGANModel) + + def test_output_type_callable_cyclegan(self): + self._test_output_type_helper(get_callable_cyclegan_model, + namedtuples.CycleGANModel) + def test_no_shape_check(self): def dummy_generator_model(_): return (None, None) @@ -286,6 +347,17 @@ class GANLossTest(test.TestCase): def test_output_type_callable_acgan(self): self._test_output_type_helper(get_callable_acgan_model) + def test_output_type_cyclegan(self): + loss = train.cyclegan_loss(create_cyclegan_model(), add_summaries=True) + self.assertIsInstance(loss, namedtuples.CycleGANLoss) + self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + + def test_output_type_callable_cyclegan(self): + loss = train.cyclegan_loss( + create_callable_cyclegan_model(), add_summaries=True) + self.assertIsInstance(loss, namedtuples.CycleGANLoss) + self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + # Test gradient penalty option. def _test_grad_penalty_helper(self, create_gan_model_fn): model = create_gan_model_fn() @@ -409,6 +481,142 @@ class GANLossTest(test.TestCase): def test_callable_acgan(self): self._test_acgan_helper(create_callable_acgan_model) + # Test that CycleGan models work. + def _test_cyclegan_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.cyclegan_loss(model) + self.assertIsInstance(loss, namedtuples.CycleGANLoss) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + (loss_x2y_gen_np, loss_x2y_dis_np, loss_y2x_gen_np, + loss_y2x_dis_np) = sess.run([ + loss.loss_x2y.generator_loss, loss.loss_x2y.discriminator_loss, + loss.loss_y2x.generator_loss, loss.loss_y2x.discriminator_loss + ]) + + self.assertGreater(loss_x2y_gen_np, loss_x2y_dis_np) + self.assertGreater(loss_y2x_gen_np, loss_y2x_dis_np) + self.assertTrue(np.isscalar(loss_x2y_gen_np)) + self.assertTrue(np.isscalar(loss_x2y_dis_np)) + self.assertTrue(np.isscalar(loss_y2x_gen_np)) + self.assertTrue(np.isscalar(loss_y2x_dis_np)) + + def test_cyclegan(self): + self._test_cyclegan_helper(create_cyclegan_model) + + def test_callable_cyclegan(self): + self._test_cyclegan_helper(create_callable_cyclegan_model) + + def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2, + pool_size): + history_values = [] + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for i in range(2 * pool_size): + t1, t2 = sess.run([tensor1, tensor2]) + history_values.append(t1) + if i < pool_size: + # For [0, pool_size), the pool is not full, tensor1 should be equal + # to tensor2 as the pool. + self.assertAllEqual(t1, t2) + else: + # For [pool_size, ?), the pool is full, tensor2 must be equal to some + # historical values of tensor1 (which is previously stored in the + # pool). + self.assertTrue(any([(v == t2).all() for v in history_values])) + + # Test `_tensor_pool_adjusted_model` for gan model. + def test_tensor_pool_adjusted_model_gan(self): + model = create_gan_model() + + new_model = train._tensor_pool_adjusted_model(model, None) + # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) + self.assertIs(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + + pool_size = 5 + new_model = train._tensor_pool_adjusted_model( + model, get_tensor_pool_fn(pool_size=pool_size)) + self.assertIsNot(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + # Check values. + self._check_tensor_pool_adjusted_model_outputs( + model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, + pool_size) + + # Test _tensor_pool_adjusted_model for infogan model. + def test_tensor_pool_adjusted_model_infogan(self): + model = create_infogan_model() + + pool_size = 5 + new_model = train._tensor_pool_adjusted_model( + model, get_tensor_pool_fn_for_infogan(pool_size=pool_size)) + # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) + self.assertIsNot(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + self.assertIsNot(new_model.predicted_distributions, + model.predicted_distributions) + # Check values. + self._check_tensor_pool_adjusted_model_outputs( + model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, + pool_size) + + # Test _tensor_pool_adjusted_model for acgan model. + def test_tensor_pool_adjusted_model_acgan(self): + model = create_acgan_model() + + pool_size = 5 + new_model = train._tensor_pool_adjusted_model( + model, get_tensor_pool_fn(pool_size=pool_size)) + # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) + self.assertIsNot(new_model.discriminator_gen_outputs, + model.discriminator_gen_outputs) + self.assertIsNot(new_model.discriminator_gen_classification_logits, + model.discriminator_gen_classification_logits) + # Check values. + self._check_tensor_pool_adjusted_model_outputs( + model.discriminator_gen_outputs, new_model.discriminator_gen_outputs, + pool_size) + + # Test tensor pool. + def _test_tensor_pool_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + if isinstance(model, namedtuples.InfoGANModel): + tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5) + else: + tensor_pool_fn = get_tensor_pool_fn(pool_size=5) + loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn) + self.assertTrue(isinstance(loss, namedtuples.GANLoss)) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + for _ in range(10): + sess.run([loss.generator_loss, loss.discriminator_loss]) + + def test_tensor_pool_gan(self): + self._test_tensor_pool_helper(create_gan_model) + + def test_tensor_pool_callable_gan(self): + self._test_tensor_pool_helper(create_callable_gan_model) + + def test_tensor_pool_infogan(self): + self._test_tensor_pool_helper(create_infogan_model) + + def test_tensor_pool_callable_infogan(self): + self._test_tensor_pool_helper(create_callable_infogan_model) + + def test_tensor_pool_acgan(self): + self._test_tensor_pool_helper(create_acgan_model) + + def test_tensor_pool_callable_acgan(self): + self._test_tensor_pool_helper(create_callable_acgan_model) + def test_doesnt_crash_when_in_nested_scope(self): with variable_scope.variable_scope('outer_scope'): gan_model = train.gan_model( diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index bdbe6f0a72621e59562fe113da101ff5a2b8c06d..707ae25d485c64f15694ee0e357f32b619d3cd33 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -82,6 +82,7 @@ tf_cuda_library( "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime:graph_mgr", + "//tensorflow/core/distributed_runtime:recent_request_ids", "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", "//tensorflow/core/distributed_runtime:worker", "//tensorflow/core/distributed_runtime:worker_cache", @@ -103,6 +104,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:request_id", "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_env", diff --git a/tensorflow/contrib/gdr/README.md b/tensorflow/contrib/gdr/README.md index 34ce60b360822888aa6223c89362ae1b0d9d991f..8242d93f129904828a11b61d48f2df8fb0f88bc3 100644 --- a/tensorflow/contrib/gdr/README.md +++ b/tensorflow/contrib/gdr/README.md @@ -119,4 +119,4 @@ In the original design (as in the reference), tensor buffers are only registered Reference === -Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3123907 +Bairen Yi, Jiacheng Xia, Li Chen, and Kai Chen. 2017. Towards Zero Copy Dataflows using RDMA. In Proceedings of SIGCOMM Posters and Demos'17, Los Angeles, CA, USA, August 22-24, 2017, 3 pages. https://doi.org/10.1145/3123878.3131975 diff --git a/tensorflow/contrib/gdr/gdr_memory_manager.cc b/tensorflow/contrib/gdr/gdr_memory_manager.cc index 5c7ac744289ab7729b4cc43ab9bedc9342284e65..81e70ae30a4c72dbcedd1aabfe758ecca4c8b366 100644 --- a/tensorflow/contrib/gdr/gdr_memory_manager.cc +++ b/tensorflow/contrib/gdr/gdr_memory_manager.cc @@ -86,8 +86,9 @@ int TryToReadNumaNode(ibv_device* device) { if (strings::safe_strto32(content, &value)) { if (value < 0) { LOG(INFO) << "Successful NUMA node read from SysFS had negative value (" - << value << "), but there must be at least one NUMA node" - ", so returning NUMA node zero"; + << value + << "), but there must be at least one NUMA node" + ", so returning NUMA node zero"; return 0; } LOG(INFO) << "NUMA node for device: " << device->name << " is " << value; @@ -290,8 +291,8 @@ Status GdrMemoryManager::Init() { // Host memory allocators for (Allocator* allocator : allocators) { auto* visitable_allocator = dynamic_cast(allocator); - CHECK(visitable_allocator) << "is not visitable for instrumentation" - << allocator->Name(); + CHECK(visitable_allocator) + << "is not visitable for instrumentation" << allocator->Name(); // Make sure we don't instrument the same allocator twice if (instrumented_.find(allocator) == std::end(instrumented_)) { visitable_allocator->AddAllocVisitor(alloc_visitor); @@ -635,8 +636,8 @@ void GdrMemoryManager::TensorFromTransportOptions( } else { checksum = GPUUtil::Checksum(*tensor); } - CHECK(checksum == remote_mr.checksum()) << "Checksum mismatch: " << checksum - << "!=" << remote_mr.checksum(); + CHECK(checksum == remote_mr.checksum()) + << "Checksum mismatch: " << checksum << "!=" << remote_mr.checksum(); #endif } done(Status::OK()); diff --git a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc index adef2aac33e3e0839a268eabe2496e58861535c5..28f68cec8cce126f1b177a73e197ccd7ab749f4a 100644 --- a/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc +++ b/tensorflow/contrib/gdr/gdr_rendezvous_mgr.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/distributed_runtime/request_id.h" #include "tensorflow/core/distributed_runtime/tensor_coding.h" #include "tensorflow/core/distributed_runtime/worker_cache.h" #include "tensorflow/core/distributed_runtime/worker_interface.h" @@ -47,6 +48,7 @@ class GdrRecvTensorCall : public BaseRecvTensorCall { recv_args_(recv_args) { req_.set_step_id(step_id); req_.set_rendezvous_key(key.data(), key.size()); + req_.set_request_id(GetUniqueRequestId()); } ~GdrRecvTensorCall() override {} diff --git a/tensorflow/contrib/gdr/gdr_worker.cc b/tensorflow/contrib/gdr/gdr_worker.cc index 568641234731a458a05886d12066ee9f55fa58aa..ce1d8d2d73000559f03046aceacb169890ecc1b6 100644 --- a/tensorflow/contrib/gdr/gdr_worker.cc +++ b/tensorflow/contrib/gdr/gdr_worker.cc @@ -41,17 +41,26 @@ namespace tensorflow { GdrWorker::GdrWorker(WorkerEnv* worker_env, RemoteMemoryManager* remote_memory_manager) - : GrpcWorker(worker_env), remote_memory_manager_(remote_memory_manager) {} + : GrpcWorker(worker_env), + remote_memory_manager_(remote_memory_manager), + recv_tensor_recent_request_ids_(100000) {} void GdrWorker::GrpcRecvTensorAsync(CallOptions* opts, const RecvTensorRequest* request, ::grpc::ByteBuffer* response, StatusCallback done) { + Status s = recv_tensor_recent_request_ids_.TrackUnique( + request->request_id(), "RecvTensor (GdrWorker)", *request); + if (!s.ok()) { + done(s); + return; + } + const int64 step_id = request->step_id(); const string& key = request->rendezvous_key(); TRACEPRINTF("RecvTensor: %lld %s", step_id, key.c_str()); Rendezvous::ParsedKey parsed; - Status s = Rendezvous::ParseKey(key, &parsed); + s = Rendezvous::ParseKey(key, &parsed); Device* src_dev = nullptr; if (s.ok()) { s = PrepareRecvTensor(parsed, &src_dev); diff --git a/tensorflow/contrib/gdr/gdr_worker.h b/tensorflow/contrib/gdr/gdr_worker.h index a30b7baaedcbc80d93d7f37756732c37d2435935..54081f655ec087d78ac07974656257dcf478bcef 100644 --- a/tensorflow/contrib/gdr/gdr_worker.h +++ b/tensorflow/contrib/gdr/gdr_worker.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/contrib/gdr/gdr_memory_manager.h" +#include "tensorflow/core/distributed_runtime/recent_request_ids.h" #include "tensorflow/core/distributed_runtime/rpc/grpc_worker_service.h" namespace tensorflow { @@ -38,6 +39,7 @@ class GdrWorker : public GrpcWorker { private: RemoteMemoryManager* remote_memory_manager_; // Not owned + RecentRequestIds recv_tensor_recent_request_ids_; }; } // namespace tensorflow diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 2a97a79070ea3a0e634d76c5877e2307b6e2e577..14ac5296657d48c7f9e94d220c9e7e28af4d4353 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -173,6 +173,9 @@ def copy_op_handler(info, op, copy_shape=True): if op._original_op: op_._original_op = op._original_op + # Add op to the graph + info.graph_._add_op(op_) + return op_, op_.outputs diff --git a/tensorflow/contrib/hvx/README.md b/tensorflow/contrib/hvx/README.md index 5a6f2f3086d708e5264b0483c211902ac8dce5f6..163993a3f6bb1bedcdffb32944a98c7cc846878e 100644 --- a/tensorflow/contrib/hvx/README.md +++ b/tensorflow/contrib/hvx/README.md @@ -1,60 +1,67 @@ # TensorFlow Runtime with HVX Acceleration -## Description +This README explain how to build and use the TensorFlow runtime with HVX Acceleration. HVX is an extension of Hexagon, a DSP provided by Qualcomm, which can compute vector calculations faster using less energy than ARM processors. -This README explain how to build and use the TensorFlow Runtime with HVX Acceleration. HVX is an extension of Hexagon which is a DSP provided by qualcomm which can compute vector calculations faster using lower energy than ARM processors. +## Dependencies + +* [Android SDK](https://developer.android.com/studio/index.html). +* [Android NDK](https://developer.android.com/ndk/index.html). Save the path in `${NDK_ROOT}`. +* A rooted Qualcomm-based Android device connected to the computer (preferably, a [Snapdragon Development Board](https://developer.qualcomm.com/hardware/additional-snapdragon), but it could be a rooted phone with a Qualcomm SoC, albeit this guide may not work with it). The device needs to be rooted for development and testing purposes, and shouldn't be needed in production. See [Behold, The Snapdragon MDP](https://developer.qualcomm.com/blog/behold-snapdragon-mdp) for more information. +* [Hexagon SDK v3.0](https://developer.qualcomm.com/software/hexagon-dsp-sdk/tools). Save the path in `${QUALCOMM_SDK}`. +* The current directory should be TensorFlow source code (`git clone https://github.com/tensorflow/tensorflow.git && cd tensorflow`), and saved into `${TF_ROOT_DIR}`. + +You may also need to add a test signature in the device to run HVX-based binaries. Follow the instructions in `${QUALCOMM_SDK}/docs/Tools_Signing.html`, using Python 2. + +Note that if the device is not rooted, you may not be able to get the serial number, push the test signature and/or run binary files that call HVX libraries. ## Quick Start Guide -We provides several tools to build and run inference with this runtime quickly. +We provide several tools to build and run inference with this runtime quickly. -#### All-in-one script to run inception model with prebuild hexagon library -If you don’t need to build your own implementation of hexagon HVX, we provide a shortcut to execute graphs by using pre-compiled binaries. +### Run inception model with a prebuilt Hexagon library +If you don’t need to build your own implementation of Hexagon HVX, we provide a shortcut to execute graphs by using pre-compiled binaries. + +```shell +./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh -p ``` -git clone https://github.com/tensorflow/tensorflow.git -cd tensorflow -NDK_ROOT="/path/to/ndk" ./tensorflow/contrib/makefile/build_all_android.sh -X -``` -(-X downloads dependencies to hexagon HVX and graphs, and copy all dependencies to android and execute a test) -#### All-in-one script to run inception model by building entire libraries from source code - If you want to build your own implementation of hexagon HVX, we provide a sample all-in-one script to execute graphs which downloads source and build everything for hexagon. +The `-p` option makes the script download dependencies (i.e., Hexagon HVX binaries and graphs models), copy them to the Android device and execute a test. -``` -git clone https://github.com/tensorflow/tensorflow.git -cd tensorflow -QUALCOMM_SDK="/path/to/qualcomm/sdk" NDK_ROOT="/path/to/ndk" ./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh +### Run inception model by building all from the source code + +If you want to build your own implementation of Hexagon HVX, we provide a sample all-in-one script to execute graphs which downloads the source and builds everything that's necessary. + +```shell +./tensorflow/contrib/makefile/samples/build_and_run_inception_hexagon.sh ``` ## Building libraries If you've finished walking through the quick start guide, you may want to try building each binary manually. -#### Build libhexagon_nn_skel.so -Download hexagon nn library from codeaurora.org and build it. +### Build libhexagon\_nn\_skel.so -``` +Download Hexagon NN library from codeaurora.org and build it. + +```shell git clone https://source.codeaurora.org/quic/hexagon_nn/nnlib cd nnlib ``` -(Just follow instructions in README.HOW_TO_BUILD. You can find libhexagon_nn_skel.so in hexagon_Release_dynamic_toolv72_v60/ship) -Then copy the generated binary to GEN_LIBS_DIR +Just follow the instructions in `README.HOW_TO_BUILD`. You can find the file `libhexagon_nn_skel.so` in `hexagon_Release_dynamic_toolv72_v60/ship`. +Then copy the generated binary to `${GEN_LIBS_DIR}`. -``` +```shell GEN_LIBS_DIR="/path/to/a/dir/to/store/hexagon/libraries" cp -v "hexagon_Release_dynamic_toolv72_v60/ship/libhexagon_nn_skel.so" "${GEN_LIBS_DIR}" ``` -#### Build libhexagon_controller.so +### Build libhexagon\_controller.so + Download tensorflow and build hexagon controller. -``` -git clone https://github.com/tensorflow/tensorflow.git -cd tensorflow -TF_ROOT_DIR="$(pwd)" -QUALCOMM_SDK="/path/to/qualcomm/sdk" +```shell GENERATED_NNLIB_DIRECTORY="/path/to/nnlib" GENERATED_HEXAGON_CONTROLLER_DIRECTORY="${QUALCOMM_SDK}/examples/common/generated_hexagon_controller" rm -rf "${GENERATED_HEXAGON_CONTROLLER_DIRECTORY}" @@ -70,12 +77,12 @@ make tree VERBOSE=1 V=android_Release cp -v "${GENERATED_HEXAGON_CONTROLLER_DIRECTORY}/android_Release/ship/libhexagon_controller.so" "${GEN_LIBS_DIR}" ``` -#### Build tensorflow linking hexagon library -Build tensorflow with the build_all_android.sh with specifying -x option. +### Build TensorFlow linking Hexagon library -``` +Build TensorFlow with `build_all_android.sh` specifying the `-x` option. + +```shell BUILD_ALL_ANDROID_PATH="${TF_ROOT_DIR}/tensorflow/contrib/makefile/build_all_android.sh" -NDK_ROOT="/path/to/ndk/root" CC_PREFIX=${CC_PREFIX} NDK_ROOT=${NDK_ROOT} "${BUILD_ALL_ANDROID_PATH}" \ -x "${GEN_LIBS_DIR}" \ @@ -83,11 +90,11 @@ CC_PREFIX=${CC_PREFIX} NDK_ROOT=${NDK_ROOT} "${BUILD_ALL_ANDROID_PATH}" \ -t hexagon_graph_execution ``` -#### Push binaries to your Android device +### Push binaries to your Android device Before running tests on your Android device, you need to push several binaries to it. -``` +```shell adb push "${GEN_LIBS_DIR}/libhexagon_controller.so" "/data/local/tmp" adb push "${GEN_LIBS_DIR}/libhexagon_nn_skel.so" "/vendor/lib/rfsa/adsp" adb push -p \ @@ -100,40 +107,54 @@ adb shell chmod "${ANDROID_EXEC_FILE_MODE}" \ adb wait-for-device ``` -#### Run tests on the device +### Run tests on the device Finally, you can run the inference tests on your device. -``` +```shell adb shell 'LD_LIBRARY_PATH=/data/local/tmp:$LD_LIBRARY_PATH' \ "/data/local/tmp/hexagon_graph_execution" ``` -#### Troubleshooting -If you're using the Open-Q 820 Snapdragon development kit, you may run into an issue with running the executable due to a missing testsig library. From the Hexagon SDK documentation: *Dynamic shared objects are required to be digitally signed and then authenticated at runtime before they are allowed to be loaded and executed.* Generating a testsig library is necessary to run the unsigned sample library built from this project. +### Troubleshooting + +#### Testsig issue + +If you're using the Open-Q 820 Snapdragon Development Kit, you may run into an issue with running the executable due to a missing `testsig` library. From the Hexagon SDK documentation: *Dynamic shared objects are required to be digitally signed and then authenticated at runtime before they are allowed to be loaded and executed.* Generating a testsig library is necessary to run the unsigned sample library built from this project. -If the lack of a testsig library is your problem, you will see errors of the type: +If the lack of a `testsig` library is your problem, you will see errors of the type: `vendor/qcom/proprietary/adsprpc/src/fastrpc_apps_user.c:169::error: -1: 0 == (nErr = remotectl_open(name, (int*)ph, dlerrstr, sizeof(dlerrstr), &dlerr))` -appearing in adb logcat. - -There are several ways to create the testsig library, the only prerequisite is Python and the correct version of the Hexagon-SDK. The following steps is one way to create this library: -1. Run adb as root: `adb root` -2. Run the command `adb shell cat /sys/devices/soc0/serial_number` -3. Convert the decimal number you get as output to hex -4. Run the python script: `python ${QUALCOMM_SDK}/tools/elfsigner/elfsigner.py -t $(SERIAL_NUMBER_HEX_VALUE)` -5. The output of the python script is a shared library stored in ${QUALCOMM_SDK}/tools/elfsigner/output/testsig-$(SERIAL_NUMBER_HEX_VALUE).so -6. Push the shared library to your device: +appearing in `adb logcat` or ["Expected: (version) >= (1), actual: 0 vs 1" while running a binary from adb](https://github.com/tensorflow/tensorflow/issues/11210). + +You need to add a test signature, as described at the beginning of this README. After rebooting your device, you should be able to run the sample application. + +#### Qualcomm SDK Linux installation fails with "Malformed \uxxxx encoding" + +The installation file is based on LaunchAnywhere, which fails in Linux if the `PS1` env variable contains non-common Unicode chars: + ``` -adb root -adb wait-for-device -adb remount -adb wait-for-device -adb shell mkdir /system/lib/rfsa -adb shell mkdir /system/lib/rfsa/adsp -adb push ${QUALCOMM_SDK}/tools/elfsigner/output/testsig-$(SERIAL_NUMBER_HEX_VALUE).so /system/lib/rfsa/adsp/ +Preparing to install... +Extracting the JRE from the installer archive... +Unpacking the JRE... +Extracting the installation resources from the installer archive... +Configuring the installer for this system's environment... + +Launching installer... + +An internal LaunchAnywhere application error has occurred and this application cannot proceed. (LAX) + +Stack Trace: +java.lang.IllegalArgumentException: Malformed \uxxxx encoding. + at java.util.Properties.loadConvert(Properties.java:574) + at java.util.Properties.load0(Properties.java:391) + at java.util.Properties.load(Properties.java:317) + at com.zerog.common.java.util.PropertiesUtil.loadProperties(Unknown Source) + at com.zerog.lax.LAX.(Unknown Source) + at com.zerog.lax.LAX.main(Unknown Source) ``` -After rebooting your device, you should be able to run the sample application. +It can be solved by temporarily assigning the `PS1` environment variable to something simple, such as '$'. + +## Maintainers -Maintainers: -- Satoshi Kataoka (satok@google.com, github.com/satok16) +* Satoshi Kataoka (satok@google.com, github.com/satok16) diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index 157e97d237021d95c935a6be66aa57842b97125c..3ff02e085ee63fabf42b3cc4389f4605455f3800 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -9,10 +9,12 @@ package(default_visibility = ["//visibility:public"]) load( "//tensorflow:tensorflow.bzl", + "tf_cc_test", "tf_custom_op_library", "tf_gen_op_libs", "tf_gen_op_wrapper_py", "tf_kernel_library", + "tf_py_test", ) load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") @@ -23,6 +25,8 @@ tf_custom_op_library( "kernels/bipartite_match_op.cc", "kernels/image_ops.cc", "kernels/image_ops.h", + "kernels/segmentation_ops.cc", + "kernels/segmentation_ops.h", "ops/image_ops.cc", ], gpu_srcs = [ @@ -37,6 +41,8 @@ tf_kernel_library( "kernels/bipartite_match_op.cc", "kernels/image_ops.cc", "kernels/image_ops.h", + "kernels/segmentation_ops.cc", + "kernels/segmentation_ops.h", ], gpu_srcs = [ "kernels/image_ops_gpu.cu.cc", @@ -77,6 +83,7 @@ tf_custom_op_py_library( "//tensorflow/python:array_ops", "//tensorflow/python:common_shapes", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", @@ -106,10 +113,33 @@ tf_custom_op_library( name = "python/ops/_distort_image_ops.so", srcs = [ "kernels/adjust_hsv_in_yiq_op.cc", + "kernels/adjust_hsv_in_yiq_op.h", "ops/distort_image_ops.cc", ], + gpu_srcs = [ + "kernels/adjust_hsv_in_yiq_op_gpu.cu.cc", + "kernels/adjust_hsv_in_yiq_op.h", + ], deps = [ - "@protobuf_archive//:protobuf", + "//tensorflow/core/kernels:gpu_util_hdrs", + ], +) + +tf_cc_test( + name = "adjust_hsv_in_yiq_op_test", + size = "small", + srcs = [ + "kernels/adjust_hsv_in_yiq_op.h", + "kernels/adjust_hsv_in_yiq_op_test.cc", + ], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", ], ) @@ -122,19 +152,6 @@ tf_gen_op_wrapper_py( deps = [":distort_image_ops_op_lib"], ) -cc_library( - name = "distort_image_ops_cc", - srcs = [ - "kernels/adjust_hsv_in_yiq_op.cc", - ], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//third_party/eigen3", - ], - alwayslink = 1, -) - py_library( name = "distort_image_py", srcs = [ @@ -177,6 +194,21 @@ cuda_py_test( ], ) +tf_py_test( + name = "segmentation_test", + size = "medium", + srcs = ["python/kernel_tests/segmentation_test.py"], + additional_deps = [ + ":distort_image_py", + ":image_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_custom_op_library( name = "python/ops/_single_image_random_dot_stereograms.so", srcs = [ @@ -222,6 +254,23 @@ py_library( ], ) +cuda_py_test( + name = "single_image_random_dot_stereograms_ops_test", + size = "medium", + srcs = ["python/kernel_tests/single_image_random_dot_stereograms_ops_test.py"], + additional_deps = [ + ":distort_image_py", + ":image_py", + ":single_image_random_dot_stereograms_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index d030dffadeb9d67f7ffcbc197a2a3feb9b3b122d..cc8ed117ba2edcc7a53e609381166f17a2fbb45e 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -20,6 +20,8 @@ This module provides functions for image manipulation; currently, chrominance transformas (including changing saturation and hue) in YIQ space and projective transforms (including rotation) are supported. +## Image Transformation `Ops` + @@angles_to_projective_transforms @@compose_transforms @@adjust_yiq_hsv @@ -28,19 +30,29 @@ projective transforms (including rotation) are supported. @@transform @@translate @@translations_to_projective_transforms + +## Image Segmentation `Ops` + +@@connected_components + +## Matching `Ops` + @@bipartite_match + +## Random Dot Stereogram `Ops` + @@single_image_random_dot_stereograms """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=line-too-long from tensorflow.contrib.image.python.ops.distort_image_ops import adjust_hsv_in_yiq from tensorflow.contrib.image.python.ops.distort_image_ops import random_hsv_in_yiq from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_transforms from tensorflow.contrib.image.python.ops.image_ops import compose_transforms +from tensorflow.contrib.image.python.ops.image_ops import connected_components from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform from tensorflow.contrib.image.python.ops.image_ops import translate diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc index f4962ed69dc68d4bad06ef29d7a167e0ba8ae044..478b716d88321101c971789f36c0ff8ecd3f418e 100644 --- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc @@ -12,14 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" #include -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/work_sharder.h" @@ -36,10 +37,10 @@ class AdjustHsvInYiqOpBase : public OpKernel { struct ComputeOptions { const Tensor* input = nullptr; + Tensor* output = nullptr; const Tensor* delta_h = nullptr; const Tensor* scale_s = nullptr; const Tensor* scale_v = nullptr; - Tensor* output = nullptr; int64 channel_count = 0; }; @@ -65,7 +66,7 @@ class AdjustHsvInYiqOpBase : public OpKernel { scale_v.shape().DebugString())); auto channels = input.dim_size(input.dims() - 1); OP_REQUIRES( - context, channels == 3, + context, channels == kChannelSize, errors::InvalidArgument("input must have 3 channels but instead has ", channels, " channels.")); @@ -101,53 +102,21 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { const Tensor* input = options.input; Tensor* output = options.output; const int64 channel_count = options.channel_count; - static const int kChannelSize = 3; auto input_data = input->shaped({channel_count, kChannelSize}); const float delta_h = options.delta_h->scalar()(); const float scale_s = options.scale_s->scalar()(); const float scale_v = options.scale_v->scalar()(); auto output_data = output->shaped({channel_count, kChannelSize}); + float tranformation_matrix[kChannelSize * kChannelSize] = {0}; + internal::compute_tranformation_matrix( + delta_h, scale_s, scale_v, tranformation_matrix); const int kCostPerChannel = 10; const DeviceBase::CpuWorkerThreads& worker_threads = *context->device()->tensorflow_cpu_worker_threads(); Shard(worker_threads.num_threads, worker_threads.workers, channel_count, kCostPerChannel, - [channel_count, &input_data, &output_data, delta_h, scale_s, scale_v]( + [channel_count, &input_data, &output_data, &tranformation_matrix]( int64 start_channel, int64 end_channel) { - // Using approximate linear transfomation described in: - // https://beesbuzz.biz/code/hsv_color_transforms.php - /** Get the constants from sympy - from sympy import Matrix - from sympy.abc import u, w - # Projection matrix to YIQ. http://en.wikipedia.org/wiki/YIQ - tyiq = Matrix([[0.299, 0.587, 0.114], - [0.596, -0.274, -0.322], - [0.211, -0.523, 0.312]]) - # Hue rotation matrix in YIQ space. - hue_proj = Matrix(3,3, [v, 0, 0, 0, vsu, -vsw, 0, vsw, vsu]) - m = tyiq.inv() * hue_proj * tyiq - **/ - // TODO(huangyp): directly compute the projection matrix from tyiq. - static const float t[kChannelSize][kChannelSize][kChannelSize] = { - {{.299, .701, .16862179492229}, - {.587, -.587, .329804745287403}, - {.114, -.114, -0.498426540209694}}, - {{.299, -.299, -.327963394172371}, - {.587, .413, .0346106879248821}, - {.114, -.114, .293352706247489}}, - {{.299, -.299, 1.24646136576682}, - {.587, -.587, -1.04322888291964}, - {.114, .886, -.203232482847173}}}; - float m[kChannelSize][kChannelSize] = {{0.}}; - float su = scale_s * std::cos(delta_h); - float sw = scale_s * std::sin(delta_h); - for (int q_index = 0; q_index < kChannelSize; q_index++) { - for (int p_index = 0; p_index < kChannelSize; p_index++) { - m[q_index][p_index] = scale_v * (t[q_index][p_index][0] + - t[q_index][p_index][1] * su + - t[q_index][p_index][2] * sw); - } - } // Applying projection matrix to input RGB vectors. const float* p = input_data.data() + start_channel * kChannelSize; float* q = output_data.data() + start_channel * kChannelSize; @@ -155,7 +124,9 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { for (int q_index = 0; q_index < kChannelSize; q_index++) { q[q_index] = 0; for (int p_index = 0; p_index < kChannelSize; p_index++) { - q[q_index] += m[q_index][p_index] * p[p_index]; + q[q_index] += + p[p_index] * + tranformation_matrix[q_index + kChannelSize * p_index]; } } p += kChannelSize; @@ -165,8 +136,33 @@ class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { } }; -REGISTER_KERNEL_BUILDER(Name("AdjustHsvInYiq").Device(DEVICE_CPU), - AdjustHsvInYiqOp); +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_CPU).TypeConstraint("T"), + AdjustHsvInYiqOp); + +#if GOOGLE_CUDA +template <> +class AdjustHsvInYiqOp : public AdjustHsvInYiqOpBase { + public: + explicit AdjustHsvInYiqOp(OpKernelConstruction* context) + : AdjustHsvInYiqOpBase(context) {} + + void DoCompute(OpKernelContext* ctx, const ComputeOptions& options) override { + const int64 number_of_elements = options.input->NumElements(); + if (number_of_elements <= 0) { + return; + } + const float* delta_h = options.delta_h->flat().data(); + const float* scale_s = options.scale_s->flat().data(); + const float* scale_v = options.scale_v->flat().data(); + functor::AdjustHsvInYiqGPU()(ctx, options.channel_count, options.input, + delta_h, scale_s, scale_v, options.output); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("AdjustHsvInYiq").Device(DEVICE_GPU).TypeConstraint("T"), + AdjustHsvInYiqOp); +#endif -// TODO(huangyp): add the GPU kernel } // namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h new file mode 100644 index 0000000000000000000000000000000000000000..8968da6d8241ca7cd548910a024a618913c3ed70 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ +#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ + +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#endif // GOOGLE_CUDA + +#include +#include "third_party/eigen3/Eigen/Core" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { + +static constexpr int kChannelSize = 3; + +namespace internal { + +template +EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void compute_tranformation_matrix( + const float delta_h, const float scale_s, const float scale_v, + float* matrix) { + static_assert(MATRIX_SIZE == kChannelSize * kChannelSize, + "Size of matrix should be 9."); + // Projection matrix from RGB to YIQ. Numbers from wikipedia + // https://en.wikipedia.org/wiki/YIQ + Eigen::Matrix3f yiq; + /* clang-format off */ + yiq << 0.299, 0.587, 0.114, + 0.596, -0.274, -0.322, + 0.211, -0.523, 0.312; + Eigen::Matrix3f yiq_inverse; + yiq_inverse << 1, 0.95617069, 0.62143257, + 1, -0.2726886, -0.64681324, + 1, -1.103744, 1.70062309; + /* clang-format on */ + // Construct hsv linear transformation matrix in YIQ space. + // https://beesbuzz.biz/code/hsv_color_transforms.php + float vsu = scale_v * scale_s * std::cos(delta_h); + float vsw = scale_v * scale_s * std::sin(delta_h); + Eigen::Matrix3f hsv_transform; + /* clang-format off */ + hsv_transform << scale_v, 0, 0, + 0, vsu, -vsw, + 0, vsw, vsu; + /* clang-format on */ + // Compute final transformation matrix = inverse_yiq * hsv_transform * yiq + Eigen::Map> eigen_matrix(matrix); + eigen_matrix = yiq_inverse * hsv_transform * yiq; +} +} // namespace internal + +#if GOOGLE_CUDA +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +struct AdjustHsvInYiqGPU { + void operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, const float* const delta_h, + const float* const scale_s, const float* const scale_v, + Tensor* const output); +}; + +} // namespace functor + +#endif // GOOGLE_CUDA + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_ADJUST_HSV_IN_YIQ_OP_H_ diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..b71ff9cd507faac66b3a33d3c02ec9b5901d814a --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc @@ -0,0 +1,84 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/kernels/gpu_utils.h" +#include "tensorflow/core/platform/stream_executor.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { + +namespace internal { + +__global__ void compute_tranformation_matrix_cuda(const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + float* const matrix, + const int matrix_size) { + if (matrix_size == kChannelSize * kChannelSize) { + compute_tranformation_matrix( + *delta_h, *scale_s, *scale_v, matrix); + } +} +} // namespace internal + +namespace functor { + +void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count, + const Tensor* const input, + const float* const delta_h, + const float* const scale_s, + const float* const scale_v, + Tensor* const output) { + const uint64 m = channel_count; + const uint64 k = kChannelSize; + const uint64 n = kChannelSize; + auto* cu_stream = ctx->eigen_device().stream(); + OP_REQUIRES(ctx, cu_stream, errors::Internal("No GPU stream available.")); + Tensor tranformation_matrix; + OP_REQUIRES_OK(ctx, ctx->allocate_temp( + DT_FLOAT, TensorShape({kChannelSize * kChannelSize}), + &tranformation_matrix)); + // TODO(huangyp): It takes about 3.5 us to comute tranformation_matrix + // with one thread. Improve its performance if necessary. + internal::compute_tranformation_matrix_cuda<<<1, 1, 0, cu_stream>>>( + delta_h, scale_s, scale_v, tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + // Call cuBlas C = A * B directly. + auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose; + auto a_ptr = + AsDeviceMemory(input->flat().data(), input->flat().size()); + auto b_ptr = AsDeviceMemory(tranformation_matrix.flat().data(), + tranformation_matrix.flat().size()); + auto c_ptr = AsDeviceMemory(output->flat().data(), + output->flat().size()); + auto* stream = ctx->op_device_context()->stream(); + OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available.")); + // TODO(huangyp): share/use autotune cublas algorithms in Matmul.op. + bool blas_launch_status = + stream + ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n, + a_ptr, k, 0.0f, &c_ptr, n) + .ok(); + if (!blas_launch_status) { + ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m, + ", n=", n, ", k=", k)); + } +} +} // namespace functor +} // namespace tensorflow +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4cbbd277840133c9419f9ce3d945b7d099679dc0 --- /dev/null +++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_test.cc @@ -0,0 +1,48 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +class AdjustHsvInYiqOpTest : public OpsTestBase { + protected: +}; + +TEST_F(AdjustHsvInYiqOpTest, IdentiyTransformMatrix) { + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, 1.0, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, {1, 0, 0, 0, 1, 0, 0, 0, 1}); + test::ExpectClose(matrix, expected); +} + +TEST_F(AdjustHsvInYiqOpTest, ScaleValueTransformMatrix) { + float scale_v = 2.3; + Tensor matrix(allocator(), DT_FLOAT, TensorShape({9})); + internal::compute_tranformation_matrix<9>(0.0, 1.0, scale_v, + matrix.flat().data()); + Tensor expected(allocator(), DT_FLOAT, TensorShape({9})); + test::FillValues(&expected, + {scale_v, 0, 0, 0, scale_v, 0, 0, 0, scale_v}); + test::ExpectClose(matrix, expected); +} + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc index 6adf837ca0ab506bd18f5e2e1fc1847e31d782bf..c2e32da133b32c8fe169302668031af8bace2c22 100644 --- a/tensorflow/contrib/image/kernels/image_ops.cc +++ b/tensorflow/contrib/image/kernels/image_ops.cc @@ -43,9 +43,9 @@ template struct FillProjectiveTransform; typedef Eigen::ThreadPoolDevice CPUDevice; using functor::FillProjectiveTransform; +using generator::Interpolation; using generator::INTERPOLATION_BILINEAR; using generator::INTERPOLATION_NEAREST; -using generator::Interpolation; using generator::ProjectiveGenerator; template @@ -72,11 +72,12 @@ class ImageProjectiveTransform : public OpKernel { const Tensor& transform_t = ctx->input(1); OP_REQUIRES(ctx, images_t.shape().dims() == 4, errors::InvalidArgument("Input images must have rank 4")); - OP_REQUIRES(ctx, (TensorShapeUtils::IsMatrix(transform_t.shape()) && - (transform_t.dim_size(0) == images_t.dim_size(0) || - transform_t.dim_size(0) == 1) && - transform_t.dim_size(1) == - ProjectiveGenerator::kNumParameters), + OP_REQUIRES(ctx, + (TensorShapeUtils::IsMatrix(transform_t.shape()) && + (transform_t.dim_size(0) == images_t.dim_size(0) || + transform_t.dim_size(0) == 1) && + transform_t.dim_size(1) == + ProjectiveGenerator::kNumParameters), errors::InvalidArgument( "Input transform should be num_images x 8 or 1 x 8")); auto images = images_t.tensor(); diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.cc b/tensorflow/contrib/image/kernels/segmentation_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe8bf6e21c7b7310527668324571774e8bc50893 --- /dev/null +++ b/tensorflow/contrib/image/kernels/segmentation_ops.cc @@ -0,0 +1,139 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs for ImageConnectedComponents in ../ops/image_ops.cc, and description +// of the algorithm in segmentation_ops.h. + +#define EIGEN_USE_THREADS + +#include "tensorflow/contrib/image/kernels/segmentation_ops.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +using tensorflow::functor::BlockedImageUnionFindFunctor; +using tensorflow::functor::FindRootFunctor; +using tensorflow::functor::ImageConnectedComponentsFunctor; +using tensorflow::functor::TensorRangeFunctor; + +using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + +// Computes connected components on batches of 2D images. +template +class ImageConnectedComponents : public OpKernel { + public: + explicit ImageConnectedComponents(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& images_t = ctx->input(0); + OP_REQUIRES(ctx, images_t.shape().dims() == 3, + errors::InvalidArgument("Input images must have rank 3")); + Tensor forest_t, rank_t; + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64, + images_t.shape(), &forest_t)); + OP_REQUIRES_OK(ctx, ctx->allocate_temp(tensorflow::DT_INT64, + images_t.shape(), &rank_t)); + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t)); + + // Fill forest with values from 0 to n - 1, so that each node points to + // itself. + TensorRangeFunctor()(ctx->eigen_device(), + forest_t.flat()); + auto rank = rank_t.tensor(); + rank.device(ctx->eigen_device()) = rank.constant(OutputType(0)); + + const auto images = images_t.tensor(); + auto forest = forest_t.tensor(); + ImageConnectedComponentsFunctor()( + ctx, output_t->flat(), images, forest, rank); + } +}; + +using CPUDevice = Eigen::ThreadPoolDevice; + +namespace functor { + +// Connected components CPU implementation. See `segmentation_ops.h` for a +// description of the algorithm. +template +struct ImageConnectedComponentsFunctor { + void operator()(OpKernelContext* ctx, + typename TTypes::Flat output, + typename TTypes::ConstTensor images, + typename TTypes::Tensor forest, + typename TTypes::Tensor rank) { + const int64 num_images = images.dimension(0), + num_rows = images.dimension(1), num_cols = images.dimension(2), + num_elements = images.size(); + // Bail out early for an empty image--no work to do. + if (num_elements == 0) { + return; + } + auto worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); + BlockedImageUnionFindFunctor union_find( + images.data(), num_rows, num_cols, forest.data(), rank.data()); + while (union_find.can_merge()) { + union_find.merge_blocks(); + int64 num_blocks_vertically = union_find.num_blocks_vertically(); + int64 num_blocks_horizontally = union_find.num_blocks_horizontally(); + // Merging each block calls union_down for each pixel in a row of the + // block, and union_right for each pixel in a column of the block. Assume + // 20 instructions for each call to union_down or union_right. find() may + // loop more while searching for the root, but this should not be very + // significant. + int cost = (union_find.block_height() + union_find.block_width()) * 20; + Shard(worker_threads->num_threads, worker_threads->workers, + num_images * num_blocks_vertically * num_blocks_horizontally, cost, + [&union_find, num_images, num_blocks_vertically, + num_blocks_horizontally](int64 start_block, int64 limit_block) { + for (int64 i = start_block; i < limit_block; i++) { + int64 block_x = i % num_blocks_horizontally; + int64 block_y = + (i / num_blocks_horizontally) % num_blocks_vertically; + int64 image = + i / (num_blocks_horizontally * num_blocks_vertically); + union_find.merge_internal_block_edges(image, block_y, block_x); + } + }); + } + FindRootFunctor()(ctx->eigen_device(), output, + images.data(), union_find); + } +}; + +} // end namespace functor + +#define REGISTER_IMAGE_CONNECTED_COMPONENTS(TYPE) \ + REGISTER_KERNEL_BUILDER(Name("ImageConnectedComponents") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + ImageConnectedComponents) +// Connected components (arguably) make sense for number, bool, and string types +TF_CALL_NUMBER_TYPES(REGISTER_IMAGE_CONNECTED_COMPONENTS); +TF_CALL_bool(REGISTER_IMAGE_CONNECTED_COMPONENTS); +TF_CALL_string(REGISTER_IMAGE_CONNECTED_COMPONENTS); +#undef REGISTER_IMAGE_CONNECTED_COMPONENTS + +// TODO(ringwalt): Implement on GPU. We probably want to stick to the original +// algorithm by Stava and Benes there for efficiency (computing small blocks in +// shared memory in CUDA thread blocks, instead of starting with single-pixel +// blocks). + +} // end namespace tensorflow diff --git a/tensorflow/contrib/image/kernels/segmentation_ops.h b/tensorflow/contrib/image/kernels/segmentation_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..0957d5fd10f02daad3d8d51aadec9ce9da2660b5 --- /dev/null +++ b/tensorflow/contrib/image/kernels/segmentation_ops.h @@ -0,0 +1,303 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ +#define TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ + +// Connected component analysis. The op is described in ../ops/image_ops.cc. A +// description of the algorithm appears below. + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +namespace functor { + +template +bool is_nonzero(T value) { + return value != T(0); +} + +template <> +bool is_nonzero(string value) { + return value.size() != 0; +} + +// Processes each pixel of an image for union-find, in parallel blocks. This is +// loosely based on the algorithm in "GPU Computing Gems" by Ondrej Stava and +// Bedrich Benes, available here: +// http://hpcg.purdue.edu/bbenes/papers/Stava2011CCL.pdf +// The bulk of the process uses blocks of each image, which have each been +// processed separately. As long as there are multiple blocks in the image, we +// double the height and width of the blocks, creating new blocks which each +// consist of 2x2 previous sub-blocks. On each new block, we process adjacent +// pixels from the previous sub-blocks serially. However, the new blocks are not +// connected, so we can process each block in parallel. +// The GPU algorithm first processes blocks of a fixed size in GPU shared +// memory, with one image block per CUDA thread block. On the CPU, we just start +// with a block size of a single pixel, and borrow the rest of the algorithm +// unchanged. +template +class BlockedImageUnionFindFunctor { + public: + using OutputType = int64; + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE BlockedImageUnionFindFunctor( + const T* images, const int64 num_rows, const int64 num_cols, + OutputType* forest, OutputType* rank) + : images_(images), + num_rows_(num_rows), + num_cols_(num_cols), + block_height_(1), + block_width_(1), + forest_(forest), + rank_(rank) {} + + // Returns the root of the tree that the pixel at the given index belongs to. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + find(OutputType index) const { + while (forest_[index] != index) { + index = forest_[index]; + } + return index; + } + + // Returns the number of blocks along the y axis. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_vertically() const { + return (num_rows_ + block_height_ - 1) / block_height_; + } + + // Returns the number of blocks along the x axis. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks_horizontally() const { + return (num_cols_ + block_width_ - 1) / block_width_; + } + + // Returns the total number of blocks in each image. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 num_blocks() const { + return num_blocks_vertically() * num_blocks_horizontally(); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_height() const { + return block_height_; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int64 block_width() const { + return block_width_; + } + + // Returns whether we may merge again (the image contains more than one + // block). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool can_merge() const { + return block_height_ < num_rows_ || block_width_ < num_cols_; + } + + // Doubles the block size. After this method, you must call + // `merge_internal_block_edges` for each image and each *new* block's xy + // coordinates (typically in parallel). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_blocks() { + block_height_ *= 2; + block_width_ *= 2; + } + + // Processes pairs of pixels within the block which were adjacent in the four + // sub-blocks. This must be done at each stage so that the connected + // components in each block are joined correctly. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void merge_internal_block_edges( + int64 image_index, int64 block_vertical_index, + int64 block_horizontal_index) const { + int64 block_start_y = block_vertical_index * block_height_; + int64 block_start_x = block_horizontal_index * block_width_; + // Merge the 4 sub-blocks horizontally (fixing the vertical seam). + int64 block_center_x = block_start_x + block_width_ / 2 - 1; + if (0 <= block_center_x && block_center_x + 1 < num_cols_) { + int64 merge_blocks_limit_y = + std::min(num_rows_, block_start_y + block_height_); + for (int64 y = block_start_y; y < merge_blocks_limit_y; y++) { + union_right(image_index, y, block_center_x); + } + } + // Merge the 4 sub-blocks vertically (fixing the horizontal seam). + int64 block_center_y = block_start_y + block_height_ / 2 - 1; + if (0 <= block_center_y && block_center_y + 1 < num_rows_) { + int64 merge_blocks_limit_x = + std::min(num_cols_, block_start_x + block_width_); + for (int64 x = block_start_x; x < merge_blocks_limit_x; x++) { + union_down(image_index, block_center_y, x); + } + } + } + + private: + // The input image(s). + const T* const images_; + const int64 num_rows_; + const int64 num_cols_; + // Current height of each sub-block of the image. + int64 block_height_; + // Current width of each sub-block of the image. + int64 block_width_; + // Union-find forest. This has the same size as `images_`, and each entry + // holds the index of its parent in `images_` (roots hold their own index). + // Cycles should not occur. + OutputType* const forest_; + // Union-find rank of each pixel. + OutputType* const rank_; + + // Unions the pixel with the pixel below it if applicable (both pixels are + // true, and the pixel is not in the last row). + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_down(OutputType batch, + OutputType row, + OutputType col) const { + T pixel = read_pixel(batch, row, col); + if (is_nonzero(pixel)) { + const int64 index_a = col + num_cols_ * (row + num_rows_ * batch); + if (row + 1 < num_rows_ && read_pixel(batch, row + 1, col) == pixel) { + const int64 index_b = col + num_cols_ * (row + 1 + num_rows_ * batch); + do_union(index_a, index_b); + } + } + } + + // Unions the pixel with the pixel to the right of it if applicable. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void union_right(OutputType batch, + OutputType row, + OutputType col) const { + T pixel = read_pixel(batch, row, col); + if (is_nonzero(pixel)) { + const int64 index_a = col + num_cols_ * (row + num_rows_ * batch); + if (col + 1 < num_cols_ && read_pixel(batch, row, col + 1) == pixel) { + const int64 index_b = col + 1 + num_cols_ * (row + num_rows_ * batch); + do_union(index_a, index_b); + } + } + } + + // Reads a pixel value in the images. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T + read_pixel(const OutputType batch, const OutputType row, + const OutputType col) const { + return images_[col + num_cols_ * (row + num_rows_ * batch)]; + } + + // Unions the trees that the two pixels belong to, using their index in the + // `images_` array. + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void do_union( + OutputType index_a, OutputType index_b) const { + // Find the roots of index_a and index_b in the forest, and make one the + // child of the other. + index_a = find(index_a); + index_b = find(index_b); + const OutputType rank_a = rank_[index_a]; + const OutputType rank_b = rank_[index_b]; + OutputType parent, child; + if (index_a == index_b) { + return; + } else if (rank_a < rank_b) { + parent = index_a; + child = index_b; + } else { + parent = index_b; + child = index_a; + rank_[parent]++; + } + forest_[child] = parent; + } +}; + +// Runs the ImageUnionFindFunctor on all pixels. Will require different CPU and +// GPU implementations. +template +class ImageConnectedComponentsFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(OpKernelContext* ctx, + typename TTypes::ConstTensor images, + typename TTypes::Tensor forest, + typename TTypes::Tensor rank); +}; + +// Fills a flat Tensor with indices from 0 to n - 1. +template +class TensorRangeFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(const Device& device, + typename TTypes::Flat tensor) { + tensor.device(device) = tensor.generate(TensorRangeGenerator()); + } + + private: + class TensorRangeGenerator { + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + operator()(const Eigen::array& coords) const { + return coords[0]; + } + }; +}; + +// Given the union-find forest, generates the root index for each node. This +// gives us arbitrary, usually non-consecutive ids for each connected component. +// The ids are massaged in Python to get deterministic, consecutive ids. +template +class FindRootFunctor { + public: + using OutputType = typename BlockedImageUnionFindFunctor::OutputType; + + void operator()(const Device& device, + typename TTypes::Flat component_ids, + const T* images, + const BlockedImageUnionFindFunctor& union_find) { + component_ids.device(device) = + component_ids.generate(FindRootGenerator(images, union_find)); + } + + private: + class FindRootGenerator { + const T* const images_; + const BlockedImageUnionFindFunctor union_find_; + + public: + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE FindRootGenerator( + const T* images, BlockedImageUnionFindFunctor union_find) + : images_(images), union_find_(union_find) {} + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE OutputType + operator()(const Eigen::array& coords) const { + if (is_nonzero(images_[coords[0]])) { + // True pixels have an arbitrary segment id > 0. The segment ids will be + // made contiguous later. + return union_find_.find(coords[0]) + 1; + } else { + // False pixels have a segment of 0. + return 0; + } + } + }; +}; + +} // end namespace functor + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_IMAGE_KERNELS_SEGMENTATION_OPS_H_ diff --git a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc index 9f0bf37aed3fc9aeefb7602ef3fda4cfd76f1917..8f9a5c28039b74a874028826ca8a6d5a36ab7cf4 100755 --- a/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc @@ -143,8 +143,8 @@ class SingleImageRandomDotStereogramsOp : public OpKernel { } data_box_left = deltaX_border_image / 2; // Center DATA in X dimension - data_box_width = data_Xwindow; // width of scan line - data_box_height = data_Ywindow; // hight of image + data_box_width = data_Xwindow; // width of scan line + data_box_height = data_Ywindow; // hight of image const T* inputZ = input_tensor.flat().data(); // Flatten input Z buffer diff --git a/tensorflow/contrib/image/ops/image_ops.cc b/tensorflow/contrib/image/ops/image_ops.cc index 4527fdd87a8be3390fb0840410218ab74a27f0d2..68771b3d054a64ba94141c092e20df1ed6b2339b 100644 --- a/tensorflow/contrib/image/ops/image_ops.cc +++ b/tensorflow/contrib/image/ops/image_ops.cc @@ -98,4 +98,34 @@ col_to_row_match_indices: A vector of length num_columns, which is the number `col_to_row_match_indices[j]`. )doc"); +REGISTER_OP("ImageConnectedComponents") + .Input("image: dtype") + .Output("components: int64") + .Attr( + "dtype: {int64, int32, uint16, int16, uint8, int8, half, float, " + "double, bool, string}") + .SetShapeFn([](InferenceContext* c) { + return shape_inference::UnchangedShape(c); + }) + .Doc(R"doc( +Find the connected components of image(s). + +For each image (along the 0th axis), all connected components of adjacent pixels +with the same non-zero value are detected and given unique ids. + +The returned `components` tensor has 0s for the zero pixels of `images`, and +arbitrary nonzero ids for the connected components of nonzero values. Ids are +unique across all of the images, and are in row-major order by the first pixel +in the component. + +Uses union-find with union by rank but not path compression, giving a runtime of +`O(n log n)`. See: + https://en.wikipedia.org/wiki/Disjoint-set_data_structure#Time_Complexity + +image: Image(s) with shape (N, H, W). +components: Component ids for each pixel in "image". Same shape as "image". Zero + pixels all have an output of 0, and all components of adjacent pixels with + the same value are given consecutive ids, starting from 1. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc index f8b56ab1c5400694b3aa8d4a0c19c7769aa8cbce..8139d4272d6950815bd39a64e86e0f7422e6f799 100755 --- a/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc +++ b/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc @@ -19,6 +19,10 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + REGISTER_OP("SingleImageRandomDotStereograms") .Attr("T: {double,float,int64,int32}") .Input("depth_values: T") @@ -37,6 +41,28 @@ REGISTER_OP("SingleImageRandomDotStereograms") "output_image_shape: shape = { dim {size:1024} dim {size: 768} dim " "{size: 1}}") .Attr("output_data_window: shape = { dim {size:1022} dim {size: 757}}") + .SetShapeFn([](InferenceContext* c) { + // Validate that the output_image_shape attr is correct. + // NOTE: The output_image_shape is [X, Y, C] + // while the output data is [Y, X, C] (or [H, W, C]). + // As a result, by default the output_image_shape has the value + // of [1024, 768, 1] but the output data will be [768, 1024, 1]. + PartialTensorShape shape; + TF_RETURN_IF_ERROR(c->GetAttr("output_image_shape", &shape)); + ShapeHandle output_image_shape; + TF_RETURN_IF_ERROR( + c->MakeShapeFromPartialTensorShape(shape, &output_image_shape)); + DimensionHandle x_dim = c->Dim(output_image_shape, 0); + DimensionHandle y_dim = c->Dim(output_image_shape, 1); + + int colors; + TF_RETURN_IF_ERROR(c->GetAttr("number_colors", &colors)); + + c->set_output( + 0, c->MakeShape( + {y_dim, x_dim, colors > 256 ? c->MakeDim(3) : c->MakeDim(1)})); + return Status::OK(); + }) .Doc(R"doc( Outputs a single image random dot stereogram for export via encode_PNG/JPG OP. diff --git a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py index b85f19d29b79defa10493bdbaa4a1b237cb2a9ee..a495b58b7f6481d4cdedf73f23615d0390eb6a45 100644 --- a/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/distort_image_ops_test.py @@ -172,7 +172,7 @@ class AdjustValueInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_np = self._adjust_value_in_yiq_np(x_np, scale) y_tf = self._adjust_value_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_np, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_np, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -237,7 +237,7 @@ class AdjustSaturationInYiqTest(test_util.TensorFlowTestCase): raise AssertionError('Invalid test style: %s' % (test_style)) y_baseline = self._adjust_saturation_in_yiq_np(x_np, scale) y_tf = self._adjust_saturation_in_yiq_tf(x_np, scale) - self.assertAllClose(y_tf, y_baseline, rtol=2e-5, atol=1e-5) + self.assertAllClose(y_tf, y_baseline, rtol=2e-4, atol=1e-4) def test_invalid_shapes(self): x_np = np.random.rand(2, 3) * 255. @@ -291,6 +291,9 @@ class AdjustHueInYiqBenchmark(test.Benchmark): def benchmark_adjust_hue_in_yiqCpuAll(self): self._benchmark_adjust_hue_in_yiq('/cpu:0', None) + def benchmark_adjust_hue_in_yiq_gpu_all(self): + self._benchmark_adjust_hue_in_yiq(test.gpu_device_name(), None) + class AdjustSaturationInYiqBenchmark(test.Benchmark): @@ -333,6 +336,9 @@ class AdjustSaturationInYiqBenchmark(test.Benchmark): def benchmark_adjust_saturation_in_yiq_cpu_all(self): self._benchmark_adjust_saturation_in_yiq('/cpu:0', None) + def benchmark_adjust_saturation_in_yiq_gpu_all(self): + self._benchmark_adjust_saturation_in_yiq(test.gpu_device_name(), None) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py new file mode 100644 index 0000000000000000000000000000000000000000..48066cbacefe6b229a1f485486f11e8b8af7704f --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/segmentation_test.py @@ -0,0 +1,189 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for connected component analysis.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +import numpy as np + +from tensorflow.contrib.image.python.ops import image_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import googletest + +# Image for testing connected_components, with a single, winding component. +SNAKE = np.asarray( + [[0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 1, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 1, 1, 1, 1, 1, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0]]) # pyformat: disable + + +class SegmentationTest(test_util.TensorFlowTestCase): + + def testDisconnected(self): + arr = math_ops.cast( + [[1, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 1, 0, 1, 0], + [1, 0, 1, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0]], + dtypes.bool) # pyformat: disable + expected = ( + [[1, 0, 0, 2, 0, 0, 0, 0, 3], + [0, 4, 0, 0, 0, 5, 0, 6, 0], + [7, 0, 8, 0, 0, 0, 9, 0, 0], + [0, 0, 0, 0, 10, 0, 0, 0, 0], + [0, 0, 11, 0, 0, 0, 0, 0, 0]]) # pyformat: disable + with self.test_session(): + self.assertAllEqual(image_ops.connected_components(arr).eval(), expected) + + def testSimple(self): + arr = [[0, 1, 0], [1, 1, 1], [0, 1, 0]] + with self.test_session(): + # Single component with id 1. + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + arr, dtypes.bool)).eval(), arr) + + def testSnake(self): + with self.test_session(): + # Single component with id 1. + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + SNAKE, dtypes.bool)).eval(), SNAKE) + + def testSnake_disconnected(self): + for i in range(SNAKE.shape[0]): + for j in range(SNAKE.shape[1]): + with self.test_session(): + # If we disconnect any part of the snake except for the endpoints, + # there will be 2 components. + if SNAKE[i, j] and (i, j) not in [(1, 1), (6, 3)]: + disconnected_snake = SNAKE.copy() + disconnected_snake[i, j] = 0 + components = image_ops.connected_components( + math_ops.cast(disconnected_snake, dtypes.bool)).eval() + self.assertEqual(components.max(), 2, 'disconnect (%d, %d)' % (i, + j)) + bins = np.bincount(components.ravel()) + # Nonzero number of pixels labeled 0, 1, or 2. + self.assertGreater(bins[0], 0) + self.assertGreater(bins[1], 0) + self.assertGreater(bins[2], 0) + + def testMultipleImages(self): + images = [[[1, 1, 1, 1], + [1, 0, 0, 1], + [1, 0, 0, 1], + [1, 1, 1, 1]], + [[1, 0, 0, 1], + [0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 1]], + [[1, 1, 0, 1], + [0, 1, 1, 0], + [1, 0, 1, 0], + [0, 0, 1, 1]]] # pyformat: disable + expected = [[[1, 1, 1, 1], + [1, 0, 0, 1], + [1, 0, 0, 1], + [1, 1, 1, 1]], + [[2, 0, 0, 3], + [0, 0, 0, 0], + [0, 0, 0, 0], + [4, 0, 0, 5]], + [[6, 6, 0, 7], + [0, 6, 6, 0], + [8, 0, 6, 0], + [0, 0, 6, 6]]] # pyformat: disable + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(math_ops.cast( + images, dtypes.bool)).eval(), expected) + + def testZeros(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components( + array_ops.zeros((100, 20, 50), dtypes.bool)).eval(), + np.zeros((100, 20, 50))) + + def testOnes(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components( + array_ops.ones((100, 20, 50), dtypes.bool)).eval(), + np.tile(np.arange(100)[:, None, None] + 1, [1, 20, 50])) + + def testOnes_small(self): + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(array_ops.ones((3, 5), + dtypes.bool)).eval(), + np.ones((3, 5))) + + def testRandom_scipy(self): + np.random.seed(42) + images = np.random.randint(0, 2, size=(10, 100, 200)).astype(np.bool) + expected = connected_components_reference_implementation(images) + if expected is None: + return + with self.test_session(): + self.assertAllEqual( + image_ops.connected_components(images).eval(), expected) + + +def connected_components_reference_implementation(images): + try: + # pylint: disable=g-import-not-at-top + from scipy.ndimage import measurements + except ImportError: + logging.exception('Skipping test method because scipy could not be loaded') + return + image_or_images = np.asarray(images) + if len(image_or_images.shape) == 2: + images = image_or_images[None, :, :] + elif len(image_or_images.shape) == 3: + images = image_or_images + components = np.asarray([measurements.label(image)[0] for image in images]) + # Get the count of nonzero ids for each image, and offset each image's nonzero + # ids using the cumulative sum. + num_ids_per_image = components.reshape( + [-1, components.shape[1] * components.shape[2]]).max(axis=-1) + positive_id_start_per_image = np.cumsum(num_ids_per_image) + for i in range(components.shape[0]): + new_id_start = positive_id_start_per_image[i - 1] if i > 0 else 0 + components[i, components[i] > 0] += new_id_start + if len(image_or_images.shape) == 2: + return components[0, :, :] + else: + return components + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4029e558d92a2b6539456bf9cf49ec2d21c9f3 --- /dev/null +++ b/tensorflow/contrib/image/python/kernel_tests/single_image_random_dot_stereograms_ops_test.py @@ -0,0 +1,82 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the 'License'); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an 'AS IS' BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for python single_image_random_dot_stereograms_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms \ + import single_image_random_dot_stereograms +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import test_util +from tensorflow.python.platform import googletest + +class SingleImageRandomDotStereogramsTest(test_util.TensorFlowTestCase): + + def test_shape_function_default(self): + """ + NOTE: The output_image_shape is [X, Y, C] + while the output data is [Y, X, C] (or [H, W, C]). + As a result, by default the output_image_shape has the value + of [1024, 768, 1], but the output data will be [768, 1024, 1]. + """ + x_np = [[1, 2, 3, 3, 2, 1], + [1, 2, 3, 4, 5, 2], + [1, 2, 3, 4, 5, 3], + [1, 2, 3, 4, 5, 4], + [6, 5, 4, 4, 5, 5]] + x_tf = constant_op.constant(x_np) + # By default [1024, 768, 1] => [768, 1024, 1]. + sirds_1 = single_image_random_dot_stereograms( + x_tf, + convergence_dots_size=8, + number_colors=256, + normalize=True) + shape_1 = sirds_1.get_shape().as_list() + self.assertEqual(shape_1, [768, 1024, 1]) + with self.test_session(): + r_tf_1 = sirds_1.eval() + self.assertAllEqual(shape_1, r_tf_1.shape) + + # If color > 256 then [1024, 768, 3] => [768, 1024, 3]. + sirds_2 = single_image_random_dot_stereograms( + x_tf, + convergence_dots_size=8, + number_colors=512, + normalize=True) + shape_2 = sirds_2.get_shape().as_list() + self.assertEqual(shape_2, [768, 1024, 3]) + with self.test_session(): + r_tf_2 = sirds_2.eval() + self.assertAllEqual(shape_2, r_tf_2.shape) + + # If explicitly set output_image_shape to [1200, 800, 1], + # then the output data should be [800, 1200, 1]. + sirds_3 = single_image_random_dot_stereograms( + x_tf, + convergence_dots_size=8, + number_colors=256, + normalize=True, + output_image_shape=[1200, 800, 1]) + shape_3 = sirds_3.get_shape().as_list() + self.assertEqual(shape_3, [800, 1200, 1]) + with self.test_session(): + r_tf_3 = sirds_3.eval() + self.assertAllEqual(shape_3, r_tf_3.shape) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index faedee6f87772016561671bacd87f88657eafffb..c139ae89d8d682d6b87813c3a21703ffa762f28e 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import resource_loader @@ -34,11 +35,12 @@ _image_ops_so = loader.load_op_library( _IMAGE_DTYPES = set( [dtypes.uint8, dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64]) +ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) def rotate(images, angles, interpolation="NEAREST", name=None): - """Rotate image(s) by the passed angle(s) in radians. + """Rotate image(s) counterclockwise by the passed angle(s) in radians. Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) @@ -288,31 +290,76 @@ def compose_transforms(*transforms): """ assert transforms, "transforms cannot be empty" with ops.name_scope("compose_transforms"): - composed = _flat_transforms_to_matrices(transforms[0]) + composed = flat_transforms_to_matrices(transforms[0]) for tr in transforms[1:]: # Multiply batches of matrices. - composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) - return _transform_matrices_to_flat(composed) + composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr)) + return matrices_to_flat_transforms(composed) -def _flat_transforms_to_matrices(transforms): - # Make the transform(s) 2D in case the input is a single transform. - transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) - num_transforms = array_ops.shape(transforms)[0] - # Add a column of ones for the implicit last entry in the matrix. - return array_ops.reshape( - array_ops.concat( - [transforms, array_ops.ones([num_transforms, 1])], axis=1), - constant_op.constant([-1, 3, 3])) +def flat_transforms_to_matrices(transforms): + """Converts `tf.contrib.image` projective transforms to affine matrices. + Note that the output matrices map output coordinates to input coordinates. For + the forward transformation matrix, call `tf.linalg.inv` on the result. -def _transform_matrices_to_flat(transform_matrices): - # Flatten each matrix. - transforms = array_ops.reshape(transform_matrices, - constant_op.constant([-1, 9])) - # Divide each matrix by the last entry (normally 1). - transforms /= transforms[:, 8:9] - return transforms[:, :8] + Args: + transforms: Vector of length 8, or batches of transforms with shape + `(N, 8)`. + + Returns: + 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the + *output coordinates* (in homogeneous coordinates) of each transform to the + corresponding *input coordinates*. + + Raises: + ValueError: If `transforms` have an invalid shape. + """ + with ops.name_scope("flat_transforms_to_matrices"): + transforms = ops.convert_to_tensor(transforms, name="transforms") + if transforms.shape.ndims not in (1, 2): + raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms) + # Make the transform(s) 2D in case the input is a single transform. + transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) + num_transforms = array_ops.shape(transforms)[0] + # Add a column of ones for the implicit last entry in the matrix. + return array_ops.reshape( + array_ops.concat( + [transforms, array_ops.ones([num_transforms, 1])], axis=1), + constant_op.constant([-1, 3, 3])) + + +def matrices_to_flat_transforms(transform_matrices): + """Converts affine matrices to `tf.contrib.image` projective transforms. + + Note that we expect matrices that map output coordinates to input coordinates. + To convert forward transformation matrices, call `tf.linalg.inv` on the + matrices and use the result here. + + Args: + transform_matrices: One or more affine transformation matrices, for the + reverse transformation in homogeneous coordinates. Shape `(3, 3)` or + `(N, 3, 3)`. + + Returns: + 2D tensor of flat transforms with shape `(N, 8)`, which may be passed into + `tf.contrib.image.transform`. + + Raises: + ValueError: If `transform_matrices` have an invalid shape. + """ + with ops.name_scope("matrices_to_flat_transforms"): + transform_matrices = ops.convert_to_tensor( + transform_matrices, name="transform_matrices") + if transform_matrices.shape.ndims not in (2, 3): + raise ValueError( + "Matrices should be 2D or 3D, got: %s" % transform_matrices) + # Flatten each matrix. + transforms = array_ops.reshape(transform_matrices, + constant_op.constant([-1, 9])) + # Divide each matrix by the last entry (normally 1). + transforms /= transforms[:, 8:9] + return transforms[:, :8] @ops.RegisterGradient("ImageProjectiveTransform") @@ -344,9 +391,9 @@ def _image_projective_transform_grad(op, grad): raise TypeError("Transforms should have rank 1 or 2.") # Invert transformations - transforms = _flat_transforms_to_matrices(transforms=transforms) + transforms = flat_transforms_to_matrices(transforms=transforms) inverse = linalg_ops.matrix_inverse(transforms) - transforms = _transform_matrices_to_flat(inverse) + transforms = matrices_to_flat_transforms(inverse) output = gen_image_ops.image_projective_transform( grad, transforms, interpolation=interpolation) if len(image_or_images.get_shape()) == 2: @@ -395,4 +442,72 @@ def bipartite_match(distance_mat, return result +def connected_components(images): + """Labels the connected components in a batch of images. + + A component is a set of pixels in a single input image, which are all adjacent + and all have the same non-zero value. The components using a squared + connectivity of one (all True entries are joined with their neighbors above, + below, left, and right). Components across all images have consecutive ids 1 + through n. Components are labeled according to the first pixel of the + component appearing in row-major order (lexicographic order by + image_index_in_batch, row, col). Zero entries all have an output id of 0. + + This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array + with the default structuring element (which is the connectivity used here). + + Args: + images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s). + + Returns: + Components with the same shape as `images`. False entries in `images` have + value 0, and all True entries map to a component id > 0. + + Raises: + TypeError: if `images` is not 2D or 3D. + """ + with ops.name_scope("connected_components"): + image_or_images = ops.convert_to_tensor(images, name="images") + if len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images + else: + raise TypeError( + "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" % + image_or_images.get_shape()) + components = gen_image_ops.image_connected_components(images) + + # TODO(ringwalt): Component id renaming should be done in the op, to avoid + # constructing multiple additional large tensors. + components_flat = array_ops.reshape(components, [-1]) + unique_ids, id_index = array_ops.unique(components_flat) + id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0] + # Map each nonzero id to consecutive values. + nonzero_consecutive_ids = math_ops.range( + array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1 + + def no_zero(): + # No need to insert a zero into the ids. + return nonzero_consecutive_ids + + def has_zero(): + # Insert a zero in the consecutive ids where zero appears in unique_ids. + # id_is_zero has length 1. + zero_id_ind = math_ops.to_int32(id_is_zero[0]) + ids_before = nonzero_consecutive_ids[:zero_id_ind] + ids_after = nonzero_consecutive_ids[zero_id_ind:] + return array_ops.concat([ids_before, [0], ids_after], axis=0) + + new_ids = control_flow_ops.cond( + math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero) + components = array_ops.reshape( + array_ops.gather(new_ids, id_index), array_ops.shape(components)) + if len(image_or_images.get_shape()) == 2: + return components[0, :, :] + else: + return components + + ops.NotDifferentiable("BipartiteMatch") +ops.NotDifferentiable("ImageConnectedComponents") diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index bb766e59d2cee648042cc08be466796d9233ad66..d4a6a5bcbb52511d4093587814100b2a0e8b2420 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -26,18 +26,20 @@ _sirds_ops = loader.load_op_library( resource_loader.get_path_to_datafile( "_single_image_random_dot_stereograms.so")) -def single_image_random_dot_stereograms( - depth_values, - hidden_surface_removal=None, - convergence_dots_size=None, - dots_per_inch=None, - eye_separation=None, mu=None, - normalize=None, normalize_max=None, - normalize_min=None, - border_level=None, - number_colors=None, - output_image_shape=None, - output_data_window=None): + +def single_image_random_dot_stereograms(depth_values, + hidden_surface_removal=None, + convergence_dots_size=None, + dots_per_inch=None, + eye_separation=None, + mu=None, + normalize=None, + normalize_max=None, + normalize_min=None, + border_level=None, + number_colors=None, + output_image_shape=None, + output_data_window=None): """Output a RandomDotStereogram Tensor for export via encode_PNG/JPG OP. Given the 2-D tensor 'depth_values' with encoded Z values, this operation @@ -45,7 +47,8 @@ def single_image_random_dot_stereograms( for the encode_PNG/JPG ops. Be careful with image compression as this may corrupt the encode 3-D data witin the image. - Based upon [this paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper). + Based upon [this + paper](http://www.learningace.com/doc/4331582/b6ab058d1e206d68ab60e4e1ead2fe6e/sirds-paper). This outputs a SIRDS image as picture_out.png: @@ -113,7 +116,8 @@ def single_image_random_dot_stereograms( hidden_surface_removal=hidden_surface_removal, convergence_dots_size=convergence_dots_size, dots_per_inch=dots_per_inch, - eye_separation=eye_separation, mu=mu, + eye_separation=eye_separation, + mu=mu, normalize=normalize, normalize_max=normalize_max, normalize_min=normalize_min, @@ -123,4 +127,5 @@ def single_image_random_dot_stereograms( output_data_window=output_data_window) return result + ops.NotDifferentiable("SingleImageRandomDotStereograms") diff --git a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc index ca288c1f737d25faac678f5c199d5c1e49f721cb..886f6798150c57d8066546b0919481d3878882fc 100644 --- a/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc +++ b/tensorflow/contrib/input_pipeline/kernels/input_pipeline_kernels.cc @@ -34,9 +34,8 @@ class ObtainNextOp : public OpKernel { // Allocate output. Tensor* output_tensor = nullptr; - OP_REQUIRES_OK( - ctx, - ctx->allocate_output("out_element", TensorShape({}), &output_tensor)); + OP_REQUIRES_OK(ctx, ctx->allocate_output("out_element", TensorShape({}), + &output_tensor)); // Obtain mutex for the "counter" tensor. mutex* mu; diff --git a/tensorflow/contrib/kafka/BUILD b/tensorflow/contrib/kafka/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..efb403462a6e5df5b69ac0735ffc03f40d4a252c --- /dev/null +++ b/tensorflow/contrib/kafka/BUILD @@ -0,0 +1,105 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_kernel_library( + name = "kafka_kernels", + srcs = ["kernels/kafka_dataset_ops.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:bounds_check_lib", + "//tensorflow/core/kernels:dataset", + "//third_party/eigen3", + "@kafka", + ], +) + +tf_gen_op_libs( + op_lib_names = ["kafka_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "gen_kafka_ops", + out = "python/ops/gen_kafka_ops.py", + require_shape_functions = True, + deps = [":kafka_ops_op_lib"], +) + +py_library( + name = "kafka", + srcs = [ + "__init__.py", + "python/ops/kafka_dataset_ops.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":gen_kafka_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//tensorflow/python/data/ops:readers", + ], +) + +# The Kafka server has to be setup before running the test. +# The Kafka server is setup through Docker so the Docker engine +# has to be installed. +# +# Once the Docker engine is ready: +# To setup the Kafka server: +# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh start kafka +# +# After the test is complete: +# To team down the Kafka server: +# $ bash tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh stop kafka +tf_py_test( + name = "kafka_test", + srcs = ["python/kernel_tests/kafka_test.py"], + additional_deps = [ + ":kafka", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + tags = [ + "manual", + "notap", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/ndlstm/python/__init__.py b/tensorflow/contrib/kafka/__init__.py similarity index 68% rename from tensorflow/contrib/ndlstm/python/__init__.py rename to tensorflow/contrib/kafka/__init__.py index 1aa51a6ec40c042ca3c26c6b08e5bdb8a42a12bd..4d755c40568dfa2f7f6f617cf3180268837a5ca0 100644 --- a/tensorflow/contrib/ndlstm/python/__init__.py +++ b/tensorflow/contrib/kafka/__init__.py @@ -4,7 +4,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -12,14 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Init file, giving convenient access to all ndlstm ops.""" +"""Kafka Dataset. + +@@KafkaDataset +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=wildcard-import,g-importing-member -from tensorflow.contrib.ndlstm.python.lstm1d import * -from tensorflow.contrib.ndlstm.python.lstm2d import * -from tensorflow.contrib.ndlstm.python.misc import * -# pylint: enable=wildcard-import +from tensorflow.contrib.kafka.python.ops.kafka_dataset_ops import KafkaDataset + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "KafkaDataset", +] + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..88ef5f357113372b0a2d0cb13382ac980a61252d --- /dev/null +++ b/tensorflow/contrib/kafka/kernels/kafka_dataset_ops.cc @@ -0,0 +1,321 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/dataset.h" + +#include "tensorflow/core/framework/tensor.h" + +#include "src-cpp/rdkafkacpp.h" + +namespace tensorflow { + +class KafkaDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* topics_tensor; + OP_REQUIRES_OK(ctx, ctx->input("topics", &topics_tensor)); + OP_REQUIRES( + ctx, topics_tensor->dims() <= 1, + errors::InvalidArgument("`topics` must be a scalar or a vector.")); + + std::vector topics; + topics.reserve(topics_tensor->NumElements()); + for (int i = 0; i < topics_tensor->NumElements(); ++i) { + topics.push_back(topics_tensor->flat()(i)); + } + + std::string servers = ""; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "servers", &servers)); + std::string group = ""; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "group", &group)); + bool eof = false; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "eof", &eof)); + int64 timeout = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "timeout", &timeout)); + OP_REQUIRES(ctx, (timeout > 0), + errors::InvalidArgument( + "Timeout value should be large than 0, got ", timeout)); + *output = new Dataset(ctx, std::move(topics), servers, group, eof, timeout); + } + + private: + class Dataset : public GraphDatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector topics, + const string& servers, const string& group, const bool eof, + const int64 timeout) + : GraphDatasetBase(ctx), + topics_(std::move(topics)), + servers_(servers), + group_(group), + eof_(eof), + timeout_(timeout) {} + + std::unique_ptr MakeIterator( + const string& prefix) const override { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::Kafka")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_STRING}); + return *dtypes; + } + + const std::vector& output_shapes() const override { + static std::vector* shapes = + new std::vector({{}}); + return *shapes; + } + + string DebugString() override { return "KafkaDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(DatasetGraphDefBuilder* b, + Node** output) const override { + Node* topics = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(topics_, &topics)); + Node* servers = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(servers_, &servers)); + Node* group = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(group_, &group)); + Node* eof = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(eof_, &eof)); + Node* timeout = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(timeout_, &timeout)); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {topics, servers, group, eof, timeout}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + // We are currently processing a topic, so try to read the next line. + if (consumer_.get()) { + while (true) { + if (limit_ >= 0 && + (topic_partition_->offset() >= limit_ || offset_ >= limit_)) { + // EOF current topic + break; + } + std::unique_ptr message( + consumer_->consume(dataset()->timeout_)); + if (message->err() == RdKafka::ERR_NO_ERROR) { + // Produce the line as output. + Tensor line_tensor(cpu_allocator(), DT_STRING, {}); + line_tensor.scalar()() = + std::string(static_cast(message->payload()), + message->len()); + out_tensors->emplace_back(std::move(line_tensor)); + *end_of_sequence = false; + // Sync offset + offset_ = message->offset(); + return Status::OK(); + } + + if (message->err() == RdKafka::ERR__PARTITION_EOF && + dataset()->eof_) { + // EOF current topic + break; + } + if (message->err() != RdKafka::ERR__TIMED_OUT) { + return errors::Internal("Failed to consume:", + message->errstr()); + } + message.reset(nullptr); + consumer_->poll(0); + } + + // We have reached the end of the current topic, so maybe + // move on to next topic. + ResetStreamsLocked(); + ++current_topic_index_; + } + + // Iteration ends when there are no more topic to process. + if (current_topic_index_ == dataset()->topics_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_topic_index"), + current_topic_index_)); + + // `consumer_` is empty if + // 1. GetNext has not been called even once. + // 2. All topics have been read and iterator has been exhausted. + if (consumer_.get()) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("current_pos"), offset_)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + ResetStreamsLocked(); + int64 current_topic_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_topic_index"), + ¤t_topic_index)); + current_topic_index_ = size_t(current_topic_index); + // The key "current_pos" is written only if the iterator was saved + // with an open topic. + if (reader->Contains(full_name("current_pos"))) { + int64 current_pos; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("current_pos"), ¤t_pos)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + topic_partition_->set_offset(current_pos); + if (topic_partition_->offset() != current_pos) { + return errors::Internal("Failed to restore to offset ", + current_pos); + } + offset_ = current_pos; + } + return Status::OK(); + } + + private: + // Sets up Kafka streams to read from the topic at + // `current_topic_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_topic_index_ >= dataset()->topics_.size()) { + return errors::InvalidArgument( + "current_topic_index_:", current_topic_index_, + " >= topics_.size():", dataset()->topics_.size()); + } + + // Actually move on to next topic. + string entry = dataset()->topics_[current_topic_index_]; + + std::vector parts = str_util::Split(entry, ":"); + if (parts.size() < 1) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + string topic = parts[0]; + int32 partition = 0; + if (parts.size() > 1) { + if (!strings::safe_strto32(parts[1], &partition)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + int64 offset = 0; + if (parts.size() > 2) { + if (!strings::safe_strto64(parts[2], &offset)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + + topic_partition_.reset( + RdKafka::TopicPartition::create(topic, partition, offset)); + + offset_ = topic_partition_->offset(); + limit_ = -1; + if (parts.size() > 3) { + if (!strings::safe_strto64(parts[3], &limit_)) { + return errors::InvalidArgument("Invalid parameters: ", entry); + } + } + + std::unique_ptr conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_GLOBAL)); + std::unique_ptr topic_conf( + RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC)); + + std::string errstr; + + RdKafka::Conf::ConfResult result = + conf->set("default_topic_conf", topic_conf.get(), errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set default_topic_conf:", errstr); + } + + result = conf->set("bootstrap.servers", dataset()->servers_, errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set bootstrap.servers ", + dataset()->servers_, ":", errstr); + } + result = conf->set("group.id", dataset()->group_, errstr); + if (result != RdKafka::Conf::CONF_OK) { + return errors::Internal("Failed to set group.id ", dataset()->group_, + ":", errstr); + } + + consumer_.reset(RdKafka::KafkaConsumer::create(conf.get(), errstr)); + if (!consumer_.get()) { + return errors::Internal("Failed to create consumer:", errstr); + } + + std::vector partitions; + partitions.emplace_back(topic_partition_.get()); + RdKafka::ErrorCode err = consumer_->assign(partitions); + if (err != RdKafka::ERR_NO_ERROR) { + return errors::Internal( + "Failed to assign partition [", topic_partition_->topic(), ", ", + topic_partition_->partition(), ", ", topic_partition_->offset(), + "]:", RdKafka::err2str(err)); + } + + return Status::OK(); + } + + // Resets all Kafka streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + consumer_->unassign(); + consumer_->close(); + consumer_.reset(nullptr); + } + + mutex mu_; + size_t current_topic_index_ GUARDED_BY(mu_) = 0; + int64 offset_ GUARDED_BY(mu_) = 0; + int64 limit_ GUARDED_BY(mu_) = -1; + std::unique_ptr topic_partition_ GUARDED_BY(mu_); + std::unique_ptr consumer_ GUARDED_BY(mu_); + }; + + const std::vector topics_; + const std::string servers_; + const std::string group_; + const bool eof_; + const int64 timeout_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("KafkaDataset").Device(DEVICE_CPU), + KafkaDatasetOp); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kafka/ops/kafka_ops.cc b/tensorflow/contrib/kafka/ops/kafka_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..8cdf16103bab2b22d51c144d21a589e1e39f2f0b --- /dev/null +++ b/tensorflow/contrib/kafka/ops/kafka_ops.cc @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("KafkaDataset") + .Input("topics: string") + .Input("servers: string") + .Input("group: string") + .Input("eof: bool") + .Input("timeout: int64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"doc( +Creates a dataset that emits the messages of one or more Kafka topics. + +topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. +servers: A list of bootstrap servers. +group: The consumer group id. +eof: If True, the kafka reader will stop on EOF. +timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py new file mode 100644 index 0000000000000000000000000000000000000000..621911876fc502ece76b08eb6c28697b3c12c863 --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.py @@ -0,0 +1,115 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for KafkaDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import kafka_dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class KafkaDatasetTest(test.TestCase): + + def setUp(self): + # The Kafka server has to be setup before the test + # and tear down after the test manually. + # The docker engine has to be installed. + # + # To setup the Kafka server: + # $ bash kafka_test.sh start kafka + # + # To team down the Kafka server: + # $ bash kafka_test.sh stop kafka + pass + + def testKafkaDataset(self): + topics = array_ops.placeholder(dtypes.string, shape=[None]) + num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + repeat_dataset = kafka_dataset_ops.KafkaDataset( + topics, group="test", eof=True).repeat(num_epochs) + batch_dataset = repeat_dataset.batch(batch_size) + + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) + init_op = iterator.make_initializer(repeat_dataset) + init_batch_op = iterator.make_initializer(batch_dataset) + get_next = iterator.get_next() + + with self.test_session() as sess: + # Basic test: read from topic 0. + sess.run(init_op, feed_dict={topics: ["test:0:0:4"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from topic 1. + sess.run(init_op, feed_dict={topics: ["test:0:5:-1"], num_epochs: 1}) + for i in range(5): + self.assertEqual("D" + str(i + 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Basic test: read from both topics. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 1 + }) + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test repeated iteration through both files. + sess.run( + init_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10 + }) + for _ in range(10): + for j in range(2): + for i in range(5): + self.assertEqual("D" + str(i + j * 5), sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test batched and repeated iteration through both files. + sess.run( + init_batch_op, + feed_dict={ + topics: ["test:0:0:4", "test:0:5:-1"], + num_epochs: 10, + batch_size: 5 + }) + for _ in range(10): + self.assertAllEqual(["D" + str(i) for i in range(5)], + sess.run(get_next)) + self.assertAllEqual(["D" + str(i + 5) for i in range(5)], + sess.run(get_next)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh new file mode 100644 index 0000000000000000000000000000000000000000..adf027b8e714124cde2b4618546e20c6b7162e1f --- /dev/null +++ b/tensorflow/contrib/kafka/python/kernel_tests/kafka_test.sh @@ -0,0 +1,48 @@ +#!/usr/bin/env bash +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e +set -o pipefail + +if [ "$#" -ne 2 ]; then + echo "Usage: $0 start|stop " >&2 + exit 1 +fi + +container=$2 +if [ "$1" == "start" ]; then + docker run -d --rm --net=host --name=$container spotify/kafka + echo Wait 5 secs until kafka is up and running + sleep 5 + echo Create test topic + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-topics.sh --create --zookeeper localhost:2181 --replication-factor 1 --partitions 1 --topic test' + echo Create test message + docker exec $container bash -c 'echo -e "D0\nD1\nD2\nD3\nD4\nD5\nD6\nD7\nD8\nD9" > /test' + echo Produce test message + docker exec $container bash -c '/opt/kafka_2.11-0.10.1.0/bin/kafka-console-producer.sh --topic test --broker-list 127.0.0.1:9092 < /test' + + echo Container $container started successfully +elif [ "$1" == "stop" ]; then + docker rm -f $container + + echo Container $container stopped successfully +else + echo "Usage: $0 start|stop " >&2 + exit 1 +fi + + + diff --git a/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8e51d27a342359881de072c3979a2b5a7fc034ea --- /dev/null +++ b/tensorflow/contrib/kafka/python/ops/kafka_dataset_ops.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kafka Dataset.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kafka.python.ops import gen_kafka_ops +from tensorflow.python.data.ops.readers import Dataset +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape + + +class KafkaDataset(Dataset): + """A Kafka Dataset that consumes the message. + """ + + def __init__(self, + topics, + servers="localhost", + group="", + eof=False, + timeout=1000): + """Create a KafkaReader. + + Args: + topics: A `tf.string` tensor containing one or more subscriptions, + in the format of [topic:partition:offset:length], + by default length is -1 for unlimited. + servers: A list of bootstrap servers. + group: The consumer group id. + eof: If True, the kafka reader will stop on EOF. + timeout: The timeout value for the Kafka Consumer to wait + (in millisecond). + """ + super(KafkaDataset, self).__init__() + self._topics = ops.convert_to_tensor( + topics, dtype=dtypes.string, name="topics") + self._servers = ops.convert_to_tensor( + servers, dtype=dtypes.string, name="servers") + self._group = ops.convert_to_tensor( + group, dtype=dtypes.string, name="group") + self._eof = ops.convert_to_tensor(eof, dtype=dtypes.bool, name="eof") + self._timeout = ops.convert_to_tensor( + timeout, dtype=dtypes.int64, name="timeout") + + def _as_variant_tensor(self): + return gen_kafka_ops.kafka_dataset(self._topics, self._servers, self._group, + self._eof, self._timeout) + + @property + def output_classes(self): + return ops.Tensor + + @property + def output_shapes(self): + return tensor_shape.scalar() + + @property + def output_types(self): + return dtypes.string diff --git a/tensorflow/contrib/keras/api/__init__.py b/tensorflow/contrib/keras/api/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..52e83069cb0c68b510da46149248369dce376647 100644 --- a/tensorflow/contrib/keras/api/__init__.py +++ b/tensorflow/contrib/keras/api/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index a2f320ab11291e4049c8367e1f133a4fbcb72a62..eff7dfeb4c1117e40f4faf43c5e92a52cffd6528 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -83,9 +83,11 @@ py_test( srcs_version = "PY2AND3", deps = [ ":kernel_methods", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kernel_methods/python/losses.py b/tensorflow/contrib/kernel_methods/python/losses.py index 208b0e1c9dbe93fb99e17e7be5ed5b6e30f4e201..f182fef067b7f523bc5ca63227265be40528b171 100644 --- a/tensorflow/contrib/kernel_methods/python/losses.py +++ b/tensorflow/contrib/kernel_methods/python/losses.py @@ -73,13 +73,13 @@ def sparse_multiclass_hinge_loss( labels)) as scope: # Check logits Tensor has valid rank. - logits_shape = logits.get_shape() - logits_rank = logits_shape.ndims + logits_rank = logits.get_shape().ndims if logits_rank != 2: raise ValueError( 'logits should have rank 2 ([batch_size, num_classes]). Given rank is' ' {}'.format(logits_rank)) - batch_size, num_classes = logits_shape[0].value, logits_shape[1].value + logits_shape = array_ops.shape(logits) + batch_size, num_classes = logits_shape[0], logits_shape[1] logits = math_ops.to_float(logits) # Check labels have valid type. diff --git a/tensorflow/contrib/kernel_methods/python/losses_test.py b/tensorflow/contrib/kernel_methods/python/losses_test.py index 8a1a5ffe56ba283bfae514738fa87e4055f8934e..72507539f813d14064bc58f03b6db4781abc9438 100644 --- a/tensorflow/contrib/kernel_methods/python/losses_test.py +++ b/tensorflow/contrib/kernel_methods/python/losses_test.py @@ -18,10 +18,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.kernel_methods.python import losses from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -114,6 +117,27 @@ class SparseMulticlassHingeLossTest(test.TestCase): loss = losses.sparse_multiclass_hinge_loss(labels, logits) self.assertAlmostEqual(loss.eval(), 0.0, 3) + def testUnknownShape(self): + """Result keeps same with `testZeroLossInt32Labels`""" + logits_np = np.array([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]]) + labels_np = np.array([0, 2, 1], dtype=np.int32) + + logits_shapes = [ + [3, 3], # batch_size, num_classes + [None, 3], + [3, None], + [None, None] + ] + + for batch_size, num_classes in logits_shapes: + with self.test_session(): + logits = array_ops.placeholder( + dtypes.float32, shape=(batch_size, num_classes)) + labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,)) + loss = losses.sparse_multiclass_hinge_loss(labels, logits) + result = loss.eval(feed_dict={logits: logits_np, labels: labels_np}) + self.assertAlmostEqual(result, 0.0, 3) + def testCorrectPredictionsSomeClassesInsideMargin(self): """Loss is > 0 even if true class logits are higher than other classes.""" with self.test_session(): diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py index 558bc294bc8ac129b3055ed46623c78a0d5a33e3..39d80addaac1fe855a37255b32bf4412b99df46a 100644 --- a/tensorflow/contrib/kfac/examples/convnet.py +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -286,7 +286,7 @@ def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, damping=0.001, layer_collection=layer_collection, momentum=0.9) - inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values()) + inv_update_queue = oq.OpQueue(optimizer.inv_update_ops) sync_optimizer = tf.train.SyncReplicasOptimizer( opt=optimizer, replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py index 4275ceadc210ff471109b596e1c9aa260ce31ab5..87eed03888c894a04c0521d1ce5ee8975b60776b 100644 --- a/tensorflow/contrib/kfac/examples/mlp.py +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -239,3 +239,88 @@ def train_mnist_multitower(data_dir, }) return minimize( loss, accuracy, layer_collection, session_config=session_config) + + +def train_mnist_estimator(data_dir, num_epochs, use_fake_data=False): + """Train an MLP on MNIST using tf.estimator. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + + # Load a dataset. + def input_fn(): + tf.logging.info("Loading MNIST into memory.") + return mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=64, + flatten_images=True, + use_fake_data=use_fake_data) + + def model_fn(features, labels, mode, params): + """Model function for MLP trained with K-FAC. + + Args: + features: Tensor of shape [batch_size, input_size]. Input features. + labels: Tensor of shape [batch_size]. Target labels for training. + mode: tf.estimator.ModeKey. Must be TRAIN. + params: ignored. + + Returns: + EstimatorSpec for training. + + Raises: + ValueError: If 'mode' is anything other than TRAIN. + """ + del params + + if mode != tf.estimator.ModeKeys.TRAIN: + raise ValueError("Only training is supposed with this API.") + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + loss, accuracy = build_model( + features, labels, num_labels=10, layer_collection=layer_collection) + + # Train with K-FAC. + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=tf.train.exponential_decay( + 0.00002, global_step, 10000, 0.5, staircase=True), + cov_ema_decay=0.95, + damping=0.0001, + layer_collection=layer_collection, + momentum=0.99) + + # Run cov_update_op every step. Run 1 inv_update_ops per step. + cov_update_op = optimizer.cov_update_op + inv_update_op = tf.group( + tf.contrib.kfac.utils.batch_execute( + global_step, optimizer.inv_update_thunks, batch_size=1)) + with tf.control_dependencies([cov_update_op, inv_update_op]): + train_op = optimizer.minimize(loss, global_step=global_step) + + # Print metrics every 5 sec. + hooks = [ + tf.train.LoggingTensorHook( + { + "loss": loss, + "accuracy": accuracy + }, every_n_secs=5), + ] + return tf.estimator.EstimatorSpec( + mode=mode, loss=loss, train_op=train_op, training_hooks=hooks) + + run_config = tf.estimator.RunConfig( + model_dir="/tmp/mnist", save_checkpoints_steps=1, keep_checkpoint_max=100) + + # Train until input_fn() is empty with Estimator. This is a prerequisite for + # TPU compatibility. + estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config) + estimator.train(input_fn=input_fn) diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py index b318c71a568be2d717745579df24134ceb3b6a0b..9c34ade1d2018135b3636fddb9dcc65839cd59de 100644 --- a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py +++ b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py @@ -33,7 +33,11 @@ FLAGS = None def main(argv): _ = argv - if FLAGS.num_towers > 1: + if FLAGS.use_estimator: + if FLAGS.num_towers != 1: + raise ValueError("Only 1 device supported in tf.estimator example.") + mlp.train_mnist_estimator(FLAGS.data_dir, num_epochs=200) + elif FLAGS.num_towers > 1: mlp.train_mnist_multitower( FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) else: @@ -52,5 +56,9 @@ if __name__ == "__main__": type=int, default=1, help="Number of CPUs to split minibatch across.") + parser.add_argument( + "--use_estimator", + action="store_true", + help="Use tf.estimator API to train.") FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py index cf92c909f4b5201bc0ffda5703136f46c7058ec6..547c4ab25d589192f2a5b65987be3b05128fe298 100644 --- a/tensorflow/contrib/kfac/examples/mnist.py +++ b/tensorflow/contrib/kfac/examples/mnist.py @@ -63,7 +63,7 @@ def load_mnist(data_dir, images = mnist_data.train.images labels = mnist_data.train.labels - dataset = tf.contrib.data.Dataset.from_tensor_slices((np.asarray( + dataset = tf.data.Dataset.from_tensor_slices((np.asarray( images, dtype=np.float32), np.asarray(labels, dtype=np.int64))) return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size) .make_one_shot_iterator().get_next()) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py index 3c98c54ef6cbd527aa0035e0b6f40be961c6308d..8d86c2bb5150cd4bc8a2b21ba050e904929e0fe9 100644 --- a/tensorflow/contrib/kfac/examples/tests/convnet_test.py +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -96,7 +96,7 @@ class ConvNetTest(tf.test.TestCase): """ x = np.asarray([[1.], [2.]]).astype(np.float32) y = np.asarray([1., 2.]).astype(np.float32) - x, y = (tf.contrib.data.Dataset.from_tensor_slices((x, y)) + x, y = (tf.data.Dataset.from_tensor_slices((x, y)) .repeat(100).batch(2).make_one_shot_iterator().get_next()) w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer()) y_hat = tf.matmul(x, w) diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py index 34a942d27f64e2583c686c2ba3240bc636ed918b..22da6c29f1b364d94432315988d844db9b95ec28 100644 --- a/tensorflow/contrib/kfac/examples/tests/mlp_test.py +++ b/tensorflow/contrib/kfac/examples/tests/mlp_test.py @@ -53,6 +53,11 @@ class MlpTest(tf.test.TestCase): mlp.train_mnist_multitower( data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) + def testTrainMnistEstimator(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + mlp.train_mnist_estimator(data_dir=None, num_epochs=1, use_fake_data=True) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD index 95fba59e3c96ae3c69e0b154740785b0d2bcb3c9..f4ed978174a9ddd8b54a88e60bfb48a67a2e76d2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/BUILD +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -17,12 +17,17 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:init_ops", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:training", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", ], ) @@ -110,12 +115,15 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/contrib/tpu", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py index 9b28c45c7263208d21b1514ae5f05b7e81e315a3..bfdb69ad02caaa57827e0ae6b3c9fc0d0ed03754 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.kfac.python.ops import estimator from tensorflow.contrib.kfac.python.ops import layer_collection as lc from tensorflow.contrib.kfac.python.ops import utils @@ -25,11 +27,15 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.training import training_util _ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] @@ -119,6 +125,114 @@ class EstimatorTest(test.TestCase): estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, mode) + def test_cov_update_thunks(self): + """Ensures covariance update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimator( + variables=[self.weights], + layer_collection=self.layer_collection, + cov_ema_decay=0.0, + damping=0.0) + + # Construct an op that executes one covariance update per step. + global_step = training_util.get_or_create_global_step() + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + cov_update_op_thunks = fisher_estimator.cov_update_thunks + cov_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(cov_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_cov_values = sess.run(cov_matrices) + + # Ensure there's one update per covariance matrix. + self.assertEqual(len(cov_matrices), len(cov_update_op_thunks)) + + # Test is no-op if only 1 covariance matrix. + assert len(cov_matrices) > 1 + + for i in range(len(cov_matrices)): + # Compare new and old covariance values + new_cov_values = sess.run(cov_matrices) + is_cov_equal = [ + np.allclose(initial_cov_value, new_cov_value) + for (initial_cov_value, + new_cov_value) in zip(initial_cov_values, new_cov_values) + ] + num_cov_equal = sum(is_cov_equal) + + # Ensure exactly one covariance matrix changes per step. + self.assertEqual(num_cov_equal, len(cov_matrices) - i) + + # Run all covariance update ops. + sess.run(cov_update_op) + sess.run(increment_global_step) + + def test_inv_update_thunks(self): + """Ensures inverse update ops run once per global_step.""" + with self._graph.as_default(), self.test_session() as sess: + fisher_estimator = estimator.FisherEstimator( + variables=[self.weights], + layer_collection=self.layer_collection, + cov_ema_decay=0.0, + damping=0.0) + + # Construct op that updates one inverse per global step. + global_step = training_util.get_or_create_global_step() + inv_matrices = [ + matrix + for fisher_factor in self.layer_collection.get_factors() + for matrix in fisher_factor._inverses_by_damping.values() + ] + inv_update_op_thunks = fisher_estimator.inv_update_thunks + inv_update_op = control_flow_ops.case( + [(math_ops.equal(global_step, i), thunk) + for i, thunk in enumerate(inv_update_op_thunks)]) + increment_global_step = global_step.assign_add(1) + + sess.run(variables.global_variables_initializer()) + initial_inv_values = sess.run(inv_matrices) + + # Ensure there's one update per inverse matrix. This is true as long as + # there's no fan-in/fan-out or parameter re-use. + self.assertEqual(len(inv_matrices), len(inv_update_op_thunks)) + + # Test is no-op if only 1 invariance matrix. + assert len(inv_matrices) > 1 + + # Assign each covariance matrix a value other than the identity. This + # ensures that the inverse matrices are updated to something different as + # well. + cov_matrices = [ + fisher_factor.get_cov() + for fisher_factor in self.layer_collection.get_factors() + ] + sess.run([ + cov_matrix.assign(2 * linalg_ops.eye(int(cov_matrix.shape[0]))) + for cov_matrix in cov_matrices + ]) + + for i in range(len(inv_matrices)): + # Compare new and old inverse values + new_inv_values = sess.run(inv_matrices) + is_inv_equal = [ + np.allclose(initial_inv_value, new_inv_value) + for (initial_inv_value, + new_inv_value) in zip(initial_inv_values, new_inv_values) + ] + num_inv_equal = sum(is_inv_equal) + + # Ensure exactly one inverse matrix changes per step. + self.assertEqual(num_inv_equal, len(inv_matrices) - i) + + # Run all inverse update ops. + sess.run(inv_update_op) + sess.run(increment_global_step) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py index 5f2b5c6cace9cd18f4cc5590ff55a9b39680a381..82accd57f0c37d140238f1884fce956654d14227 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -40,6 +40,21 @@ def _make_psd(dim): return array_ops.constant(mat) +class UtilsTest(test.TestCase): + + def testComputePiTracenorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + left_factor = array_ops.diag([1., 2., 0., 1.]) + right_factor = array_ops.ones([2., 2.]) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = fb.compute_pi_tracenorm(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + class FullFBTest(test.TestCase): def testFullFBInitSingleTensor(self): @@ -301,8 +316,7 @@ class FullyConnectedDiagonalFB(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -584,8 +598,7 @@ class ConvDiagonalFBTest(test.TestCase): multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( self.w, [self.inputs], [self.outputs], [self.output_grads]) multiply_result_small, multiply_inverse_result_small = ( - self.runFisherBlockOps(self.w, - np.split(self.inputs, 2), + self.runFisherBlockOps(self.w, np.split(self.inputs, 2), np.split(self.outputs, 2), np.split(self.output_grads, 2))) @@ -608,8 +621,9 @@ class ConvDiagonalFBTest(test.TestCase): self.kernel_size, self.kernel_size, self.input_channels + 1, self.output_channels ]) - expected_result = (expected_result[:, :, 0:-1, :], np.reshape( - expected_result[:, :, -1, :], [self.output_channels])) + expected_result = (expected_result[:, :, 0:-1, :], + np.reshape(expected_result[:, :, -1, :], + [self.output_channels])) self.assertEqual(len(result), 2) self.assertAllClose(expected_result[0], result[0]) @@ -692,8 +706,8 @@ class ConvKFCBasicFBTest(test.TestCase): sess.run(block._input_factor.make_inverse_update_ops()) sess.run(block._output_factor.make_inverse_update_ops()) - vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), np.arange( - 2, 4).reshape(2, 1).astype(np.float32)) + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), + np.arange(2, 4).reshape(2, 1).astype(np.float32)) output = block.multiply_inverse((array_ops.constant(vector[0]), array_ops.constant(vector[1]))) @@ -776,11 +790,50 @@ class ConvKFCBasicFBTest(test.TestCase): self.assertAllClose(output_flat, explicit) +class FullyConnectedSeriesFBTest(test.TestCase): + + def testFullyConnectedSeriesFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), inputs=[inputs], outputs=[outputs]) + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=True) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedSeriesFB( + lc.LayerCollection(), + inputs=[inputs], + outputs=[outputs], + has_bias=False) + grads = outputs**2 + block.instantiate_factors(((grads,),), 0.5) + + def as_tensors(tensor_or_tuple): """Converts a potentially nested tuple of np.array to Tensors.""" if isinstance(tensor_or_tuple, (tuple, list)): return tuple(as_tensors(t) for t in tensor_or_tuple) return ops.convert_to_tensor(tensor_or_tuple) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py index 5e2ce5a3096f5b523fafad56be742154d79e4803..753378d9f4a0d8762bafbee2ec27d6c71783dda1 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -35,18 +35,27 @@ from tensorflow.python.platform import test class MaybeColocateTest(test.TestCase): + def setUp(self): + self._colocate_cov_ops_with_inputs = ff.COLOCATE_COV_OPS_WITH_INPUTS + + def tearDown(self): + ff.set_global_constants( + colocate_cov_ops_with_inputs=self._colocate_cov_ops_with_inputs) + def testFalse(self): + ff.set_global_constants(colocate_cov_ops_with_inputs=False) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a, False): + with ff.maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@b'], b.op.colocation_groups()) def testTrue(self): + ff.set_global_constants(colocate_cov_ops_with_inputs=True) with tf_ops.Graph().as_default(): a = constant_op.constant([2.0], name='a') - with ff._maybe_colocate_with(a, True): + with ff.maybe_colocate_with(a): b = constant_op.constant(3.0, name='b') self.assertEqual([b'loc:@a'], a.op.colocation_groups()) self.assertEqual([b'loc:@a'], b.op.colocation_groups()) @@ -67,12 +76,19 @@ class FisherFactorTestingDummy(ff.FisherFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError def instantiate_covariance(self): pass + def make_inverse_update_ops(self): + return [] + class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. @@ -94,6 +110,10 @@ class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return dtypes.float32 + def _compute_new_cov(self): raise NotImplementedError @@ -109,7 +129,7 @@ class NumericalUtilsTest(test.TestCase): random_seed.set_random_seed(200) x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x)) + cov = ff.compute_cov(array_ops.constant(x)) np_cov = np.dot(x.T, x) / x.shape[0] self.assertAllClose(sess.run(cov), np_cov) @@ -121,7 +141,7 @@ class NumericalUtilsTest(test.TestCase): normalizer = 10. x = npr.randn(100, 3) - cov = ff._compute_cov(array_ops.constant(x), normalizer) + cov = ff.compute_cov(array_ops.constant(x), normalizer=normalizer) np_cov = np.dot(x.T, x) / normalizer self.assertAllClose(sess.run(cov), np_cov) @@ -132,7 +152,7 @@ class NumericalUtilsTest(test.TestCase): m, n = 3, 4 a = npr.randn(m, n) - a_homog = ff._append_homog(array_ops.constant(a)) + a_homog = ff.append_homog(array_ops.constant(a)) np_result = np.hstack([a, np.ones((m, 1))]) self.assertAllClose(sess.run(a_homog), np_result) @@ -267,13 +287,13 @@ class InverseProvidingFactorTest(test.TestCase): for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): factor.register_damped_inverse(1. / i) ops = factor.make_inverse_update_ops() - self.assertEqual(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD, len(ops)) + self.assertEqual(1, len(ops)) sess.run(tf_variables.global_variables_initializer()) new_invs = [] + sess.run(ops) for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): # The inverse op will assign the damped inverse of cov to the inv var. - sess.run(ops[i - 1]) new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) # We want to see that the new invs are all different from each other. for i in range(len(new_invs)): @@ -331,6 +351,16 @@ class FullFactorTest(test.TestCase): factor = ff.FullFactor((tensor,), 32) self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + def testFullFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 6], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -351,6 +381,16 @@ class NaiveDiagonalFactorTest(test.TestCase): factor = ff.NaiveDiagonalFactor((tensor,), 32) self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + def testNaiveDiagonalFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([6, 1], cov.get_shape().as_list()) + def testMakeCovarianceUpdateOp(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -364,18 +404,25 @@ class NaiveDiagonalFactorTest(test.TestCase): class FullyConnectedKroneckerFactorTest(test.TestCase): - def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape): + def _testFullyConnectedKroneckerFactorInit(self, + has_bias, + final_shape, + dtype=dtypes.float32_ref): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) - tensor = array_ops.ones((2, 3), name='a/b/c') + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) - self.assertEqual(final_shape, factor.get_cov().get_shape().as_list()) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual(final_shape, cov.get_shape().as_list()) def testFullyConnectedKroneckerFactorInitNoBias(self): - self._testFullyConnectedKroneckerFactorInit(False, [3, 3]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3], dtype=dtype) def testFullyConnectedKroneckerFactorInitWithBias(self): - self._testFullyConnectedKroneckerFactorInit(True, [4, 4]) + for dtype in (dtypes.float32_ref, dtypes.float64_ref): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4], dtype=dtype) def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: @@ -418,6 +465,18 @@ class ConvInputKroneckerFactorTest(test.TestCase): self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], factor.get_cov().get_shape().as_list()) + def testConvInputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + cov.get_shape().as_list()) + def testMakeCovarianceUpdateOpWithBias(self): with tf_ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -453,6 +512,16 @@ class ConvOutputKroneckerFactorTest(test.TestCase): factor = ff.ConvOutputKroneckerFactor((tensor,)) self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + def testConvOutputKroneckerFactorInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), dtype=dtype, name='a/b/c') + factor = ff.ConvOutputKroneckerFactor((tensor,)) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([5, 5], cov.get_shape().as_list()) + def testConvOutputKroneckerFactorInitNotEnoughDims(self): with tf_ops.Graph().as_default(): random_seed.set_random_seed(200) @@ -471,5 +540,49 @@ class ConvOutputKroneckerFactorTest(test.TestCase): self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) +class FullyConnectedMultiKFTest(test.TestCase): + + def testFullyConnectedMultiKFInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + self.assertEqual([3, 3], factor.get_cov().get_shape().as_list()) + + def testFullyConnectedMultiKFInitFloat64(self): + with tf_ops.Graph().as_default(): + dtype = dtypes.float64_ref + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), dtype=dtype, name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=False) + cov = factor.get_cov() + self.assertEqual(cov.dtype, dtype) + self.assertEqual([3, 3], cov.get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,), has_bias=True) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + tensor_list = [tensor] + factor = ff.FullyConnectedMultiKF((tensor_list,)) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py index 39ce3e9337157c8206107bc40c489e44019743ab..ae787b6f1ac90218f2ac73d37fb270df0b822de2 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -113,6 +113,113 @@ class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): self.assertListEqual(loss.input_minibatches, tower_logits) self.assertEqual(loss.num_registered_minibatches, num_towers) + def testMultiplyFisherSingleVector(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([1., 2., 3.]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + # the LossFunction.multiply_fisher docstring only says it supports the + # case where the vector is the same shape as the input natural parameters + # (i.e. the logits here), but here we also test leading dimensions + vector = np.array([1., 2., 3.]) + vectors = [vector, vector.reshape(1, -1), np.stack([vector] * 4)] + + probs = np.exp(logits - np.logaddexp.reduce(logits)) + fisher = np.diag(probs) - np.outer(probs, probs) + + for vector in vectors: + result = loss.multiply_fisher(vector) + expected_result = np.dot(vector, fisher) + self.assertAllClose(expected_result, sess.run(result)) + + def testMultiplyFisherBatch(self): + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.array([[1., 2., 3.], [4., 6., 8.]]) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss(logits) + + vector = np.array([[1., 2., 3.], [5., 3., 1.]]) + + na = np.newaxis + probs = np.exp(logits - np.logaddexp.reduce(logits, axis=-1, + keepdims=True)) + fishers = probs[..., na] * np.eye(3) - probs[..., na] * probs[..., na, :] + + result = loss.multiply_fisher(vector) + expected_result = np.matmul(vector[..., na, :], fishers)[..., 0, :] + self.assertEqual(sess.run(result).shape, logits.shape) + self.assertAllClose(expected_result, sess.run(result)) + + +class OnehotCategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2, 3)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.one_hot(targets, 3)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + + def testMultiMinibatchRegistration(self): + """Ensure this loss function supports registering multiple minibatches.""" + with ops.Graph().as_default(): + tower_logits = [] + loss = None + num_towers = 5 + for _ in range(num_towers): + logits = random_ops.random_uniform(shape=[2, 3]) + tower_logits.append(logits) + if loss is None: + loss = loss_functions.OnehotCategoricalLogitsNegativeLogProbLoss( + logits) + else: + loss.register_additional_minibatch(logits) + self.assertListEqual(loss.input_minibatches, tower_logits) + self.assertEqual(loss.num_registered_minibatches, num_towers) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py index 55fe38e3e9aab2dbd70a45cdc8fa0c208b036db0..97a97adbf5577cd2694d3055acaa59258ad27964 100644 --- a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -22,11 +22,15 @@ import numpy as np import numpy.random as npr from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -95,6 +99,18 @@ class SubGraphTest(test.TestCase): filtered_list = sub_graph.filter_list(input_list) self.assertEqual(filtered_list, [b]) + def testVariableUses(self): + with ops.Graph().as_default(): + var = variable_scope.get_variable('var', shape=[10, 10]) + resource_var = variable_scope.get_variable( + 'resource_var', shape=[10, 10], use_resource=True) + x = array_ops.zeros([3, 10]) + z0 = math_ops.matmul(x, var) + math_ops.matmul(x, var) + z1 = math_ops.matmul(x, resource_var) + sub_graph = utils.SubGraph((z0, z1)) + self.assertEqual(2, sub_graph.variable_uses(var)) + self.assertEqual(1, sub_graph.variable_uses(resource_var)) + class UtilsTest(test.TestCase): @@ -222,18 +238,6 @@ class UtilsTest(test.TestCase): self.assertAllClose(b, np.array([4., 5.])) self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) - def testComputePi(self): - with ops.Graph().as_default(), self.test_session() as sess: - random_seed.set_random_seed(200) - left_factor = array_ops.diag([1., 2., 0., 1.]) - right_factor = array_ops.ones([2., 2.]) - - # pi is the sqrt of the left trace norm divided by the right trace norm - pi = utils.compute_pi(left_factor, right_factor) - - pi_val = sess.run(pi) - self.assertEqual(1., pi_val) - def testPosDefInvCholesky(self): with ops.Graph().as_default(), self.test_session() as sess: random_seed.set_random_seed(200) @@ -265,6 +269,62 @@ class UtilsTest(test.TestCase): np_inv = np.linalg.inv(x + damp * np.eye(size)) self.assertAllClose(sess.run(tf_inv), np_inv) + def testCrossReplicaMean(self): + """Ensures that cross_replica_mean() executes only when num_shards > 1.""" + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(4): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertNotEqual(mean, tensor) + + with ops.Graph().as_default(): + with tpu_function.tpu_shard_context(1): + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + self.assertEqual(mean, tensor) + + with ops.Graph().as_default(): + with self.assertRaises(ValueError): # Outside of TPU context. + tensor = array_ops.zeros([], dtype=dtypes.float32) + mean = utils.cross_replica_mean(tensor) + + def testBatchExecute(self): + """Ensure batch_execute runs in a round-robin fashion.""" + + def increment_var(var): + return lambda: var.assign_add(1) + + with ops.Graph().as_default(), self.test_session() as sess: + i = variable_scope.get_variable('i', initializer=0) + accumulators = [ + variable_scope.get_variable('var%d' % j, initializer=0) + for j in range(3) + ] + thunks = [increment_var(var) for var in accumulators] + increment_accumulators = utils.batch_execute(i, thunks, 2) + increment_i = i.assign_add(1) + + sess.run(variables.global_variables_initializer()) + + # Ensure one op per thunk. + self.assertEqual(3, len(increment_accumulators)) + + # Ensure round-robin execution. + values = [] + for _ in range(5): + sess.run(increment_accumulators) + sess.run(increment_i) + values.append(sess.run(accumulators)) + self.assertAllClose( + [ + [1, 1, 0], # + [2, 1, 1], # + [2, 2, 2], # + [3, 3, 2], # + [4, 3, 3] + ], + values) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD index b2272a4cee09b35ff672514077b4b128b870b772..ee6549b109399766579b6ea18a987ae2c8275983 100644 --- a/tensorflow/contrib/kfac/python/ops/BUILD +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -38,6 +38,7 @@ py_library( ":utils", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:special_math_ops", @@ -64,6 +65,7 @@ py_library( srcs = ["loss_functions.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", "//tensorflow/python:tensor_shape", @@ -195,7 +197,9 @@ py_library( srcs = ["utils.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/tpu", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:gradients", diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py index 27ff951f16112e09b82ac6885072d966de09983f..a7b1f9d35c931fc44408be804479e758f28f7110 100644 --- a/tensorflow/contrib/kfac/python/ops/estimator.py +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -20,7 +20,6 @@ from __future__ import print_function import contextlib import itertools -import math import numpy as np @@ -67,7 +66,21 @@ class _DeviceContextGenerator(object): class FisherEstimator(object): - """Fisher estimator class supporting various approximations of the Fisher.""" + """Fisher estimator class supporting various approximations of the Fisher. + + Attributes: + cov_update_thunks: list of no-arg functions. Executing a function adds + covariance update ops for a single FisherFactor to the graph. + cov_update_ops: List of Ops. Running an op updates covariance matrices for a + single FisherFactor. + cov_update_op: Op. Running updates covariance matrices for all + FisherFactors. + inv_update_thunks: list of no-arg functions. Executing a function adds + inverse update ops for a single FisherFactor to the graph. + inv_update_ops: List of Ops. Running an op updates inverse matrices for a + single FisherFactor. + inv_update_op: Op. Running updates inverse matrices for all FisherFactors. + """ def __init__(self, variables, @@ -75,7 +88,7 @@ class FisherEstimator(object): damping, layer_collection, estimation_mode="gradients", - colocate_gradients_with_ops=False, + colocate_gradients_with_ops=True, cov_devices=None, inv_devices=None): """Create a FisherEstimator object. @@ -111,7 +124,7 @@ class FisherEstimator(object): is more expensive to compute than the other three options by a factor equal to the output dimension, roughly speaking. colocate_gradients_with_ops: Whether we should request gradients be - colocated with their respective ops. + colocated with their respective ops. (Default: True) cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. @@ -123,12 +136,13 @@ class FisherEstimator(object): ValueError: If no losses have been registered with layer_collection. """ + self._cov_ema_decay = cov_ema_decay self._variables = variables self._damping = damping self._estimation_mode = estimation_mode self._layers = layer_collection self._layers.create_subgraph() - self._check_registration(variables) + self._layers.check_registration(variables) self._gradient_fns = { "gradients": self._get_grads_lists_gradients, "empirical": self._get_grads_lists_empirical, @@ -136,13 +150,31 @@ class FisherEstimator(object): "exact": self._get_grads_lists_exact } self._colocate_gradients_with_ops = colocate_gradients_with_ops + + # TODO(b/70674513): Factor device placement outside of this class. self._cov_device_context_generator = _DeviceContextGenerator(cov_devices) if inv_devices == cov_devices: self._inv_device_context_generator = self._cov_device_context_generator else: self._inv_device_context_generator = _DeviceContextGenerator(inv_devices) - setup = self._setup(cov_ema_decay) - self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup + + self._instantiate_factors() + + self.cov_update_thunks = [ + self._create_cov_update_thunk(factor) + for factor in self._layers.get_factors() + ] + self.cov_update_ops = [thunk() for thunk in self.cov_update_thunks] + self.cov_update_op = control_flow_ops.group( + self.cov_update_ops, name="cov_update_op") + + self.inv_update_thunks = [ + self._create_inv_update_thunk(factor) + for factor in self._layers.get_factors() + ] + self.inv_update_ops = [thunk() for thunk in self.inv_update_thunks] + self.inv_update_op = control_flow_ops.group( + self.inv_update_ops, name="inv_update_op") @property def variables(self): @@ -203,61 +235,8 @@ class FisherEstimator(object): return self._apply_transformation(vecs_and_vars, lambda fb, vec: fb.multiply(vec)) - def _check_registration(self, variables): - """Checks that all variable uses have been registered properly. - - Args: - variables: List of variables. - - Raises: - ValueError: If any registered variables are not included in the list. - ValueError: If any variable in the list is not registered. - ValueError: If any variable in the list is registered with the wrong - number of "uses" in the subgraph recorded (vs the number of times that - variable is actually used in the subgraph). - """ - # Note that overlapping parameters (i.e. those that share variables) will - # be caught by layer_collection.LayerParametersDict during registration. - - reg_use_map = self._layers.get_use_count_map() - - error_messages = [] - - for var in variables: - total_uses = self._layers.subgraph.variable_uses(var) - reg_uses = reg_use_map[var] - - if reg_uses == 0: - error_messages.append("Variable {} not registered.".format(var)) - elif (not math.isinf(reg_uses)) and reg_uses != total_uses: - error_messages.append( - "Variable {} registered with wrong number of uses ({} " - "registrations vs {} uses).".format(var, reg_uses, total_uses)) - - num_get_vars = len(reg_use_map) - - if num_get_vars > len(variables): - error_messages.append("{} registered variables were not included in list." - .format(num_get_vars - len(variables))) - - if error_messages: - error_messages = [ - "Found the following errors with variable registration:" - ] + error_messages - raise ValueError("\n\t".join(error_messages)) - - def _setup(self, cov_ema_decay): - """Sets up the various operations. - - Args: - cov_ema_decay: The decay factor used when calculating the covariance - estimate moving averages. - - Returns: - A triple (covs_update_op, invs_update_op, inv_updates_dict), where - covs_update_op is the grouped Op to update all the covariance estimates, - invs_update_op is the grouped Op to update all the inverses, and - inv_updates_dict is a dict mapping Op names to individual inverse updates. + def _instantiate_factors(self): + """Instantiates FisherFactors' variables. Raises: ValueError: If estimation_mode was improperly specified at construction. @@ -282,20 +261,25 @@ class FisherEstimator(object): with self._cov_device_context_generator(): fb.instantiate_factors(grads_list, self.damping) - cov_updates = [ - factor.make_covariance_update_op(cov_ema_decay) - for factor in self._layers.get_factors() - ] - inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} + def _create_cov_update_thunk(self, factor): + """Constructs a covariance update thunk for a single FisherFactor.""" + + def thunk(): + with tf_ops.name_scope( + "create_cov_update_thunk", values=[self._cov_ema_decay]): + return factor.make_covariance_update_op(self._cov_ema_decay) + + return thunk - return control_flow_ops.group(*cov_updates), control_flow_ops.group( - *inv_updates.values()), inv_updates + def _create_inv_update_thunk(self, factor): + """Constructs an inverse update thunk for a single FisherFactor.""" - def _get_all_inverse_update_ops(self): - for factor in self._layers.get_factors(): - with self._inv_device_context_generator(): - for op in factor.make_inverse_update_ops(): - yield op + def thunk(): + with tf_ops.name_scope("create_inv_update_thunk"): + with self._inv_device_context_generator(): + return control_flow_ops.group(factor.make_inverse_update_ops()) + + return thunk def _get_grads_lists_gradients(self, tensors): grads_flat = gradients_impl.gradients( @@ -333,11 +317,7 @@ class FisherEstimator(object): return tuple((grad,) for grad in grads_all) def _get_grads_lists_exact(self, tensors): - """Returns a list of all gradients, computing them exactly. - - Args: - tensors: Tensors for which to compute gradients. - """ + """No docstring required.""" # Loop over all coordinates of all losses. grads_all = [] for loss in self._layers.losses: diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py index e822a1213a4132522be8031401609c78572cb1a6..0d2fa706f5853570bb8c04a9b9ac3378e2f2386e 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -38,6 +38,7 @@ from __future__ import division from __future__ import print_function import abc +import enum # pylint: disable=g-bad-import-order import six @@ -52,14 +53,61 @@ from tensorflow.python.ops import math_ops # damping /= num_replications ** NORMALIZE_DAMPING_POWER NORMALIZE_DAMPING_POWER = 1.0 +# Methods for adjusting damping for FisherBlocks. See +# compute_pi_adjusted_damping() for details. +PI_OFF_NAME = "off" +PI_TRACENORM_NAME = "tracenorm" +PI_TYPE = PI_TRACENORM_NAME -def set_global_constants(normalize_damping_power=None): + +def set_global_constants(normalize_damping_power=None, pi_type=None): """Sets various global constants used by the classes in this module.""" global NORMALIZE_DAMPING_POWER + global PI_TYPE if normalize_damping_power is not None: NORMALIZE_DAMPING_POWER = normalize_damping_power + if pi_type is not None: + PI_TYPE = pi_type + + +def normalize_damping(damping, num_replications): + """Normalize damping after adjusting scale by NORMALIZE_DAMPING_POWER.""" + if NORMALIZE_DAMPING_POWER: + return damping / (num_replications ** NORMALIZE_DAMPING_POWER) + return damping + + +def compute_pi_tracenorm(left_cov, right_cov): + """Computes the scalar constant pi for Tikhonov regularization/damping. + + pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_cov: The left Kronecker factor "covariance". + right_cov: The right Kronecker factor "covariance". + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = math_ops.trace(left_cov) * right_cov.shape.as_list()[0] + right_norm = math_ops.trace(right_cov) * left_cov.shape.as_list()[0] + return math_ops.sqrt(left_norm / right_norm) + + +def compute_pi_adjusted_damping(left_cov, right_cov, damping): + + if PI_TYPE == PI_TRACENORM_NAME: + pi = compute_pi_tracenorm(left_cov, right_cov) + return (damping * pi, damping / pi) + + elif PI_TYPE == PI_OFF_NAME: + return (damping, damping) + @six.add_metaclass(abc.ABCMeta) class FisherBlock(object): @@ -153,7 +201,7 @@ class FullFB(FisherBlock): self._factor.register_damped_inverse(damping) def multiply_inverse(self, vector): - inverse = self._factor.get_inverse(self._damping) + inverse = self._factor.get_damped_inverse(self._damping) out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) return utils.column_to_tensors(vector, out_flat) @@ -410,9 +458,8 @@ class ConvDiagonalFB(FisherBlock): inputs_shape[1] * inputs_shape[2] // (self._strides[1] * self._strides[2])) - if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations ** NORMALIZE_DAMPING_POWER - self._damping = damping + self._damping = (self._num_locations + * normalize_damping(damping, self._num_locations)) self._factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvDiagonalFactor, @@ -465,11 +512,10 @@ class KroneckerProductFB(FisherBlock): Args: damping: The base damping factor (float or Tensor) for the damped inverse. """ - pi = utils.compute_pi(self._input_factor.get_cov(), - self._output_factor.get_cov()) - - self._input_damping = (damping**0.5) * pi - self._output_damping = (damping**0.5) / pi + self._input_damping, self._output_damping = compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) self._input_factor.register_damped_inverse(self._input_damping) self._output_factor.register_damped_inverse(self._output_damping) @@ -487,8 +533,9 @@ class KroneckerProductFB(FisherBlock): return 1.0 def multiply_inverse(self, vector): - left_factor_inv = self._input_factor.get_inverse(self._input_damping) - right_factor_inv = self._output_factor.get_inverse(self._output_damping) + left_factor_inv = self._input_factor.get_damped_inverse(self._input_damping) + right_factor_inv = self._output_factor.get_damped_inverse( + self._output_damping) reshaped_vector = utils.layer_params_to_mat2d(vector) reshaped_out = math_ops.matmul(left_factor_inv, math_ops.matmul(reshaped_vector, @@ -650,8 +697,8 @@ class ConvKFCBasicFB(KroneckerProductFB): grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) # Infer number of locations upon which convolution is applied. - self._num_locations = _num_conv_locations(inputs.shape.as_list(), - self._strides) + self._num_locations = num_conv_locations(inputs.shape.as_list(), + self._strides) self._input_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvInputKroneckerFactor, @@ -660,11 +707,9 @@ class ConvKFCBasicFB(KroneckerProductFB): self._output_factor = self._layer_collection.make_or_get_factor( fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) - if NORMALIZE_DAMPING_POWER: - damping /= self._num_locations**NORMALIZE_DAMPING_POWER - self._damping = damping - + damping = normalize_damping(damping, self._num_locations) self._register_damped_input_and_output_inverses(damping) + self._damping = damping @property def _renorm_coeff(self): @@ -717,6 +762,267 @@ def _concat_along_batch_dim(tensor_list): return array_ops.concat(tensor_list, axis=0) -def _num_conv_locations(input_shape, strides): - """Returns the number of locations a Conv kernel is applied to.""" +def num_conv_locations(input_shape, strides): + """Returns the number of spatial locations a 2D Conv kernel is applied to. + + Args: + input_shape: list representing shape of inputs to the Conv layer. + strides: list representing strides for the Conv kernel. + + Returns: + A scalar |T| denoting the number of spatial locations for the Conv layer. + """ return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) + + +class FullyConnectedMultiIndepFB(KroneckerProductFB): + """FisherBlock for fully-connected layers that share parameters. + """ + + def __init__(self, layer_collection, inputs, outputs, has_bias=False): + """Creates a FullyConnectedMultiIndepFB block. + + Args: + layer_collection: LayerCollection instance. + inputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + inputs_size]. + outputs: list or tuple of Tensors. Each Tensor has shape [batch_size, + outputs_size]. + has_bias: bool. If True, estimates Fisher with respect to a bias + parameter as well as the layer's parameters. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_uses = len(inputs) + + super(FullyConnectedMultiIndepFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, + ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + damping = normalize_damping(damping, self._num_uses) + self._register_damped_input_and_output_inverses(damping) + + @property + def _renorm_coeff(self): + return self._num_uses + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) + + +class SeriesFBApproximation(enum.IntEnum): + """See FullyConnectedSeriesFB.__init__ for description and usage.""" + option1 = 1 + option2 = 2 + + +class FullyConnectedSeriesFB(FisherBlock): + """FisherBlock for fully-connected layers that share parameters across time. + + See the following preprint for details: + https://openreview.net/pdf?id=HyMTkQZAb + + See the end of the appendix of the paper for a pseudo-code of the + algorithm being implemented by multiply_inverse here. Note that we are + using pre-computed versions of certain matrix-matrix products to speed + things up. This is explicitly explained wherever it is done. + """ + + def __init__(self, + layer_collection, + inputs, + outputs, + has_bias=False, + option=SeriesFBApproximation.option2): + """Constructs a new `FullyConnectedSeriesFB`. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + inputs: List of tensors of shape [batch_size, input_size]. + Inputs to the layer. + outputs: List of tensors of shape [batch_size, input_size]. + Outputs of the layer (before activations). + has_bias: Whether the layer includes a bias parameter. + option: A `SeriesFBApproximation` specifying the simplifying assumption + to be used in this block. `option1` approximates the cross-covariance + over time as a symmetric matrix, while `option2` makes + the assumption that training sequences are infinitely long. See section + 3.5 of the paper for more details. + """ + + assert len(inputs) == len(outputs) + # We need to make sure inputs and outputs are tuples and not lists so that + # they get hashed by layer_collection.make_or_get_factor properly. + self._inputs = tuple(inputs) + self._outputs = tuple(outputs) + self._has_bias = has_bias + self._num_timesteps = len(inputs) + self._option = option + + super(FullyConnectedSeriesFB, self).__init__(layer_collection) + + @property + def num_registered_minibatches(self): + # TODO(b/69411207): Add support for registering additional minibatches. + return 1 + + def instantiate_factors(self, grads_list, damping): + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, ((self._inputs,), self._has_bias)) + + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedMultiKF, (grads_list,)) + + damping = normalize_damping(damping, self._num_timesteps) + self._damping_input, self._damping_output = compute_pi_adjusted_damping( + self._input_factor.get_cov(), + self._output_factor.get_cov(), + damping**0.5) + + if self._option == SeriesFBApproximation.option1: + self._input_factor.register_option1quants(self._damping_input) + self._output_factor.register_option1quants(self._damping_output) + elif self._option == SeriesFBApproximation.option2: + self._input_factor.register_option2quants(self._damping_input) + self._output_factor.register_option2quants(self._damping_output) + else: + raise ValueError( + "Unrecognized FullyConnectedSeriesFB approximation: {}".format( + self._option)) + + def multiply_inverse(self, vector): + # pylint: disable=invalid-name + + Z = utils.layer_params_to_mat2d(vector) + + # Derivations were done for "batch_dim==1" case so we need to convert to + # that orientation: + Z = array_ops.transpose(Z) + + if self._option == SeriesFBApproximation.option1: + + # Note that L_A = A0^(-1/2) * U_A and L_G = G0^(-1/2) * U_G. + L_A, psi_A = self._input_factor.get_option1quants(self._damping_input) + L_G, psi_G = self._output_factor.get_option1quants(self._damping_output) + + def gamma(x): + # We are assuming that each case has the same number of time-steps. + # If this stops being the case one shouldn't simply replace this T + # with its average value. Instead, one needs to go back to the + # definition of the gamma function from the paper. + T = self._num_timesteps + return (1 - x)**2 / (T * (1 - x**2) - 2 * x * (1 - x**T)) + + # Y = gamma( psi_G*psi_A^T ) (computed element-wise) + # Even though Y is Z-independent we are recomputing it from the psi's + # each since Y depends on both A and G quantities, and it is relatively + # cheap to compute. + Y = gamma(array_ops.reshape(psi_G, [int(psi_G.shape[0]), -1]) * psi_A) + + # Z = L_G^T * Z * L_A + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = U_G^T * Z * U_A + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A), transpose_a=True) + + # Z = Z .* Y + Z *= Y + + # Z = L_G * Z * L_A^T + # This is equivalent to the following computation from the original + # pseudo-code: + # Z = U_G * Z * U_A^T + # Z = G0^(-1/2) * Z * A0^(-1/2) + Z = math_ops.matmul(L_G, math_ops.matmul(Z, L_A, transpose_b=True)) + + elif self._option == SeriesFBApproximation.option2: + + # Note that P_A = A_1^T * A_0^(-1) and P_G = G_1^T * G_0^(-1), + # and K_A = A_0^(-1/2) * E_A and K_G = G_0^(-1/2) * E_G. + P_A, K_A, mu_A = self._input_factor.get_option2quants(self._damping_input) + P_G, K_G, mu_G = self._output_factor.get_option2quants( + self._damping_output) + + # Our approach differs superficially from the pseudo-code in the paper + # in order to reduce the total number of matrix-matrix multiplies. + # In particular, the first three computations in the pseudo code are + # Z = G0^(-1/2) * Z * A0^(-1/2) + # Z = Z - hPsi_G^T * Z * hPsi_A + # Z = E_G^T * Z * E_A + # Noting that hPsi = C0^(-1/2) * C1 * C0^(-1/2), so that + # C0^(-1/2) * hPsi = C0^(-1) * C1 * C0^(-1/2) = P^T * C0^(-1/2) + # the entire computation can be written as + # Z = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - hPsi_G^T * G0^(-1/2) * Z * A0^(-1/2) * hPsi_A) * E_A + # = E_G^T * (G0^(-1/2) * Z * A0^(-1/2) + # - G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2)) * E_A + # = E_G^T * G0^(-1/2) * Z * A0^(-1/2) * E_A + # - E_G^T* G0^(-1/2) * P_G * Z * P_A^T * A0^(-1/2) * E_A + # = K_G^T * Z * K_A - K_G^T * P_G * Z * P_A^T * K_A + # This final expression is computed by the following two lines: + # Z = Z - P_G * Z * P_A^T + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A, transpose_b=True)) + # Z = K_G^T * Z * K_A + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A), transpose_a=True) + + # Z = Z ./ (1*1^T - mu_G*mu_A^T) + # Be careful with the outer product. We don't want to accidentally + # make it an inner-product instead. + tmp = 1.0 - array_ops.reshape(mu_G, [int(mu_G.shape[0]), -1]) * mu_A + # Prevent some numerical issues by setting any 0.0 eigs to 1.0 + tmp += 1.0 * math_ops.cast(math_ops.equal(tmp, 0.0), dtype=tmp.dtype) + Z /= tmp + + # We now perform the transpose/reverse version of the operations + # derived above, whose derivation from the original pseudo-code is + # analgous. + # Z = K_G * Z * K_A^T + Z = math_ops.matmul(K_G, math_ops.matmul(Z, K_A, transpose_b=True)) + + # Z = Z - P_G^T * Z * P_A + Z -= math_ops.matmul(P_G, math_ops.matmul(Z, P_A), transpose_a=True) + + # Z = normalize (1/E[T]) * Z + # Note that this normalization is done because we compute the statistics + # by averaging, not summing, over time. (And the gradient is presumably + # summed over time, not averaged, and thus their scales are different.) + Z /= math_ops.cast(self._num_timesteps, Z.dtype) + + # Convert back to the "batch_dim==0" orientation. + Z = array_ops.transpose(Z) + + return utils.mat2d_to_layer_params(vector, Z) + + # pylint: enable=invalid-name + + def multiply(self, vector): + raise NotImplementedError + + def tensors_to_compute_grads(self): + return self._outputs + + def num_inputs(self): + return len(self._inputs) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py index 59389f8d385c18f50914d690cfaa2825ef807ed3..ac396309206fe09af65c2b70840a513fb25b579b 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -33,6 +33,10 @@ _allowed_symbols = [ 'ConvKFCBasicFB', 'ConvDiagonalFB', 'set_global_constants', + 'compute_pi_tracenorm', + 'compute_pi_adjusted_damping', + 'num_conv_locations', + 'normalize_damping' ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py index fbc192f1dcfa0b384e2cb31c43af3651436321ea..bcba18ae147c6ceca50bc9a2a17e01fc201d88c1 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -27,8 +27,11 @@ import six from tensorflow.contrib.kfac.python.ops import utils from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn from tensorflow.python.ops import special_math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -50,11 +53,15 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 # matrix powers. Must be nonnegative. EIGENVALUE_CLIPPING_THRESHOLD = 0.0 +# Colocate the covariance ops and variables with the input tensors for each +# factor. +COLOCATE_COV_OPS_WITH_INPUTS = True + @contextlib.contextmanager -def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): - """Context to colocate with `op` if `colocate_cov_ops_with_inputs`.""" - if colocate_cov_ops_with_inputs: +def maybe_colocate_with(op): + """Context to colocate with `op` if `COLOCATE_COV_OPS_WITH_INPUTS`.""" + if COLOCATE_COV_OPS_WITH_INPUTS: if isinstance(op, (list, tuple)): with tf_ops.colocate_with(op[0]): yield @@ -68,12 +75,14 @@ def _maybe_colocate_with(op, colocate_cov_ops_with_inputs): def set_global_constants(init_covariances_at_zero=None, zero_debias=None, eigenvalue_decomposition_threshold=None, - eigenvalue_clipping_threshold=None): + eigenvalue_clipping_threshold=None, + colocate_cov_ops_with_inputs=None): """Sets various global constants used by the classes in this module.""" global INIT_COVARIANCES_AT_ZERO global ZERO_DEBIAS global EIGENVALUE_DECOMPOSITION_THRESHOLD global EIGENVALUE_CLIPPING_THRESHOLD + global COLOCATE_COV_OPS_WITH_INPUTS if init_covariances_at_zero is not None: INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero @@ -83,6 +92,8 @@ def set_global_constants(init_covariances_at_zero=None, EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold if eigenvalue_clipping_threshold is not None: EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + if colocate_cov_ops_with_inputs is not None: + COLOCATE_COV_OPS_WITH_INPUTS = colocate_cov_ops_with_inputs def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument @@ -101,7 +112,55 @@ def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: di return array_ops.ones(shape, dtype) -def _compute_cov(tensor, normalizer=None): +def extract_image_patches(image, ksizes, strides, padding, name=None): + """Extracts image patches for an N-dimensional convolution. + + This function is a compatibility wrapper over tf.extract_image_patches(), as + ExtractImagePatches isn't yet implemented in XLA. + + Args: + image: Tensor of shape [batch, in_x, in_y, ..., in_channels]. Input images. + All dimensions except 'batch' must be defined. + ksizes: [filter_x, filter_y, ...]. Spatial shape of filter in each + dimension. + strides: [stride_x, stride_y, ...]. Spatial stride for filter in each + dimension. + padding: str. "VALID" or "SAME". + name: str or None. name of Op. + + Returns: + result: [batch, out_x, out_y, ..., filter_x, filter_y, ..., in_channels]. + Contains image patches to which conv kernel would be applied for each + output location. [out_x, out_y, ...] depends on padding. + """ + if not utils.on_tpu(): + return array_ops.extract_image_patches( + image, + ksizes=([1] + list(ksizes) + [1]), + strides=([1] + list(strides) + [1]), + rates=[1, 1, 1, 1], + padding=padding, + name=name) + + with tf_ops.name_scope(name, "extract_image_patches", + [image, ksizes, strides, padding]): + batch = image.shape.as_list()[0] + in_channels = image.shape.as_list()[-1] + + # Map each input feature to a location in the output. + out_channels = np.prod(ksizes) * in_channels + filters = linalg_ops.eye(out_channels), + filters = array_ops.reshape(filters, ksizes + [in_channels, out_channels]) + + result = nn.convolution(image, filters, padding, strides=strides) + out_spatial = result.shape.as_list()[1:-1] + result = array_ops.reshape( + result, [batch or -1] + out_spatial + ksizes + [in_channels]) + + return result + + +def compute_cov(tensor, tensor_right=None, normalizer=None): """Compute the empirical second moment of the rows of a 2D Tensor. This function is meant to be applied to random matrices for which the true row @@ -109,6 +168,8 @@ def _compute_cov(tensor, normalizer=None): Args: tensor: A 2D Tensor. + tensor_right: An optional 2D Tensor. If provided, this function computes + the matrix product tensor^T * tensor_right instead of tensor^T * tensor. normalizer: optional scalar for the estimator (by default, the normalizer is the number of rows of tensor). @@ -117,12 +178,17 @@ def _compute_cov(tensor, normalizer=None): """ if normalizer is None: normalizer = array_ops.shape(tensor)[0] - cov = (math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( - normalizer, tensor.dtype)) - return (cov + array_ops.transpose(cov)) / math_ops.cast(2, cov.dtype) + if tensor_right is None: + cov = ( + math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2.0, cov.dtype) + else: + return (math_ops.matmul(tensor, tensor_right, transpose_a=True) / + math_ops.cast(normalizer, tensor.dtype)) -def _append_homog(tensor): +def append_homog(tensor): """Appends a homogeneous coordinate to the last dimension of a Tensor. Args: @@ -135,7 +201,7 @@ def _append_homog(tensor): rank = len(tensor.shape.as_list()) shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) ones = array_ops.ones(shape, dtype=tensor.dtype) - return array_ops.concat([tensor, ones], axis=rank-1) + return array_ops.concat([tensor, ones], axis=rank - 1) def scope_string_from_params(params): @@ -173,8 +239,8 @@ def scope_string_from_params(params): elif isinstance(param, (tf_ops.Tensor, variables.Variable)): name_parts.append(scope_string_from_name(param)) else: - raise ValueError( - "Encountered an unsupported param type {}".format(type(param))) + raise ValueError("Encountered an unsupported param type {}".format( + type(param))) return "_".join(name_parts) @@ -225,6 +291,10 @@ class FisherFactor(object): """ pass + @abc.abstractproperty + def _dtype(self): + pass + @property def _cov_initializer(self): return covariance_initializer @@ -236,7 +306,8 @@ class FisherFactor(object): "cov", initializer=self._cov_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) @abc.abstractmethod def _compute_new_cov(self, idx=0): @@ -250,15 +321,27 @@ class FisherFactor(object): Returns: An Op for updating the covariance Variable referenced by _cov. """ - new_cov = math_ops.add_n( - tuple(self._compute_new_cov(idx) for idx in range(self._num_sources))) - - return moving_averages.assign_moving_average( - self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + new_cov_contribs = tuple(self._compute_new_cov(idx) + for idx in range(self._num_sources)) + # This gets the job done but we might want a better solution in the future. + # In particular, we could have a separate way of specifying where the + # the cov variables finally end up, independent of where their various + # contributions are computed. Right now these are the same thing, but in + # the future we might want to perform the cov computations on each tower, + # so that each tower will be considered a "source" (allowing us to reuse + # the existing "source" code for this). + with maybe_colocate_with(new_cov_contribs[0]): + new_cov = math_ops.add_n(new_cov_contribs) + # Synchronize value across all TPU cores. + if utils.on_tpu(): + new_cov = utils.cross_replica_mean(new_cov) + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + @abc.abstractmethod def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - return [] + pass def get_cov(self): return self._cov @@ -273,6 +356,13 @@ class InverseProvidingFactor(FisherFactor): _cov_shape properties. """ + # TODO(b/69108481): This class (and its subclasses) should be refactored to + # serve the matrix quantities it computes as both (potentially stale) + # variables, updated by the inverse update ops, and fresh values stored in + # tensors that recomputed once every session.run() call. Currently matpower + # and damp_inverse have the former behavior, while eigendecomposition has + # the latter. + def __init__(self): self._inverses_by_damping = {} self._matpower_by_exp_and_damping = {} @@ -283,6 +373,10 @@ class InverseProvidingFactor(FisherFactor): def register_damped_inverse(self, damping): """Registers a damped inverse needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_inverse. + Args: damping: The damping value (float or Tensor) for this factor. """ @@ -293,12 +387,17 @@ class InverseProvidingFactor(FisherFactor): "inv_damp{}".format(damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._inverses_by_damping[damping] = inv def register_matpower(self, exp, damping): """Registers a matrix power needed by a FisherBlock. + This creates a variable and signals make_inverse_update_ops to make the + corresponding update op. The variable can be read via the method + get_matpower. + Args: exp: The exponent (float or Tensor) to raise the matrix to. damping: The damping value (float or Tensor). @@ -311,59 +410,81 @@ class InverseProvidingFactor(FisherFactor): "matpower_exp{}_damp{}".format(exp_string, damping_string), initializer=inverse_initializer, shape=self._cov_shape, - trainable=False) + trainable=False, + dtype=self._dtype) self._matpower_by_exp_and_damping[(exp, damping)] = matpower - def register_eigendecomp(self): - """Registers that an eigendecomposition is needed by a FisherBlock.""" - if not self._eigendecomp: - self._eigendecomp = linalg_ops.self_adjoint_eig(self._cov) - def make_inverse_update_ops(self): """Create and return update ops corresponding to registered computations.""" - ops = super(InverseProvidingFactor, self).make_inverse_update_ops() + ops = [] + + # We do this to ensure that we don't reuse the eigendecomp from old calls + # to make_inverse_update_ops that may be placed on different devices. This + # can happen is the user has both a permanent and lazily constructed + # version of the inverse ops (and only uses one of them). + self.reset_eigendecomp() num_inverses = len(self._inverses_by_damping) matrix_power_registered = bool(self._matpower_by_exp_and_damping) - use_eig = (self._eigendecomp or matrix_power_registered or - num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + use_eig = ( + self._eigendecomp or matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) if use_eig: - self.register_eigendecomp() # ensures self._eigendecomp is set - eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence - - # The matrix self._cov is positive semidefinite by construction, but the - # numerical eigenvalues could be negative due to numerical errors, so here - # we clip them to be at least EIGENVALUE_CLIPPING_THRESHOLD. - clipped_eigenvalues = math_ops.maximum(eigenvalues, - EIGENVALUE_CLIPPING_THRESHOLD) + eigenvalues, eigenvectors = self.get_eigendecomp() # pylint: disable=unpacking-non-sequence for damping, inv in self._inverses_by_damping.items(): ops.append( inv.assign( - math_ops.matmul(eigenvectors / (clipped_eigenvalues + damping), + math_ops.matmul(eigenvectors / (eigenvalues + damping), array_ops.transpose(eigenvectors)))) for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): ops.append( matpower.assign( - math_ops.matmul(eigenvectors * (clipped_eigenvalues + damping)** - exp, array_ops.transpose(eigenvectors)))) + math_ops.matmul(eigenvectors * + (eigenvalues + damping)**exp, + array_ops.transpose(eigenvectors)))) + # These ops share computation and should be run on a single device. + ops = [control_flow_ops.group(*ops)] else: for damping, inv in self._inverses_by_damping.items(): ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) return ops - def get_inverse(self, damping): + def get_damped_inverse(self, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._inverses_by_damping[damping] def get_matpower(self, exp, damping): + # Note that this function returns a variable which gets updated by the + # inverse ops. It may be stale / inconsistent with the latest value of + # get_cov(). return self._matpower_by_exp_and_damping[(exp, damping)] def get_eigendecomp(self): + """Creates or retrieves eigendecomposition of self._cov.""" + # Unlike get_inverse and get_matpower this doesn't retrieve a stored + # variable, but instead always computes a fresh version from the current + # value of get_cov(). + if not self._eigendecomp: + eigenvalues, eigenvectors = linalg_ops.self_adjoint_eig(self._cov) + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least FLAGS.eigenvalue_clipping_threshold + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + self._eigendecomp = (clipped_eigenvalues, eigenvectors) + return self._eigendecomp + def reset_eigendecomp(self): + self._eigendecomp = None + class FullFactor(InverseProvidingFactor): """FisherFactor for a full matrix representation of the Fisher of a parameter. @@ -374,41 +495,38 @@ class FullFactor(InverseProvidingFactor): def __init__(self, params_grads, - batch_size, - colocate_cov_ops_with_inputs=False): + batch_size): self._batch_size = batch_size - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - self._orig_params_grads_name = scope_string_from_params( - [params_grads, self._batch_size]) - params_grads_flat = [] - for params_grad in params_grads: - with _maybe_colocate_with(params_grad, - self._colocate_cov_ops_with_inputs): - col = utils.tensors_to_column(params_grad) - params_grads_flat.append(col) - self._params_grads_flat = tuple(params_grads_flat) + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) super(FullFactor, self).__init__() @property def _var_scope(self): - return "ff_full/" + self._orig_params_grads_name + return "ff_full/" + scope_string_from_params( + [self._params_grads, self._batch_size]) @property def _cov_shape(self): - size = self._params_grads_flat[0].shape[0] - return [size, size] + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, size) @property def _num_sources(self): - return len(self._params_grads_flat) + return len(self._params_grads) + + @property + def _dtype(self): + return self._params_grads[0][0].dtype def _compute_new_cov(self, idx=0): # This will be a very basic rank 1 estimate - with _maybe_colocate_with(self._params_grads_flat[idx], - self._colocate_cov_ops_with_inputs): - return ((self._params_grads_flat[idx] * array_ops.transpose( - self._params_grads_flat[idx])) / math_ops.cast( - self._batch_size, self._params_grads_flat[idx].dtype)) + with maybe_colocate_with(self._params_grads[idx]): + params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + return ((params_grads_flat * array_ops.transpose( + params_grads_flat)) / math_ops.cast(self._batch_size, + params_grads_flat.dtype)) class DiagonalFactor(FisherFactor): @@ -421,6 +539,9 @@ class DiagonalFactor(FisherFactor): def _cov_initializer(self): return diagonal_covariance_initializer + def make_inverse_update_ops(self): + return [] + class NaiveDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approximation of any type of param's Fisher. @@ -431,38 +552,36 @@ class NaiveDiagonalFactor(DiagonalFactor): def __init__(self, params_grads, - batch_size, - colocate_cov_ops_with_inputs=False): + batch_size): + self._params_grads = tuple(utils.ensure_sequence(params_grad) + for params_grad in params_grads) self._batch_size = batch_size - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - params_grads_flat = [] - for params_grad in params_grads: - with _maybe_colocate_with(params_grad, - self._colocate_cov_ops_with_inputs): - col = utils.tensors_to_column(params_grad) - params_grads_flat.append(col) - self._params_grads = tuple(params_grads_flat) - self._orig_params_grads_name = scope_string_from_params( - [self._params_grads, self._batch_size]) super(NaiveDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_naivediag/" + self._orig_params_grads_name + return "ff_naivediag/" + scope_string_from_params( + [self._params_grads, self._batch_size]) @property def _cov_shape(self): - return self._params_grads[0].shape + size = sum(param_grad.shape.num_elements() + for param_grad in self._params_grads[0]) + return (size, 1) @property def _num_sources(self): return len(self._params_grads) + @property + def _dtype(self): + return self._params_grads[0][0].dtype + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._params_grads[idx], - self._colocate_cov_ops_with_inputs): - return (math_ops.square(self._params_grads[idx]) / math_ops.cast( - self._batch_size, self._params_grads[idx].dtype)) + with maybe_colocate_with(self._params_grads[idx]): + params_grads_flat = utils.tensors_to_column(self._params_grads[idx]) + return (math_ops.square(params_grads_flat) / math_ops.cast( + self._batch_size, params_grads_flat.dtype)) class FullyConnectedDiagonalFactor(DiagonalFactor): @@ -471,18 +590,15 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], approximates the covariance as, - Cov(in, out) = (1/batch_size) \sum_{i} outer(in[i], out_grad[i]) ** 2.0 + Cov(in, out) = (1/batch_size) sum_{i} outer(in[i], out_grad[i]) ** 2.0 where the square is taken element-wise. """ - # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Instantiate FullyConnectedDiagonalFactor. Args: @@ -491,44 +607,46 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): outputs_grads: List of Tensors of shape [batch_size, output_size]. Gradient of loss with respect to layer's preactivations. has_bias: bool. If True, append '1' to each input. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ + self._inputs = inputs + self._has_bias = has_bias self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs self._batch_size = array_ops.shape(inputs)[0] - self._orig_tensors_name = scope_string_from_params((inputs,) + - tuple(outputs_grads)) - - # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. (Only - # the target entry of _outputs_grads changes with idx.) - with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): - if has_bias: - inputs = _append_homog(inputs) - self._squared_inputs = math_ops.square(inputs) + self._squared_inputs = None super(FullyConnectedDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_diagfc/" + self._orig_tensors_name + return "ff_diagfc/" + scope_string_from_params( + (self._inputs,) + tuple(self._outputs_grads)) @property def _cov_shape(self): - return [self._squared_inputs.shape[1], self._outputs_grads[0].shape[1]] + return [self._inputs.shape[1] + self._has_bias, + self._outputs_grads[0].shape[1]] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): # The well-known special formula that uses the fact that the entry-wise # square of an outer product is the outer-product of the entry-wise squares. # The gradient is the outer product of the input and the output gradients, # so we just square both and then take their outer-product. - with _maybe_colocate_with(self._squared_inputs, - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._outputs_grads[idx]): + # We only need to compute squared_inputs once + if self._squared_inputs is None: + inputs = self._inputs + if self._has_bias: + inputs = append_homog(self._inputs) + self._squared_inputs = math_ops.square(inputs) + new_cov = math_ops.matmul( self._squared_inputs, math_ops.square(self._outputs_grads[idx]), @@ -540,16 +658,13 @@ class FullyConnectedDiagonalFactor(DiagonalFactor): class ConvDiagonalFactor(DiagonalFactor): """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" - # TODO(jamesmartens): add units tests for this class - def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Creates a ConvDiagonalFactor object. Args: @@ -564,53 +679,63 @@ class ConvDiagonalFactor(DiagonalFactor): padding: The padding in this layer (1-D of Tensor length 4). has_bias: Python bool. If True, the layer is assumed to have a bias parameter in addition to its filter parameter. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ + self._inputs = inputs self._filter_shape = filter_shape + self._strides = strides + self._padding = padding self._has_bias = has_bias self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs - - self._orig_tensors_name = scope_string_from_name((inputs,) - + tuple(outputs_grads)) - - # Note that we precompute the required operations on the inputs since the - # inputs don't change with the 'idx' argument to _compute_new_cov. (Only - # the target entry of _outputs_grads changes with idx.) - with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs): - filter_height, filter_width, _, _ = self._filter_shape - patches = array_ops.extract_image_patches( - inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=strides, - rates=[1, 1, 1, 1], - padding=padding) - - if has_bias: - patches = _append_homog(patches) - - self._patches = patches + self._patches = None super(ConvDiagonalFactor, self).__init__() @property def _var_scope(self): - return "ff_convdiag/" + self._orig_tensors_name + return "ff_convdiag/" + scope_string_from_name( + (self._inputs,) + tuple(self._outputs_grads)) @property def _cov_shape(self): filter_height, filter_width, in_channels, out_channels = self._filter_shape - return [filter_height * filter_width * in_channels + self._has_bias, - out_channels] + return [ + filter_height * filter_width * in_channels + self._has_bias, + out_channels + ] @property def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + + def make_covariance_update_op(self, ema_decay): + with maybe_colocate_with(self._inputs): + filter_height, filter_width, _, _ = self._filter_shape + + # TODO(b/64144716): there is potential here for a big savings in terms + # of memory use. + patches = extract_image_patches( + self._inputs, + ksizes=[filter_height, filter_width], + strides=self._strides[1:-1], + padding=self._padding) + + if self._has_bias: + patches = append_homog(patches) + + self._patches = patches + + op = super(ConvDiagonalFactor, self).make_covariance_update_op(ema_decay) + + self._patches = None + + return op + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx], - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._outputs_grads[idx]): outputs_grad = self._outputs_grads[idx] batch_size = array_ops.shape(self._patches)[0] @@ -634,23 +759,18 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def __init__(self, tensors, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Instantiate FullyConnectedKroneckerFactor. Args: tensors: List of Tensors of shape [batch_size, n]. Represents either a layer's inputs or its output's gradients. - has_bias: bool. If True, assume this factor is for the layer's inputs and - append '1' to each row. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. + has_bias: bool. If True, append '1' to each row. """ # The tensor argument is either a tensor of input activations or a tensor of # output pre-activation gradients. self._has_bias = has_bias self._tensors = tensors - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(FullyConnectedKroneckerFactor, self).__init__() @property @@ -667,13 +787,16 @@ class FullyConnectedKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._tensors) + @property + def _dtype(self): + return self._tensors[0].dtype + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._tensors[idx], - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._tensors[idx]): tensor = self._tensors[idx] if self._has_bias: - tensor = _append_homog(tensor) - return _compute_cov(tensor) + tensor = append_homog(tensor) + return compute_cov(tensor) class ConvInputKroneckerFactor(InverseProvidingFactor): @@ -682,7 +805,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): Estimates E[ a a^T ] where a is the inputs to a convolutional layer given example x. Expectation is taken over all examples and locations. - Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See + Equivalent to Omega in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors. """ @@ -691,8 +814,7 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): filter_shape, strides, padding, - has_bias=False, - colocate_cov_ops_with_inputs=False): + has_bias=False): """Initializes ConvInputKroneckerFactor. Args: @@ -704,15 +826,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): width_stride, in_channel_stride]. padding: str. Padding method for layer. "SAME" or "VALID". has_bias: bool. If True, append 1 to in_channel. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ self._filter_shape = filter_shape self._strides = strides self._padding = padding self._has_bias = has_bias self._inputs = inputs - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvInputKroneckerFactor, self).__init__() @property @@ -732,27 +851,44 @@ class ConvInputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return 1 + @property + def _dtype(self): + return self._inputs.dtype + def _compute_new_cov(self, idx=0): if idx != 0: raise ValueError("ConvInputKroneckerFactor only supports idx = 0") - # TODO(jamesmartens): factor this patches stuff out into a utility function - with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._inputs): filter_height, filter_width, in_channels, _ = self._filter_shape - patches = array_ops.extract_image_patches( + + # TODO(b/64144716): there is potential here for a big savings in terms of + # memory use. + patches = extract_image_patches( self._inputs, - ksizes=[1, filter_height, filter_width, 1], - strides=self._strides, - rates=[1, 1, 1, 1], + ksizes=[filter_height, filter_width], + strides=self._strides[1:-1], padding=self._padding) flatten_size = (filter_height * filter_width * in_channels) + # patches_flat below is the matrix [[A_l]] from the KFC paper (tilde + # omitted over A for clarity). It has shape M|T| x J|Delta| (eq. 14), + # where M = minibatch size, |T| = number of spatial locations, + # |Delta| = number of spatial offsets, and J = number of input maps + # for convolutional layer l. patches_flat = array_ops.reshape(patches, [-1, flatten_size]) - + # We append a homogenous coordinate to patches_flat if the layer has + # bias parameters. This gives us [[A_l]]_H from the paper. if self._has_bias: - patches_flat = _append_homog(patches_flat) - - return _compute_cov(patches_flat) + patches_flat = append_homog(patches_flat) + # We call compute_cov without passing in a normalizer. compute_cov uses + # the first dimension of patches_flat i.e. M|T| as the normalizer by + # default. Hence we end up computing 1/M|T| * [[A_l]]^T [[A_l]], with + # shape J|Delta| x J|Delta|. This is related to hat{Omega}_l from + # the paper but has a different scale here for consistency with + # ConvOutputKroneckerFactor. + # (Tilde omitted over A for clarity.) + return compute_cov(patches_flat) class ConvOutputKroneckerFactor(InverseProvidingFactor): @@ -762,22 +898,19 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over all examples and locations. - Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See + Equivalent to Gamma in https://arxiv.org/abs/1602.01407 for details. See Section 3.1 Estimating the factors. """ - def __init__(self, outputs_grads, colocate_cov_ops_with_inputs=False): + def __init__(self, outputs_grads): """Initializes ConvOutputKroneckerFactor. Args: outputs_grads: list of Tensors. Each Tensor is of shape [batch_size, height, width, out_channels]. - colocate_cov_ops_with_inputs: Whether to colocate cov_update ops with - their inputs. """ self._out_channels = outputs_grads[0].shape.as_list()[3] self._outputs_grads = outputs_grads - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs super(ConvOutputKroneckerFactor, self).__init__() @property @@ -793,9 +926,292 @@ class ConvOutputKroneckerFactor(InverseProvidingFactor): def _num_sources(self): return len(self._outputs_grads) + @property + def _dtype(self): + return self._outputs_grads[0].dtype + def _compute_new_cov(self, idx=0): - with _maybe_colocate_with(self._outputs_grads[idx], - self._colocate_cov_ops_with_inputs): + with maybe_colocate_with(self._outputs_grads[idx]): + # reshaped_tensor below is the matrix DS_l defined in the KFC paper + # (tilde omitted over S for clarity). It has shape M|T| x I, where + # M = minibatch size, |T| = number of spatial locations, and + # I = number of output maps for convolutional layer l. reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], [-1, self._out_channels]) - return _compute_cov(reshaped_tensor) + # Following the reasoning in ConvInputKroneckerFactor._compute_new_cov, + # compute_cov here returns 1/M|T| * DS_l^T DS_l = hat{Gamma}_l + # as defined in the paper, with shape I x I. + # (Tilde omitted over S for clarity.) + return compute_cov(reshaped_tensor) + + +class FullyConnectedMultiKF(InverseProvidingFactor): + """Kronecker factor for a fully connected recurrent layer.""" + + def __init__(self, + tensor_lists, + has_bias=False): + """Constructs a new `FullyConnectedMultiKF`. + + Args: + tensor_lists: List of lists of Tensors of shape [batch_size, n]. + has_bias: bool. If True, '1' is appended to each row. + """ + + self._tensor_lists = tensor_lists + self._has_bias = has_bias + self._batch_size = array_ops.shape(tensor_lists[0][0])[0] + self._num_timesteps = len(tensor_lists[0]) + self._tensors = [None] * len(tensor_lists) + + self._cov_dt1 = None + self._option1quants_by_damping = {} + self._option2quants_by_damping = {} + + super(FullyConnectedMultiKF, self).__init__() + + @property + def _var_scope(self): + return "ff_fc_multi/" + scope_string_from_params(self._tensor_lists) + + @property + def _num_sources(self): + return len(self._tensor_lists) + + @property + def _dtype(self): + return self._tensor_lists[0][0].dtype + + def make_covariance_update_op(self, ema_decay): + + op = super(FullyConnectedMultiKF, self).make_covariance_update_op(ema_decay) + + if self._cov_dt1 is not None: + new_cov_dt1_contribs = tuple(self._compute_new_cov_dt1(idx) + for idx in range(self._num_sources)) + + with maybe_colocate_with(new_cov_dt1_contribs[0]): + new_cov_dt1 = math_ops.add_n(new_cov_dt1_contribs) + + op2 = moving_averages.assign_moving_average( + self._cov_dt1, new_cov_dt1, ema_decay, zero_debias=ZERO_DEBIAS) + + # TODO(b/69112164): + # It's important that _cov and _cov_dt1 remain consistent with each + # other while the inverse ops are happening. How can we ensure this? + # We will need to add explicit synchronization for this to + # work with asynchronous training. + op = control_flow_ops.group(op, op2) + + return op + + def _compute_new_cov(self, idx=0): + with maybe_colocate_with(self._tensor_lists[idx]): + tensor = array_ops.concat(self._tensor_lists[idx], 0) + if self._has_bias: + tensor = append_homog(tensor) + # We save these so they can be used by _compute_new_cov_dt1 + self._tensors[idx] = tensor + return compute_cov(tensor) + + def _compute_new_cov_dt1(self, idx=0): + tensor = self._tensors[idx] + with maybe_colocate_with(tensor): + # Is there a more elegant way to do this computation? + tensor_present = tensor[:-self._batch_size, :] + tensor_future = tensor[self._batch_size:, :] + # We specify a normalizer for this computation to ensure a PSD Fisher + # block estimate. This is equivalent to padding with zeros, as was done + # in Section B.2 of the appendix. + normalizer = self._num_timesteps * self._batch_size + return compute_cov( + tensor_future, tensor_right=tensor_present, normalizer=normalizer) + + @property + def _cov_shape(self): + size = self._tensor_lists[0][0].shape[1] + self._has_bias + return [size, size] + + @property + def _vec_shape(self): + size = self._tensor_lists[0][0].shape[1] + self._has_bias + return [size] + + def get_option1quants(self, damping): + return self._option1quants_by_damping[damping] + + def get_option2quants(self, damping): + return self._option2quants_by_damping[damping] + + def get_cov_dt1(self): + assert self._cov_dt1 is not None + return self._cov_dt1 + + def register_cov_dt1(self): + """Create a variable representing temporal cross-covariance. + + (This is technically the second moment, not covariance, since it's + not mean subtracted.) + """ + if self._cov_dt1 is None: + with variable_scope.variable_scope(self._var_scope): + self._cov_dt1 = variable_scope.get_variable( + "cov_dt1", + initializer=init_ops.zeros_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + + def register_option1quants(self, damping): + + self.register_cov_dt1() + + if damping not in self._option1quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Lmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + psi = variable_scope.get_variable( + "psi_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option1quants_by_damping[damping] = (Lmat, psi) + + def register_option2quants(self, damping): + + self.register_cov_dt1() + + if damping not in self._option2quants_by_damping: + # It's questionable as to whether we should initialize with stuff like + # this at all. Ideally these values should never be used until they are + # updated at least once. + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + Pmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Lmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + Kmat = variable_scope.get_variable( # pylint: disable=invalid-name + "Kmat_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False, + dtype=self._dtype) + mu = variable_scope.get_variable( + "mu_damp{}".format(damping_string), + initializer=init_ops.ones_initializer, + shape=self._vec_shape, + trainable=False, + dtype=self._dtype) + + self._option2quants_by_damping[damping] = (Pmat, Kmat, mu) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + # TODO(b/69918258): Add correctness tests for this method. + # pylint: disable=invalid-name + + ops = super(FullyConnectedMultiKF, self).make_inverse_update_ops() + + if (len(self._option1quants_by_damping) + + len(self._option2quants_by_damping)): + + # Note that C0 and C1 are stand-ins for A0 and A1, or G0 and G1, from + # the pseudo-code in the original paper. Because the computations for + # the A and G case are essentially the same they can both be performed by + # the same class (this one). + + C1 = self.get_cov_dt1() + + # Get the eigendecomposition of C0 (= self.get_cov()) + eigen_e, eigen_V = self.get_eigendecomp() + + # TODO(b/69678661): Note, there is an implicit assumption here that C1 + # and C0 (as represented here by its eigen-decomp) are consistent. This + # could fail to be the case if self._cov and self._cov_dt1 are not updated + # consistently, or are somehow read between or during the cov updates. + # Can this possibly happen? Is there a way to prevent it? + + for damping, (Lmat_var, + psi_var) in self._option1quants_by_damping.items(): + + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # The following line imposses the symmetry assumed by "Option 1" on C1. + # Stangely the code can work okay with this line commented out, + # depending on how psd_eig is defined. I'm not sure why. + C1 = (C1 + array_ops.transpose(C1)) / 2.0 + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(math_ops.matmul(invsqrtC0, C1), invsqrtC0) + + # Compute the decomposition U*diag(psi)*U^T = hPsi + psi, U = utils.posdef_eig(hPsi) + + # L = C0^(-1/2) * U + Lmat = math_ops.matmul(invsqrtC0, U) + + ops.append(Lmat_var.assign(Lmat)) + ops.append(psi_var.assign(psi)) + + for damping, (Pmat_var, Kmat_var, + mu_var) in self._option2quants_by_damping.items(): + + # compute C0^(-1/2) + invsqrtC0 = math_ops.matmul( + eigen_V * (eigen_e + damping)**(-0.5), eigen_V, transpose_b=True) + + # Might need to enforce symmetry lost due to numerical issues. + invsqrtC0 = (invsqrtC0 + array_ops.transpose(invsqrtC0)) / 2.0 + + # Compute the product C0^(-1/2) * C1 + invsqrtC0C1 = math_ops.matmul(invsqrtC0, C1) + + # hPsi = C0^(-1/2) * C1 * C0^(-1/2) (hPsi means hat{Psi}) + hPsi = math_ops.matmul(invsqrtC0C1, invsqrtC0) + + # Compute the decomposition E*diag(mu)*E^T = hPsi^T * hPsi + # Note that we using the notation mu instead of "m" for the eigenvalues. + # Instead of computing the product hPsi^T * hPsi and then doing an + # eigen-decomposition of this we just compute the SVD of hPsi and then + # square the singular values to get the eigenvalues. For a justification + # of this approach, see: + # https://en.wikipedia.org/wiki/Singular-value_decomposition#Relation_to_eigenvalue_decomposition + sqrtmu, _, E = linalg_ops.svd(hPsi) + mu = math_ops.square(sqrtmu) + + # Mathematically, the eigenvalues should not should not exceed 1.0, but + # due to numerical issues, or possible issues with inconsistent + # values of C1 and (the eigen-decomposition of) C0 they might. So + # we enforce this condition. + mu = math_ops.minimum(mu, 1.0) + + # P = (C0^(-1/2) * C1)^T * C0^(-1/2) = C_1^T * C_0^(-1) + Pmat = math_ops.matmul(invsqrtC0C1, invsqrtC0, transpose_a=True) + + # K = C_0^(-1/2) * E + Kmat = math_ops.matmul(invsqrtC0, E) + + ops.append(Pmat_var.assign(Pmat)) + ops.append(Kmat_var.assign(Kmat)) + ops.append(mu_var.assign(mu)) + + return [control_flow_ops.group(*ops)] + + # pylint: enable=invalid-name diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py index 23ee93cd405bbf719939df89d525c812ee061f8b..ad93919149c287b1932dd2b6bd772c0dab26192d 100644 --- a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -41,6 +41,9 @@ _allowed_symbols = [ "ConvOutputKroneckerFactor", "ConvDiagonalFactor", "set_global_constants", + "maybe_colocate_with", + "compute_cov", + "append_homog" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py index 3a005ee39dd9400c21ae6c41fad5351d7fff2aac..8d450f04f379701e46a18b2e34bbbd6fcfcce2bb 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -26,7 +26,9 @@ from __future__ import print_function from collections import defaultdict from collections import OrderedDict +from functools import partial +import math import six from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb @@ -57,20 +59,22 @@ _CONV2D_APPROX_TO_BLOCK_TYPES = { APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, } +APPROX_KRONECKER_INDEP_NAME = "kron_indep" +APPROX_KRONECKER_SERIES_1_NAME = "kron_series_1" +APPROX_KRONECKER_SERIES_2_NAME = "kron_series_2" + +_FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES = { + APPROX_KRONECKER_INDEP_NAME: fb.FullyConnectedMultiIndepFB, + APPROX_KRONECKER_SERIES_1_NAME: partial(fb.FullyConnectedSeriesFB, + option=1), + APPROX_KRONECKER_SERIES_2_NAME: partial(fb.FullyConnectedSeriesFB, + option=2) +} + # Possible value for 'reuse' keyword argument. Sets 'reuse' to # tf.get_variable_scope().reuse. VARIABLE_SCOPE = "VARIABLE_SCOPE" -# TODO(jamesmartens): need to add find_canonical_output back into this somewhere - - -def ensure_sequence(obj): - """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" - if isinstance(obj, (tuple, list)): - return obj - else: - return (obj,) - class LayerParametersDict(OrderedDict): """An OrderedDict where keys are Tensors or tuples of Tensors. @@ -130,7 +134,6 @@ class LayerCollection(object): def __init__(self, graph=None, - colocate_cov_ops_with_inputs=False, name="LayerCollection"): self.fisher_blocks = LayerParametersDict() self.fisher_factors = OrderedDict() @@ -142,7 +145,8 @@ class LayerCollection(object): self._default_generic_approximation = APPROX_FULL_NAME self._default_fully_connected_approximation = APPROX_KRONECKER_NAME self._default_convolution_2d_approximation = APPROX_KRONECKER_NAME - self._colocate_cov_ops_with_inputs = colocate_cov_ops_with_inputs + self._default_fully_connected_multi_approximation = ( + APPROX_KRONECKER_SERIES_2_NAME) with variable_scope.variable_scope(None, default_name=name) as scope: self._var_scope = scope.name @@ -152,19 +156,13 @@ class LayerCollection(object): """LossFunctions registered with this LayerCollection.""" return list(self._loss_dict.values()) - def is_variable_registered(self, variable): - """Checks whether the variable has already been registered. - - Args: - variable: A single variable or tensor. - Returns: - True if the variable has been registered either by itself or as part of a - tuple. - """ - return any([ - variable in key if isinstance(key, (tuple, list)) else variable == key - for key in self.fisher_blocks.keys() - ]) + @property + def registered_variables(self): + """A tuple of all of the variables currently registered.""" + tuple_of_tuples = (utils.ensure_sequence(key) for key, block + in six.iteritems(self.fisher_blocks)) + flat_tuple = tuple(item for tuple_ in tuple_of_tuples for item in tuple_) + return flat_tuple @property def linked_parameters(self): @@ -213,6 +211,16 @@ class LayerCollection(object): value)) self._default_convolution_2d_approximation = value + @property + def default_fully_connected_multi_approximation(self): + return self._default_fully_connected_multi_approximation + + def set_default_fully_connected_multi_approximation(self, value): + if value not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("{} is not a valid approximation for a fully-connected " + "multi layer.".format(value)) + self._default_fully_connected_multi_approximation = value + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): """Validates and registers the layer_key associated with the fisher_block. @@ -221,7 +229,7 @@ class LayerCollection(object): existing registrations and to register if valid. fisher_block: The associated `FisherBlock`. reuse: Method to use for inserting new `FisherBlock`s. One of True, False, - or VARIABLE_SCOPE. + or 'VARIABLE_SCOPE'. Raises: ValueError: If `layer_key` was already registered and reuse is `False`, @@ -258,9 +266,9 @@ class LayerCollection(object): variable_to_block = { var: (params, block) for (params, block) in self.fisher_blocks.items() - for var in ensure_sequence(params) + for var in utils.ensure_sequence(params) } - for variable in ensure_sequence(layer_key): + for variable in utils.ensure_sequence(layer_key): if variable in variable_to_block: prev_key, prev_block = variable_to_block[variable] raise ValueError( @@ -272,13 +280,65 @@ class LayerCollection(object): def get_use_count_map(self): """Returns a dict of variables to their number of registrations.""" + # TODO(b/70283403): Reimplement this in the old way, where each + # registration function would be responsible for incrementing the count. + # Also, this version has a bug: it won't do the right thing for generic + # registration for parameters that are shared. i.e. it won't set the use + # count to infinity. vars_to_uses = defaultdict(int) for key, block in six.iteritems(self.fisher_blocks): - key = key if isinstance(key, (tuple, list)) else (key,) + n = ( + block.num_inputs()*block.num_registered_minibatches if isinstance( + block, (fb.FullyConnectedSeriesFB, fb.FullyConnectedMultiIndepFB)) + else block.num_registered_minibatches) + key = utils.ensure_sequence(key) for k in key: - vars_to_uses[k] += block.num_registered_minibatches + vars_to_uses[k] += n return vars_to_uses + def check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self.get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + def get_blocks(self): return self.fisher_blocks.values() @@ -312,12 +372,12 @@ class LayerCollection(object): ValueError: If the parameters were already registered in a layer or identified as part of an incompatible group. """ - params = frozenset(ensure_sequence(params)) + params = frozenset(utils.ensure_sequence(params)) # Check if any of the variables in 'params' is already in # 'self.fisher_blocks.keys()'. for registered_params, fisher_block in self.fisher_blocks.items(): - registered_params_set = set(ensure_sequence(registered_params)) + registered_params_set = set(utils.ensure_sequence(registered_params)) for variable in params: if (variable in registered_params_set and params != registered_params_set): @@ -351,7 +411,7 @@ class LayerCollection(object): def _get_linked_approx(self, params): """If params were linked, return their specified approximation.""" - params_set = frozenset(ensure_sequence(params)) + params_set = frozenset(utils.ensure_sequence(params)) if params_set in self.linked_parameters: return self.linked_parameters[params_set] else: @@ -370,11 +430,11 @@ class LayerCollection(object): this layer. Weight matrix should have shape [input_size, output_size]. Bias should have shape [output_size]. inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. - outputs: Tensor of shape [batch_size, output_size]. Preactivations + outputs: Tensor of shape [batch_size, output_size]. Outputs produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -416,10 +476,10 @@ class LayerCollection(object): inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs to layer. outputs: Tensor of shape [batch_size, height, width, out_channels]. - Preactivations produced by layer. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + Output produced by layer. + approx: str. One of "kron" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -449,14 +509,11 @@ class LayerCollection(object): """Registers a generic layer. Args: - params: Tensor or 2-tuple of Tensors corresponding to weight and bias of - this layer. Weight matrix should have shape [kernel_height, - kernel_width, in_channels, out_channels]. Bias should have shape - [out_channels]. + params: Tensor or tuple of Tensors corresponding to the parameters. batch_size: 0-D Tensor. Size of the minibatch. - approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + approx: str. One of "full" or "diagonal". reuse: bool or str. If True, reuse an existing FisherBlock. If False, - create a new FisherBlock. If VARIABLE_SCOPE, use + create a new FisherBlock. If "VARIABLE_SCOPE", use tf.get_variable_scope().reuse. Raises: @@ -477,6 +534,47 @@ class LayerCollection(object): block = self.register_block(params, block_type(self, params), reuse=reuse) block.register_additional_minibatch(batch_size) + def register_fully_connected_multi(self, params, inputs, outputs, + approx=None): + """Register fully connected layers with shared parameters. + + This can handle general fully-connected layers with shared parameters, but + has specialized approximations to deal with the case where there is a + meaningful linear order to the share instances (such as in an RNN). + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: A list of tensors, each of shape [batch_size, input_size]. Inputs + to layer. In the case of RNNs, one Tensor per time step. + outputs: A list of tensors, the same length as 'inputs', each of shape + [batch_size, output_size]. Outputs produced by layer. In the case of + RNNs, one Tensor per time step. + approx: str. One of "kron_indep", "kron_series_1", or "kron_series_2". + + Raises: + ValueError: For improper value to 'approx'. + """ + if approx is None: + approx = self._get_linked_approx(params) + if approx is None: + approx = self.default_fully_connected_multi_approximation + has_bias = isinstance(params, (tuple, list)) + + # TODO(b/70283649): something along the lines of find_canonical_output + # should be added back in here (and for the other block types, arguably). + + if approx not in _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES: + raise ValueError("Bad value {} for approx.".format(approx)) + block_type = _FULLY_CONNECTED_MULTI_APPROX_TO_BLOCK_TYPES[approx] + + # For now we don't support multiple minibatches for this type of layer, so + # we set reuse=False + self.register_block(params, + block_type(self, inputs, outputs, has_bias=has_bias), + reuse=False) + def register_categorical_predictive_distribution(self, logits, seed=None, @@ -619,7 +717,6 @@ class LayerCollection(object): key = cls, args if key not in self.fisher_factors: - colo = self._colocate_cov_ops_with_inputs with variable_scope.variable_scope(self._var_scope): - self.fisher_factors[key] = cls(*args, colocate_cov_ops_with_inputs=colo) + self.fisher_factors[key] = cls(*args) return self.fisher_factors[key] diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py index d6bf61a210203dd74d4e93b65005f660b1fab4ff..f8aa230d9ca1f542950f56b1e6cf1ab7ccd3d05f 100644 --- a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -36,6 +36,9 @@ _allowed_symbols = [ "APPROX_DIAGONAL_NAME", "APPROX_FULL_NAME", "VARIABLE_SCOPE", + "APPROX_KRONECKER_INDEP_NAME", + "APPROX_KRONECKER_SERIES_1_NAME", + "APPROX_KRONECKER_SERIES_2_NAME" ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py index e2e5bc3ffea3e52087c24802948bc8260e3b199a..cb3e698b9ceab920785adf735f88bd8e535a628f 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -22,6 +22,7 @@ import abc import six +from tensorflow.contrib.distributions.python.ops import onehot_categorical from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops @@ -91,13 +92,13 @@ class LossFunction(object): @abc.abstractmethod def _evaluate(self, targets): - """Evaluates the log probability of the targets. + """Evaluates the negative log probability of the targets. Args: targets: Tensor that distribution can calculate log_prob() of. Returns: - log probability of each target, summed across all targets. + negative log probability of each target, summed across all targets. """ pass @@ -659,19 +660,20 @@ class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, def multiply_fisher(self, vector): probs = self._probs - return vector * probs - math_ops.reduce_sum(vector * probs, axis=1) * probs + return vector * probs - probs * math_ops.reduce_sum( + vector * probs, axis=-1, keep_dims=True) def multiply_fisher_factor(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - probs * math_ops.reduce_sum( - sqrt_probs * vector, axis=1, keep_dims=True) + sqrt_probs * vector, axis=-1, keep_dims=True) def multiply_fisher_factor_transpose(self, vector): probs = self._probs sqrt_probs = self._sqrt_probs return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( - probs * vector, axis=1, keep_dims=True) + probs * vector, axis=-1, keep_dims=True) def multiply_fisher_factor_replicated_one_hot(self, index): assert len(index) == 1, "Length of index was {}".format(len(index)) @@ -785,3 +787,16 @@ def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): after[dim] = dim_size - position - 1 return array_ops.pad(slice_to_insert, list(zip(before, after))) + + +class OnehotCategoricalLogitsNegativeLogProbLoss( + CategoricalLogitsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution with onehot targets. + + Identical to CategoricalLogitsNegativeLogProbLoss except that the underlying + distribution is OneHotCategorical as opposed to Categorical. + """ + + @property + def dist(self): + return onehot_categorical.OneHotCategorical(logits=self._logits) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py index e9bb4f14e9e24128382832fcdaccdc9b24017046..705a871d482565897e7ac850327729a6186f1746 100644 --- a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -31,6 +31,7 @@ _allowed_symbols = [ "NormalMeanNegativeLogProbLoss", "NormalMeanVarianceNegativeLogProbLoss", "CategoricalLogitsNegativeLogProbLoss", + "OnehotCategoricalLogitsNegativeLogProbLoss", "MultiBernoulliNegativeLogProbLoss", "MultiBernoulliNegativeLogProbLoss", "insert_slice_in_zeros", diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py index 831870fca451c585cb1a1dc6b24aad757e2bbaa8..b6d9d37a31a949b154b79e6f3677289a0d167373 100644 --- a/tensorflow/contrib/kfac/python/ops/op_queue.py +++ b/tensorflow/contrib/kfac/python/ops/op_queue.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import ops as tf_ops diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py index ecf7f3e4e5ab7d9c151f760fdab733bc3830e37b..1974b07acfc879dc4bc844db9af88fd1043d6698 100644 --- a/tensorflow/contrib/kfac/python/ops/optimizer.py +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -41,12 +41,12 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): damping, layer_collection, var_list=None, - momentum=0., + momentum=0.9, momentum_type="regular", norm_constraint=None, name="KFAC", estimation_mode="gradients", - colocate_gradients_with_ops=False, + colocate_gradients_with_ops=True, cov_devices=None, inv_devices=None): """Initializes the KFAC optimizer with the given settings. @@ -70,8 +70,8 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): var_list: Optional list or tuple of variables to train. Defaults to the list of variables collected in the graph under the key `GraphKeys.TRAINABLE_VARIABLES`. - momentum: The momentum value for this optimizer. Only applies when - momentum_type is 'regular' or 'adam'. (Default: 0) + momentum: The momentum decay constant to use. Only applies when + momentum_type is 'regular' or 'adam'. (Default: 0.9) momentum_type: The type of momentum to use in this optimizer, one of 'regular', 'adam', or 'qmodel'. (Default: 'regular') norm_constraint: float or Tensor. If specified, the update is scaled down @@ -85,6 +85,7 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): more a more detailed description of these options. colocate_gradients_with_ops: Whether we should request gradients we compute in the estimator be colocated with their respective ops. + (Default: True) cov_devices: Iterable of device strings (e.g. '/gpu:0'). Covariance computations will be placed on these devices in a round-robin fashion. Can be None, which means that no devices are specified. @@ -136,12 +137,32 @@ class KfacOptimizer(gradient_descent.GradientDescentOptimizer): self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] self._losses = layer_collection.losses - self.cov_update_op = self._fisher_est.cov_update_op - self.inv_update_op = self._fisher_est.inv_update_op - self.inv_updates_dict = self._fisher_est.inv_updates_dict - super(KfacOptimizer, self).__init__(learning_rate, name=name) + @property + def cov_update_thunks(self): + return self._fisher_est.cov_update_thunks + + @property + def cov_update_ops(self): + return self._fisher_est.cov_update_ops + + @property + def cov_update_op(self): + return self._fisher_est.cov_update_op + + @property + def inv_update_thunks(self): + return self._fisher_est.inv_update_thunks + + @property + def inv_update_ops(self): + return self._fisher_est.inv_update_ops + + @property + def inv_update_op(self): + return self._fisher_est.inv_update_op + @property def variables(self): return self._fisher_est.variables diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py index d5461c9f2ea0512ad7c4f2d393ac8e7f441d1b77..e89508fa46b6e2ce278e5373e6c9d17203ad1ef2 100644 --- a/tensorflow/contrib/kfac/python/ops/utils.py +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -20,16 +20,22 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.tpu.python.ops import tpu_ops +from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variables # Method used for inverting matrices. POSDEF_INV_METHOD = "cholesky" +POSDEF_EIG_METHOD = "self_adjoint" def set_global_constants(posdef_inv_method=None): @@ -161,33 +167,11 @@ def mat2d_to_layer_params(vector_template, mat2d): return array_ops.reshape(mat2d, vector_template.shape) -def compute_pi(left_factor, right_factor): - """Computes the scalar constant pi for Tikhonov regularization/damping. - - pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) - See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. - - Args: - left_factor: The left Kronecker factor Tensor. - right_factor: The right Kronecker factor Tensor. - - Returns: - The computed scalar constant pi for these Kronecker Factors (as a Tensor). - """ - # Instead of dividing by the dim of the norm, we multiply by the dim of the - # other norm. This works out the same in the ratio. - left_norm = math_ops.trace(left_factor) * right_factor.get_shape().as_list()[ - 0] - right_norm = math_ops.trace(right_factor) * left_factor.get_shape().as_list()[ - 0] - return math_ops.sqrt(left_norm / right_norm) - - def posdef_inv(tensor, damping): """Computes the inverse of tensor + damping * identity.""" identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) damping = math_ops.cast(damping, dtype=tensor.dtype) - return posdef_inv_funcs[POSDEF_INV_METHOD](tensor, identity, damping) + return posdef_inv_functions[POSDEF_INV_METHOD](tensor, identity, damping) def posdef_inv_matrix_inverse(tensor, identity, damping): @@ -209,23 +193,51 @@ def posdef_inv_eig(tensor, identity, damping): eigenvectors / eigenvalues, eigenvectors, transpose_b=True) -posdef_inv_funcs = { +posdef_inv_functions = { "matrix_inverse": posdef_inv_matrix_inverse, "cholesky": posdef_inv_cholesky, "eig": posdef_inv_eig, } +def posdef_eig(mat): + """Computes the eigendecomposition of a positive semidefinite matrix.""" + return posdef_eig_functions[POSDEF_EIG_METHOD](mat) + + +def posdef_eig_svd(mat): + """Computes the singular values and left singular vectors of a matrix.""" + evals, evecs, _ = linalg_ops.svd(mat) + + return evals, evecs + + +def posdef_eig_self_adjoint(mat): + """Computes eigendecomposition using self_adjoint_eig.""" + evals, evecs = linalg_ops.self_adjoint_eig(mat) + evals = math_ops.abs(evals) # Should be equivalent to svd approach. + + return evals, evecs + + +posdef_eig_functions = { + "self_adjoint": posdef_eig_self_adjoint, + "svd": posdef_eig_svd, +} + + class SubGraph(object): """Defines a subgraph given by all the dependencies of a given set of outputs. """ def __init__(self, outputs): + # Set of all ancestor Tensors, Ops to 'outputs'. self._members = set() self._recurse_add(outputs) def _recurse_add(self, nodes): + """Recursively adds all of nodes' ancestors.""" for node in nodes: if node in self._members: continue @@ -241,8 +253,25 @@ class SubGraph(object): return node in self._members def variable_uses(self, var): - """Computes number of times a variable is used.""" - return len(self._members.intersection(set(var.value().consumers()))) + """Computes number of times a variable is used. + + Args: + var: Variable or ResourceVariable instance. + + Returns: + Number of times a variable is used within this subgraph. + + Raises: + ValueError: If 'var' is not a variable type. + """ + if isinstance(var, resource_variable_ops.ResourceVariable): + var = var.handle + elif isinstance(var, variables.Variable): + var = var.value() + else: + raise ValueError("%s does not appear to be a variable." % str(var)) + + return len(self._members.intersection(set(var.consumers()))) def filter_list(self, node_list): """Filters 'node_list' to nodes in this subgraph.""" @@ -287,5 +316,109 @@ def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): return dysdx + +def on_tpu(): + """Returns True when building a TPU computation.""" + return tpu_function.get_tpu_context().number_of_shards is not None + + +def cross_replica_mean(tensor, name=None): + """Takes mean value of a Tensor across all TPU cores. + + Args: + tensor: Tensor to be synchronized. + name: None or string. Name of Op. + + Returns: + Average of Tensor across all TPU cores. + + Raises: + ValueError: If called outside of TPU context. + """ + with ops.name_scope(name, "cross_replica_mean", [tensor]): + num_shards = tpu_function.get_tpu_context().number_of_shards + if num_shards is None: + raise ValueError( + "Cannot take cross_replica_mean() outside of TPU Context.") + if num_shards == 1: + return tensor + return tpu_ops.cross_replica_sum(tensor / num_shards) + + +def ensure_sequence(obj): + """If `obj` isn't a tuple or list, return a tuple containing `obj`.""" + if isinstance(obj, (tuple, list)): + return obj + else: + return (obj,) + + +def batch_execute(global_step, thunks, batch_size, name=None): + """Executes a subset of ops per global step. + + Given a list of thunks, each of which produces a single stateful op, + ensures that exactly 'batch_size' ops are run per global step. Ops are + scheduled in a round-robin fashion. For example, with 3 ops + + global_step | op0 | op1 | op2 + ------------+-----+-----+----- + 0 | x | x | + ------------+-----+-----+----- + 1 | x | | x + ------------+-----+-----+----- + 2 | | x | x + ------------+-----+-----+----- + 3 | x | x | + ------------+-----+-----+----- + 4 | x | | x + + Does not guarantee order of op execution within a single global step. + + Args: + global_step: Tensor indicating time. Determines which ops run. + thunks: List of thunks. Each thunk encapsulates one op. Return values are + ignored. + batch_size: int. Number of ops to execute per global_step. + name: string or None. Name scope for newly added ops. + + Returns: + List of ops. Exactly 'batch_size' ops are guaranteed to have an effect + every global step. + """ + + def true_fn(thunk): + """Ensures thunk is executed and returns an Op (not a Tensor).""" + + def result(): + with ops.control_dependencies([thunk()]): + return control_flow_ops.no_op() + + return result + + def false_fn(_): + """Executes a no-op.""" + + def result(): + return control_flow_ops.no_op() + + return result + + with ops.name_scope(name, "batch_execute"): + true_fns = [true_fn(thunk) for thunk in thunks] + false_fns = [false_fn(thunk) for thunk in thunks] + num_thunks = len(thunks) + conditions = [ + math_ops.less( + math_ops.mod(batch_size - 1 + global_step * batch_size - j, + num_thunks), batch_size) for j in range(num_thunks) + ] + result = [ + control_flow_ops.cond(condition, true_fn, false_fn) + for (condition, true_fn, + false_fn) in zip(conditions, true_fns, false_fns) + ] + return result + + # TODO(b/69623235): Add a function for finding tensors that share gradients # to eliminate redundant fisher factor computations. diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py index 9df07d69aad5e61f9cfb994c9a63fdec04f025fe..fe8e39c212c2c3381f9aa6fdb9fdf423ff958481 100644 --- a/tensorflow/contrib/kfac/python/ops/utils_lib.py +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -24,13 +24,13 @@ from tensorflow.python.util.all_util import remove_undocumented # pylint: enable=unused-import,line-too-long,wildcard-import _allowed_symbols = [ + "set_global_constants", "SequenceDict", "tensors_to_column", "column_to_tensors", "kronecker_product", "layer_params_to_mat2d", "mat2d_to_layer_params", - "compute_pi", "posdef_inv", "posdef_inv_matrix_inverse", "posdef_inv_cholesky", @@ -38,6 +38,8 @@ _allowed_symbols = [ "SubGraph", "generate_random_signs", "fwd_gradients", + "ensure_sequence", + "batch_execute", ] remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py index 1f4a3ef568efc459d4a36fcb0d5de7e0bce8335c..e70b4923749d89aba1bd0187857d762305daeb07 100644 --- a/tensorflow/contrib/labeled_tensor/python/ops/core_test.py +++ b/tensorflow/contrib/labeled_tensor/python/ops/core_test.py @@ -225,7 +225,7 @@ class LabeledTensorTest(test_util.Base): tensor = array_ops.placeholder(dtypes.string, [None]) actual = core.LabeledTensor(tensor, ['x']) self.assertIsNone(actual.axes['x'].size) - self.assertIs(actual.axes['x'].value, tensor.get_shape()[0]) + self.assertIsNone(actual.axes['x'].value.value) def test_eq(self): self.assertEqual(self.lt, self.lt) diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index 6c624929f20503054e0258aad8a843f4a201be64..337c9e06b870b2cca53fcdbf3d94225660e193c4 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -27,6 +27,7 @@ See the @{$python/contrib.layers} guide. @@convolution2d_transpose @@conv3d_transpose @@convolution3d_transpose +@@dense_to_sparse @@dropout @@elu @@embedding_lookup_unique @@ -34,6 +35,7 @@ See the @{$python/contrib.layers} guide. @@fully_connected @@GDN @@gdn +@@images_to_sequence @@layer_norm @@linear @@max_pool2d @@ -49,6 +51,7 @@ See the @{$python/contrib.layers} guide. @@scale_gradient @@separable_conv2d @@separable_convolution2d +@@sequence_to_images @@softmax @@spatial_softmax @@stack diff --git a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc index 932c5ab99249feda1e3a7f2d707ce4237fe7177f..01893d60615a9b4ded2afc88c6de0168d4be0921 100644 --- a/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc +++ b/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc @@ -423,8 +423,9 @@ class SparseFeatureCrossOp : public OpKernel { "Input values should be a std::vector but received shape ", values_list_in[i].shape().DebugString(), " at position ", i)); OP_REQUIRES( - context, indices_list_in[i].shape().dim_size(0) == - values_list_in[i].shape().dim_size(0), + context, + indices_list_in[i].shape().dim_size(0) == + values_list_in[i].shape().dim_size(0), errors::InvalidArgument( "Expected size of values to be ", indices_list_in[i].shape().dim_size(0), " got ", diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index 226d933d85d91600e36ffb84212703e10455bfbb..b7d34d6435789e54403926a342481971e854b449 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -156,6 +156,10 @@ from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import deprecation +# Imports the core `InputLayer` symbol in contrib during development. +InputLayer = fc_core.InputLayer # pylint: disable=invalid-name + + class _LinearEmbeddingLookupArguments( collections.namedtuple("_LinearEmbeddingLookupArguments", ["input_tensor", @@ -521,7 +525,7 @@ def sparse_column_with_integerized_feature(column_name, Args: column_name: A string defining sparse column name. - bucket_size: An int that is > 1. The number of buckets. It should be bigger + bucket_size: An int that is >= 1. The number of buckets. It should be bigger than maximum feature. In other words features in this column should be an int64 in range [0, bucket_size) combiner: A string specifying how to reduce if the sparse column is @@ -539,7 +543,7 @@ def sparse_column_with_integerized_feature(column_name, An integerized _SparseColumn definition. Raises: - ValueError: bucket_size is not greater than 1. + ValueError: bucket_size is less than 1. ValueError: dtype is not integer. """ return _SparseColumnIntegerized( @@ -748,6 +752,10 @@ class _WeightedSparseColumn( {self.weight_column_name: parsing_ops.VarLenFeature(self.dtype)}) return config + @property + def lookup_config(self): + return self.sparse_id_column.lookup_config + @property def key(self): """Returns a string which will be used as a key when we do sorting.""" diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops.py b/tensorflow/contrib/layers/python/layers/feature_column_ops.py index fa0047f05d893f6543ddb1680824a32469e13293..78affea44cbfb92523063968dbc1be98841854db 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops.py @@ -97,10 +97,13 @@ def _input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank, - default_name): + default_name, + cols_to_outs=None): """Implementation of `input_from(_sequence)_feature_columns`.""" columns_to_tensors = columns_to_tensors.copy() check_feature_columns(feature_columns) + if cols_to_outs is not None and not isinstance(cols_to_outs, dict): + raise ValueError('cols_to_outs must be a dict unless None') with variable_scope.variable_scope(scope, default_name=default_name, values=columns_to_tensors.values()): @@ -144,6 +147,8 @@ def _input_from_feature_columns(columns_to_tensors, except ValueError as e: raise ValueError('Error creating input layer for column: {}.\n' '{}, {}'.format(column.name, e, ee)) + if cols_to_outs is not None: + cols_to_outs[column] = output_tensors[-1] return array_ops.concat(output_tensors, output_rank - 1) @@ -151,7 +156,8 @@ def input_from_feature_columns(columns_to_tensors, feature_columns, weight_collections=None, trainable=True, - scope=None): + scope=None, + cols_to_outs=None): """A tf.contrib.layers style input layer builder based on FeatureColumns. Generally a single example in training data is described with feature columns. @@ -196,6 +202,8 @@ def input_from_feature_columns(columns_to_tensors, trainable: If `True` also add variables to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). scope: Optional scope for variable_scope. + cols_to_outs: Optional dict from feature column to output tensor, + which is concatenated into the returned tensor. Returns: A Tensor which can be consumed by hidden layers in the neural network. @@ -209,7 +217,8 @@ def input_from_feature_columns(columns_to_tensors, trainable, scope, output_rank=2, - default_name='input_from_feature_columns') + default_name='input_from_feature_columns', + cols_to_outs=cols_to_outs) @experimental diff --git a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py index fbfa0e32de55edab3c90189ddfe05ab826ac9167..e6bbd86ab722c4e853a59f816bed8a8ac1fe9ede 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_ops_test.py @@ -607,6 +607,31 @@ class CreateInputLayersForDNNsTest(test.TestCase): # Verify cross compatibility: Core builder output should equal to contrib. self.assertAllEqual(output.eval().shape, output_core.eval().shape) + def testAllDNNColumnsWithColumnwiseOutputs(self): + sparse_column = feature_column.sparse_column_with_keys( + "ids", ["a", "b", "c", "unseen"]) + real_valued_column = feature_column.real_valued_column("income", 2) + one_hot_column = feature_column.one_hot_column(sparse_column) + embedding_column = feature_column.embedding_column(sparse_column, 10) + features = { + "ids": + sparse_tensor.SparseTensor( + values=["c", "b", "a"], + indices=[[0, 0], [1, 0], [2, 0]], + dense_shape=[3, 1]), + "income": + constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]), + } + columns = [one_hot_column, embedding_column, real_valued_column] + cols_to_outs = {} + feature_column_ops.input_from_feature_columns( + features, columns, cols_to_outs=cols_to_outs) + with self.test_session(): + variables_lib.global_variables_initializer().run() + lookup_ops.tables_initializer().run() + for column in columns: + self.assertTrue(column in cols_to_outs) + def testRealValuedColumn(self): real_valued = feature_column.real_valued_column("price") features = {"price": constant_op.constant([[20.], [110], [-3]])} diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index 5ae885b7202357326bd8494d382adb57fa636d20..fc8f153fe3abdc83aca5abfa9a4bb5f5d5531480 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -102,6 +102,16 @@ class FeatureColumnTest(test.TestCase): weighted_ids = fc.weighted_sparse_column(ids, "weights") self.assertEqual(weighted_ids.name, "ids_weighted_by_weights") + def testWeightedSparseColumnWithVocabularyFile(self): + ids = fc.sparse_column_with_vocabulary_file( + "ids", "a_file", num_oov_buckets=7, vocab_size=3) + weighted_ids = fc.weighted_sparse_column(ids, "weights") + self.assertEqual(weighted_ids.name, "ids_weighted_by_weights") + self.assertEqual(weighted_ids.lookup_config, ids.lookup_config) + self.assertEqual(weighted_ids.lookup_config.vocab_size, 3) + self.assertEqual(weighted_ids.lookup_config.num_oov_buckets, 7) + self.assertEqual(weighted_ids.lookup_config.vocabulary_file, "a_file") + def testWeightedSparseColumnDeepCopy(self): ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"]) weighted = fc.weighted_sparse_column(ids, "weights") @@ -211,8 +221,8 @@ class FeatureColumnTest(test.TestCase): weighted_sparse_col = fc.weighted_sparse_column(ids, "weights") self.assertEqual(weighted_sparse_col.name, "ids_weighted_by_weights") - b = fc.shared_embedding_columns([sparse_col, weighted_sparse_col], - dimension=4, combiner="mean") + b = fc.shared_embedding_columns( + [sparse_col, weighted_sparse_col], dimension=4, combiner="mean") self.assertEqual(len(b), 2) self.assertEqual(b[0].shared_embedding_name, "a1_ids_weighted_by_weights_shared_embedding") @@ -220,8 +230,8 @@ class FeatureColumnTest(test.TestCase): "a1_ids_weighted_by_weights_shared_embedding") # Tries reversing order to check compatibility condition. - b = fc.shared_embedding_columns([weighted_sparse_col, sparse_col], - dimension=4, combiner="mean") + b = fc.shared_embedding_columns( + [weighted_sparse_col, sparse_col], dimension=4, combiner="mean") self.assertEqual(len(b), 2) self.assertEqual(b[0].shared_embedding_name, "a1_ids_weighted_by_weights_shared_embedding") @@ -230,18 +240,17 @@ class FeatureColumnTest(test.TestCase): # Tries adding two weighted columns to check compatibility between them. weighted_sparse_col_2 = fc.weighted_sparse_column(ids, "weights_2") - b = fc.shared_embedding_columns([weighted_sparse_col, - weighted_sparse_col_2], - dimension=4, combiner="mean") + b = fc.shared_embedding_columns( + [weighted_sparse_col, weighted_sparse_col_2], + dimension=4, + combiner="mean") self.assertEqual(len(b), 2) self.assertEqual( b[0].shared_embedding_name, - "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding" - ) + "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding") self.assertEqual( b[1].shared_embedding_name, - "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding" - ) + "ids_weighted_by_weights_ids_weighted_by_weights_2_shared_embedding") def testSharedEmbeddingColumnDeterminism(self): # Tests determinism in auto-generated shared_embedding_name. @@ -276,10 +285,10 @@ class FeatureColumnTest(test.TestCase): columns = fc.shared_embedding_columns( [a1, a2], dimension=4, combiner="mean") columns_copy = copy.deepcopy(columns) - self.assertEqual( - columns_copy[0].shared_embedding_name, "a1_a2_shared_embedding") - self.assertEqual( - columns_copy[1].shared_embedding_name, "a1_a2_shared_embedding") + self.assertEqual(columns_copy[0].shared_embedding_name, + "a1_a2_shared_embedding") + self.assertEqual(columns_copy[1].shared_embedding_name, + "a1_a2_shared_embedding") def testOneHotColumn(self): a = fc.sparse_column_with_keys("a", ["a", "b", "c", "d"]) @@ -326,11 +335,11 @@ class FeatureColumnTest(test.TestCase): weighted_ids = fc.weighted_sparse_column(ids, "weights") one_hot = fc.one_hot_column(weighted_ids) features = { - 'ids': constant_op.constant([['marlo', 'unknown', 'omar']]), - 'weights': constant_op.constant([[2., 4., 6.]]) + "ids": constant_op.constant([["marlo", "unknown", "omar"]]), + "weights": constant_op.constant([[2., 4., 6.]]) } one_hot_tensor = feature_column_ops.input_from_feature_columns( - features, [one_hot]) + features, [one_hot]) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) @@ -339,11 +348,9 @@ class FeatureColumnTest(test.TestCase): def testMissingValueInOneHotColumnForSparseColumnWithKeys(self): ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"]) one_hot = fc.one_hot_column(ids) - features = { - 'ids': constant_op.constant([['marlo', 'unknown', 'omar']]) - } + features = {"ids": constant_op.constant([["marlo", "unknown", "omar"]])} one_hot_tensor = feature_column_ops.input_from_feature_columns( - features, [one_hot]) + features, [one_hot]) with self.test_session() as sess: sess.run(variables.global_variables_initializer()) sess.run(lookup_ops.tables_initializer()) @@ -369,8 +376,7 @@ class FeatureColumnTest(test.TestCase): self.assertEqual(d4.default_value, None) self.assertEqual(d4.is_sparse, True) # Default value is a list but dimension is None. - with self.assertRaisesRegexp(ValueError, - "Only scalar default value.*"): + with self.assertRaisesRegexp(ValueError, "Only scalar default value.*"): fc._real_valued_var_len_column("g5", default_value=[2., 3.]) def testRealValuedVarLenColumnDtypes(self): @@ -380,18 +386,19 @@ class FeatureColumnTest(test.TestCase): "rvc": parsing_ops.VarLenFeature(dtype=dtypes.float32) }, rvc.config) - rvc = fc._real_valued_var_len_column("rvc", default_value=0, - is_sparse=False) - self.assertDictEqual( - { - "rvc": parsing_ops.FixedLenSequenceFeature(shape=[], - dtype=dtypes.float32, - allow_missing=True, - default_value=0.0) - }, rvc.config) - - rvc = fc._real_valued_var_len_column("rvc", dtype=dtypes.int32, - default_value=0, is_sparse=True) + rvc = fc._real_valued_var_len_column( + "rvc", default_value=0, is_sparse=False) + self.assertDictEqual({ + "rvc": + parsing_ops.FixedLenSequenceFeature( + shape=[], + dtype=dtypes.float32, + allow_missing=True, + default_value=0.0) + }, rvc.config) + + rvc = fc._real_valued_var_len_column( + "rvc", dtype=dtypes.int32, default_value=0, is_sparse=True) self.assertDictEqual( { "rvc": parsing_ops.VarLenFeature(dtype=dtypes.int32) @@ -399,8 +406,8 @@ class FeatureColumnTest(test.TestCase): with self.assertRaisesRegexp(TypeError, "dtype must be convertible to float"): - fc._real_valued_var_len_column("rvc", dtype=dtypes.string, - default_value="", is_sparse=True) + fc._real_valued_var_len_column( + "rvc", dtype=dtypes.string, default_value="", is_sparse=True) def testRealValuedColumn(self): a = fc.real_valued_column("aaa") @@ -494,13 +501,13 @@ class FeatureColumnTest(test.TestCase): for output_rank in range(1, 3 + len(dimensions)): with variable_scope.variable_scope("output_rank_{}".format(output_rank)): real_valued_output = real_valued_column._to_dnn_input_layer( - constant_op.constant( - real_valued_input, dtype=dtypes.float32), + constant_op.constant(real_valued_input, dtype=dtypes.float32), output_rank=output_rank) with self.test_session() as sess: real_valued_eval = sess.run(real_valued_output) - expected_shape = (input_shape[:output_rank - 1] + - [np.prod(input_shape[output_rank - 1:])]) + expected_shape = ( + input_shape[:output_rank - 1] + + [np.prod(input_shape[output_rank - 1:])]) self.assertEquals(expected_shape, list(real_valued_eval.shape)) def testRealValuedColumnDensification(self): @@ -510,8 +517,7 @@ class FeatureColumnTest(test.TestCase): "sparse_real_valued1", is_sparse=True) sparse_tensor = sparse_tensor_lib.SparseTensor( values=[2.0, 5.0], indices=[[0, 0], [2, 0]], dense_shape=[3, 1]) - with self.assertRaisesRegexp( - ValueError, "Set is_sparse to False"): + with self.assertRaisesRegexp(ValueError, "Set is_sparse to False"): real_valued_column._to_dnn_input_layer(sparse_tensor) def testRealValuedColumnDeepCopy(self): @@ -539,9 +545,8 @@ class FeatureColumnTest(test.TestCase): def testBucketizedColumnRequiresRealValuedColumnDimension(self): with self.assertRaisesRegexp( TypeError, "source_column must be an instance of _RealValuedColumn.*"): - fc.bucketized_column(fc._real_valued_var_len_column("bbb", - is_sparse=True), - [0]) + fc.bucketized_column( + fc._real_valued_var_len_column("bbb", is_sparse=True), [0]) def testBucketizedColumnRequiresSortedBuckets(self): with self.assertRaisesRegexp(ValueError, @@ -644,20 +649,14 @@ class FeatureColumnTest(test.TestCase): def testRealValuedColumnDtypes(self): rvc = fc.real_valued_column("rvc") - self.assertDictEqual( - { - "rvc": parsing_ops.FixedLenFeature( - [1], dtype=dtypes.float32) - }, - rvc.config) + self.assertDictEqual({ + "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.float32) + }, rvc.config) rvc = fc.real_valued_column("rvc", dtype=dtypes.int32) - self.assertDictEqual( - { - "rvc": parsing_ops.FixedLenFeature( - [1], dtype=dtypes.int32) - }, - rvc.config) + self.assertDictEqual({ + "rvc": parsing_ops.FixedLenFeature([1], dtype=dtypes.int32) + }, rvc.config) with self.assertRaisesRegexp(ValueError, "dtype must be convertible to float"): @@ -692,8 +691,9 @@ class FeatureColumnTest(test.TestCase): batch_size = 4 dense_scalar_input = [1, 2, 3, 4] sparse_column = fc.sparse_column_with_integerized_feature("values", 10) - features = {"values": - constant_op.constant(dense_scalar_input, dtype=dtypes.int64)} + features = { + "values": constant_op.constant(dense_scalar_input, dtype=dtypes.int64) + } sparse_column.insert_transformed_feature(features) sparse_output = features[sparse_column] expected_shape = [batch_size, 1] @@ -721,8 +721,7 @@ class FeatureColumnTest(test.TestCase): def testSparseColumnKeysDeepCopy(self): """Tests deepcopy of sparse_column_with_keys.""" - column = fc.sparse_column_with_keys( - "a", keys=["key0", "key1", "key2"]) + column = fc.sparse_column_with_keys("a", keys=["key0", "key1", "key2"]) self.assertEqual("a", column.name) column_copy = copy.deepcopy(column) self.assertEqual("a", column_copy.name) @@ -775,8 +774,9 @@ class FeatureColumnTest(test.TestCase): a = fc.sparse_column_with_hash_bucket("cross_aaa", hash_bucket_size=100) b = fc.sparse_column_with_hash_bucket("cross_bbb", hash_bucket_size=100) cross_col = fc.crossed_column(set([a, b]), hash_bucket_size=10000) - one_hot_col = fc.one_hot_column(fc.sparse_column_with_hash_bucket( - "sparse_column_for_one_hot", hash_bucket_size=100)) + one_hot_col = fc.one_hot_column( + fc.sparse_column_with_hash_bucket( + "sparse_column_for_one_hot", hash_bucket_size=100)) scattered_embedding_col = fc.scattered_embedding_column( "scattered_embedding_column", size=100, dimension=10, hash_key=1) feature_columns = set([ @@ -799,17 +799,13 @@ class FeatureColumnTest(test.TestCase): "str_id_weights_column": parsing_ops.VarLenFeature(dtypes.float32), "real_valued_column1": - parsing_ops.FixedLenFeature( - [1], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([1], dtype=dtypes.float32), "real_valued_column2": - parsing_ops.FixedLenFeature( - [5], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([5], dtype=dtypes.float32), "real_valued_column_for_bucketization1": - parsing_ops.FixedLenFeature( - [1], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([1], dtype=dtypes.float32), "real_valued_column_for_bucketization2": - parsing_ops.FixedLenFeature( - [4], dtype=dtypes.float32), + parsing_ops.FixedLenFeature([4], dtype=dtypes.float32), "cross_aaa": parsing_ops.VarLenFeature(dtypes.string), "cross_bbb": @@ -839,11 +835,14 @@ class FeatureColumnTest(test.TestCase): real_valued_col0 = fc._real_valued_var_len_column( "real_valued_column0", is_sparse=True) real_valued_col1 = fc._real_valued_var_len_column( - "real_valued_column1", dtype=dtypes.int64, default_value=0, + "real_valued_column1", + dtype=dtypes.int64, + default_value=0, is_sparse=False) feature_columns = set([real_valued_col0, real_valued_col1]) expected_config = { - "real_valued_column0": parsing_ops.VarLenFeature(dtype=dtypes.float32), + "real_valued_column0": + parsing_ops.VarLenFeature(dtype=dtypes.float32), "real_valued_column1": parsing_ops.FixedLenSequenceFeature( [], dtype=dtypes.int64, allow_missing=True, default_value=0), @@ -864,7 +863,9 @@ class FeatureColumnTest(test.TestCase): real_valued_col5 = fc._real_valued_var_len_column( "real_valued_column5", default_value=2, is_sparse=True) real_valued_col6 = fc._real_valued_var_len_column( - "real_valued_column6", dtype=dtypes.int64, default_value=1, + "real_valued_column6", + dtype=dtypes.int64, + default_value=1, is_sparse=False) feature_columns = [ real_valued_col1, real_valued_col2, real_valued_col3, real_valued_col4, @@ -892,8 +893,7 @@ class FeatureColumnTest(test.TestCase): parsing_ops.VarLenFeature(dtype=dtypes.float32), "real_valued_column6": parsing_ops.FixedLenSequenceFeature( - [], dtype=dtypes.int64, allow_missing=True, - default_value=1) + [], dtype=dtypes.int64, allow_missing=True, default_value=1) }, config) @@ -1094,8 +1094,8 @@ class FeatureColumnTest(test.TestCase): # This will initialize the crossed column weights from provided checkpoint # and return a [4, 1] tensor which is same as weights variable. Since we # won't modify weights, this should be same as 'saved_col_weights'. - _, col_weights, _ = (feature_column_ops.weighted_sum_from_feature_columns( - { + _, col_weights, _ = ( + feature_column_ops.weighted_sum_from_feature_columns({ sparse_col_1.name: input_tensor, sparse_col_2.name: input_tensor }, [crossed_col_initialized], 1)) diff --git a/tensorflow/contrib/layers/python/layers/initializers.py b/tensorflow/contrib/layers/python/layers/initializers.py index b12a882d9ae88f7cf4f920cfa5872e5de1c67290..51610f21b24f1d40f26630cc1e69ca723d130639 100644 --- a/tensorflow/contrib/layers/python/layers/initializers.py +++ b/tensorflow/contrib/layers/python/layers/initializers.py @@ -79,7 +79,8 @@ def variance_scaling_initializer(factor=2.0, mode='FAN_IN', uniform=False, ``` * To get [Delving Deep into Rectifiers]( - http://arxiv.org/pdf/1502.01852v1.pdf), use (Default):
+ http://arxiv.org/pdf/1502.01852v1.pdf) (also know as the "MSRA + initialization"), use (Default):
`factor=2.0 mode='FAN_IN' uniform=False` * To get [Convolutional Architecture for Fast Feature Embedding]( http://arxiv.org/abs/1408.5093), use:
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index 6cd586a5f016c76cc52b340bfd0d32fa08f23748..5c1ff9ec267f1bccd9bee44a4b19e7ed3ec24cf0 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -29,6 +29,7 @@ from tensorflow.contrib.framework.python.ops import variables from tensorflow.contrib.layers.python.layers import initializers from tensorflow.contrib.layers.python.layers import utils from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops @@ -54,47 +55,18 @@ from tensorflow.python.layers.maxout import maxout # TODO(b/28426988): Replace legacy_* fns migrated from slim. # TODO(b/28426988): Remove legacy_* when all uses have migrated to new API. -__all__ = ['avg_pool2d', - 'avg_pool3d', - 'batch_norm', - 'bias_add', - 'conv2d', - 'conv3d', - 'conv2d_in_plane', - 'conv2d_transpose', - 'conv3d_transpose', - 'convolution', - 'convolution2d', - 'convolution2d_in_plane', - 'convolution2d_transpose', - 'convolution3d', - 'convolution3d_transpose', - 'dropout', - 'elu', - 'flatten', - 'fully_connected', - 'GDN', - 'gdn', - 'layer_norm', - 'linear', - 'pool', - 'max_pool2d', - 'max_pool3d', - 'one_hot_encoding', - 'relu', - 'relu6', - 'repeat', - 'scale_gradient', - 'separable_conv2d', - 'separable_convolution2d', - 'softmax', - 'spatial_softmax', - 'stack', - 'unit_norm', - 'legacy_fully_connected', - 'legacy_linear', - 'legacy_relu', - 'maxout'] +__all__ = [ + 'avg_pool2d', 'avg_pool3d', 'batch_norm', 'bias_add', 'conv2d', 'conv3d', + 'conv2d_in_plane', 'conv2d_transpose', 'conv3d_transpose', 'convolution', + 'convolution2d', 'convolution2d_in_plane', 'convolution2d_transpose', + 'convolution3d', 'convolution3d_transpose', 'dense_to_sparse', + 'dropout', 'elu', 'flatten', 'fully_connected', 'GDN', 'gdn', + 'images_to_sequence', 'layer_norm', 'linear', 'pool', 'max_pool2d', + 'max_pool3d', 'one_hot_encoding', 'relu', 'relu6', 'repeat', + 'scale_gradient', 'separable_conv2d', 'separable_convolution2d', + 'sequence_to_images', 'softmax', 'spatial_softmax', 'stack', 'unit_norm', + 'legacy_fully_connected', 'legacy_linear', 'legacy_relu', 'maxout' +] DATA_FORMAT_NCHW = 'NCHW' DATA_FORMAT_NHWC = 'NHWC' @@ -139,13 +111,14 @@ def avg_pool2d(inputs, raise ValueError('data_format has to be either NCHW or NHWC.') with ops.name_scope(scope, 'AvgPool2D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.AveragePooling2D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.AveragePooling2D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -187,13 +160,14 @@ def avg_pool3d(inputs, raise ValueError('data_format has to be either NCDHW or NDHWC.') with ops.name_scope(scope, 'AvgPool3D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.AveragePooling3D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.AveragePooling3D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -298,8 +272,8 @@ def _fused_batch_norm(inputs, raise ValueError('Inputs %s has undefined rank' % inputs.name) elif original_rank not in [2, 4]: raise ValueError('Inputs %s has unsupported rank.' - ' Expected 2 or 4 but got %d' % ( - inputs.name, original_rank)) + ' Expected 2 or 4 but got %d' % (inputs.name, + original_rank)) if original_rank == 2: channels = inputs.get_shape()[-1].value if channels is None: @@ -393,6 +367,7 @@ def _fused_batch_norm(inputs, def _fused_batch_norm_training(): return nn.fused_batch_norm( inputs, gamma, beta, epsilon=epsilon, data_format=data_format) + def _fused_batch_norm_inference(): return nn.fused_batch_norm( inputs, @@ -403,9 +378,9 @@ def _fused_batch_norm(inputs, epsilon=epsilon, is_training=False, data_format=data_format) - outputs, mean, variance = utils.smart_cond(is_training, - _fused_batch_norm_training, - _fused_batch_norm_inference) + + outputs, mean, variance = utils.smart_cond( + is_training, _fused_batch_norm_training, _fused_batch_norm_inference) # If `is_training` doesn't have a constant value, because it is a `Tensor`, # a `Variable` or `Placeholder` then is_training_value will be None and @@ -415,6 +390,7 @@ def _fused_batch_norm(inputs, if need_updates: if updates_collections is None: no_updates = lambda: outputs + def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( @@ -424,9 +400,11 @@ def _fused_batch_norm(inputs, with ops.control_dependencies( [update_moving_mean, update_moving_variance]): return array_ops.identity(outputs) + outputs = utils.smart_cond(is_training, _force_updates, no_updates) else: moving_vars_fn = lambda: (moving_mean, moving_variance) + def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( @@ -434,9 +412,9 @@ def _fused_batch_norm(inputs, update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance - update_mean, update_variance = utils.smart_cond(is_training, - _delay_updates, - moving_vars_fn) + + update_mean, update_variance = utils.smart_cond( + is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) @@ -479,7 +457,12 @@ def batch_norm(inputs, Sergey Ioffe, Christian Szegedy - Can be used as a normalizer function for conv2d and fully_connected. + Can be used as a normalizer function for conv2d and fully_connected. The + normalization is over all but the last dimension if `data_format` is `NHWC` + and all but the second dimension if `data_format` is `NCHW`. In case of a 2D + tensor this corresponds to the batch dimension, while in case of a 4D tensor + this + corresponds to the batch and space dimensions. Note: when training, the moving_mean and moving_variance need to be updated. By default the update ops are placed in `tf.GraphKeys.UPDATE_OPS`, so they @@ -535,8 +518,8 @@ def batch_norm(inputs, then the batch normalization uses weighted mean and variance. (This can be used to correct for bias in training example selection.) - fused: if `True`, use a faster, fused implementation if possible. - If `None`, use the system recommended implementation. + fused: if `None` or `True`, use a faster, fused implementation if possible. + If `False`, use the system recommended implementation. data_format: A string. `NHWC` (default) and `NCHW` are supported. zero_debias_moving_mean: Use zero_debias for moving_mean. It creates a new pair of variables 'moving_mean/biased' and 'moving_mean/local_step'. @@ -588,10 +571,9 @@ def batch_norm(inputs, # implementation in normalization_layers.BatchNormalization. inputs = ops.convert_to_tensor(inputs) rank = inputs.get_shape().ndims - possible_to_fuse = (batch_weights is None and - not renorm and - rank in [2, 4] and - adjustment is None) + possible_to_fuse = ( + batch_weights is None and not renorm and rank in [2, 4] and + adjustment is None) if fused and possible_to_fuse and ( zero_debias_moving_mean or rank == 2 or updates_collections is not ops.GraphKeys.UPDATE_OPS): @@ -619,7 +601,9 @@ def batch_norm(inputs, layer_variable_getter = _build_variable_getter() with variable_scope.variable_scope( - scope, 'BatchNorm', [inputs], reuse=reuse, + scope, + 'BatchNorm', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) @@ -667,15 +651,15 @@ def batch_norm(inputs, outputs = layer.apply(inputs, training=is_training) # Add variables to collections. - _add_variable_to_collections( - layer.moving_mean, variables_collections, 'moving_mean') - _add_variable_to_collections( - layer.moving_variance, variables_collections, 'moving_variance') + _add_variable_to_collections(layer.moving_mean, variables_collections, + 'moving_mean') + _add_variable_to_collections(layer.moving_variance, variables_collections, + 'moving_variance') if layer.beta is not None: _add_variable_to_collections(layer.beta, variables_collections, 'beta') if layer.gamma is not None: - _add_variable_to_collections( - layer.gamma, variables_collections, 'gamma') + _add_variable_to_collections(layer.gamma, variables_collections, + 'gamma') if activation_fn is not None: outputs = activation_fn(outputs) @@ -715,8 +699,8 @@ def batch_norm(inputs, params_shape = inputs_shape[-1:] params_shape_broadcast = None if not params_shape.is_fully_defined(): - raise ValueError('Inputs %s has undefined channels dimension %s.' % ( - inputs.name, params_shape)) + raise ValueError('Inputs %s has undefined channels dimension %s.' % + (inputs.name, params_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None @@ -727,23 +711,25 @@ def batch_norm(inputs, 'beta') beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) - beta = variables.model_variable('beta', - shape=params_shape, - dtype=dtype, - initializer=beta_initializer, - collections=beta_collections, - trainable=trainable) + beta = variables.model_variable( + 'beta', + shape=params_shape, + dtype=dtype, + initializer=beta_initializer, + collections=beta_collections, + trainable=trainable) if scale: - gamma_collections = utils.get_variable_collections(variables_collections, - 'gamma') + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') gamma_initializer = param_initializers.get('gamma', init_ops.ones_initializer()) - gamma = variables.model_variable('gamma', - shape=params_shape, - dtype=dtype, - initializer=gamma_initializer, - collections=gamma_collections, - trainable=trainable) + gamma = variables.model_variable( + 'gamma', + shape=params_shape, + dtype=dtype, + initializer=gamma_initializer, + collections=gamma_collections, + trainable=trainable) # Create moving_mean and moving_variance variables and add them to the # appropriate collections. We disable variable partitioning while creating @@ -792,8 +778,8 @@ def batch_norm(inputs, mean, variance = nn.moments(inputs, moments_axes) else: if data_format == DATA_FORMAT_NCHW: - mean, variance = nn.weighted_moments(inputs, moments_axes, - batch_weights, keep_dims=True) + mean, variance = nn.weighted_moments( + inputs, moments_axes, batch_weights, keepdims=True) mean = array_ops.reshape(mean, [-1]) variance = array_ops.reshape(variance, [-1]) else: @@ -802,19 +788,21 @@ def batch_norm(inputs, moving_vars_fn = lambda: (moving_mean, moving_variance) if updates_collections is None: + def _force_updates(): """Internal function forces updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( moving_mean, mean, decay, zero_debias=zero_debias_moving_mean) update_moving_variance = moving_averages.assign_moving_average( moving_variance, variance, decay, zero_debias=False) - with ops.control_dependencies([update_moving_mean, - update_moving_variance]): + with ops.control_dependencies( + [update_moving_mean, update_moving_variance]): return array_ops.identity(mean), array_ops.identity(variance) - mean, variance = utils.smart_cond(is_training, - _force_updates, + + mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) else: + def _delay_updates(): """Internal function that delay updates moving_vars if is_training.""" update_moving_mean = moving_averages.assign_moving_average( @@ -823,9 +811,8 @@ def batch_norm(inputs, moving_variance, variance, decay, zero_debias=False) return update_moving_mean, update_moving_variance - update_mean, update_variance = utils.smart_cond(is_training, - _delay_updates, - moving_vars_fn) + update_mean, update_variance = utils.smart_cond( + is_training, _delay_updates, moving_vars_fn) ops.add_to_collections(updates_collections, update_mean) ops.add_to_collections(updates_collections, update_variance) # Use computed moments during training and moving_vars otherwise. @@ -893,8 +880,8 @@ def bias_add(inputs, """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') - with variable_scope.variable_scope(scope, 'BiasAdd', [inputs], - reuse=reuse) as sc: + with variable_scope.variable_scope( + scope, 'BiasAdd', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) dtype = inputs.dtype.base_dtype inputs_shape = inputs.get_shape() @@ -909,13 +896,16 @@ def bias_add(inputs, raise ValueError('`C` dimension must be known but is None') biases_collections = utils.get_variable_collections(variables_collections, 'biases') - biases = variables.model_variable('biases', - shape=[num_features,], - dtype=dtype, - initializer=initializer, - regularizer=regularizer, - collections=biases_collections, - trainable=trainable) + biases = variables.model_variable( + 'biases', + shape=[ + num_features, + ], + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + collections=biases_collections, + trainable=trainable) outputs = nn.bias_add(inputs, biases, data_format=data_format) if activation_fn is not None: outputs = activation_fn(outputs) @@ -1015,8 +1005,10 @@ def convolution(inputs, if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']: raise ValueError('Invalid data_format: %r' % (data_format,)) - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( scope, 'Conv', [inputs], reuse=reuse, @@ -1034,26 +1026,27 @@ def convolution(inputs, raise ValueError('Convolution not supported for input with rank', input_rank) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = layer_class(filters=num_outputs, - kernel_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - dilation_rate=rate, - activation=None, - use_bias=not normalizer_fn and biases_initializer, - kernel_initializer=weights_initializer, - bias_initializer=biases_initializer, - kernel_regularizer=weights_regularizer, - bias_regularizer=biases_regularizer, - activity_regularizer=None, - trainable=trainable, - name=sc.name, - dtype=inputs.dtype.base_dtype, - _scope=sc, - _reuse=reuse) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = layer_class( + filters=num_outputs, + kernel_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + dilation_rate=rate, + activation=None, + use_bias=not normalizer_fn and biases_initializer, + kernel_initializer=weights_initializer, + bias_initializer=biases_initializer, + kernel_regularizer=weights_regularizer, + bias_regularizer=biases_regularizer, + activity_regularizer=None, + trainable=trainable, + name=sc.name, + dtype=inputs.dtype.base_dtype, + _scope=sc, + _reuse=reuse) outputs = layer.apply(inputs) # Add variables to collections. @@ -1069,6 +1062,7 @@ def convolution(inputs, outputs = activation_fn(outputs) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + convolution2d = convolution convolution3d = convolution @@ -1144,13 +1138,14 @@ def convolution2d_in_plane( weights_shape = [kernel_h, kernel_w, 1, 1] weights_collections = utils.get_variable_collections( variables_collections, 'weights') - weights = variables.model_variable('weights', - shape=weights_shape, - dtype=dtype, - initializer=weights_initializer, - regularizer=weights_regularizer, - collections=weights_collections, - trainable=trainable) + weights = variables.model_variable( + 'weights', + shape=weights_shape, + dtype=dtype, + initializer=weights_initializer, + regularizer=weights_regularizer, + collections=weights_collections, + trainable=trainable) depthwise_weights = array_ops.tile(weights, [1, 1, num_filters_in, 1]) outputs = nn.depthwise_conv2d(inputs, depthwise_weights, [1, stride_h, stride_w, 1], padding) @@ -1161,13 +1156,16 @@ def convolution2d_in_plane( if biases_initializer is not None: biases_collections = utils.get_variable_collections( variables_collections, 'biases') - biases = variables.model_variable('biases', - shape=[num_filters_in,], - dtype=dtype, - initializer=biases_initializer, - regularizer=biases_regularizer, - collections=biases_collections, - trainable=trainable) + biases = variables.model_variable( + 'biases', + shape=[ + num_filters_in, + ], + dtype=dtype, + initializer=biases_initializer, + regularizer=biases_regularizer, + collections=biases_collections, + trainable=trainable) outputs = nn.bias_add(outputs, biases) if activation_fn is not None: @@ -1240,19 +1238,23 @@ def convolution2d_transpose( ValueError: If `data_format` is neither `NHWC` nor `NCHW`. ValueError: If `C` dimension of `inputs` is None. """ - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( - scope, 'Conv2d_transpose', [inputs], reuse=reuse, + scope, + 'Conv2d_transpose', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') layer = convolutional_layers.Convolution2DTranspose( filters=num_outputs, kernel_size=kernel_size, @@ -1349,19 +1351,23 @@ def convolution3d_transpose( ValueError: If `data_format` is neither `NDHWC` nor `NCDHW`. ValueError: If `C` dimension of `inputs` is None. """ - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( - scope, 'Conv3d_transpose', [inputs], reuse=reuse, + scope, + 'Conv3d_transpose', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: if data_format not in (DATA_FORMAT_NCDHW, DATA_FORMAT_NDHWC): raise ValueError('data_format has to be either NCDHW or NDHWC.') inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') layer = convolutional_layers.Convolution3DTranspose( filters=num_outputs, kernel_size=kernel_size, @@ -1396,6 +1402,30 @@ def convolution3d_transpose( return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def dense_to_sparse(tensor, eos_token=0, outputs_collections=None, scope=None): + """Converts a dense tensor into a sparse tensor. + An example use would be to convert dense labels to sparse ones + so that they can be fed to the ctc_loss. + + Args: + tensor: An `int` `Tensor` to be converted to a `Sparse`. + eos_token: An integer. + It is part of the target label that signfies the end of a sentence. + outputs_collections: Collection to add the outputs. + scope: Optional scope for name_scope. + """ + with variable_scope.variable_scope(scope, 'dense_to_sparse', [tensor]) as sc: + tensor = ops.convert_to_tensor(tensor) + indices = array_ops.where( + math_ops.not_equal(tensor, constant_op.constant(eos_token, + tensor.dtype))) + values = array_ops.gather_nd(tensor, indices) + shape = array_ops.shape(tensor, out_type=dtypes.int64) + outputs = sparse_tensor.SparseTensor(indices, values, shape) + return utils.collect_named_outputs(outputs_collections, sc.name, outputs) + + @add_arg_scope def dropout(inputs, keep_prob=0.5, @@ -1430,19 +1460,18 @@ def dropout(inputs, with variable_scope.variable_scope( scope, 'Dropout', [inputs], custom_getter=_model_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) - layer = core_layers.Dropout(rate=1 - keep_prob, - noise_shape=noise_shape, - seed=seed, - name=sc.name, - _scope=sc) + layer = core_layers.Dropout( + rate=1 - keep_prob, + noise_shape=noise_shape, + seed=seed, + name=sc.name, + _scope=sc) outputs = layer.apply(inputs, training=is_training) return utils.collect_named_outputs(outputs_collections, sc.name, outputs) @add_arg_scope -def flatten(inputs, - outputs_collections=None, - scope=None): +def flatten(inputs, outputs_collections=None, scope=None): """Flattens the input while maintaining the batch_size. Assumes that the first dimension represents the batch. @@ -1474,8 +1503,8 @@ def _sparse_inner_flatten(inputs, new_rank): outer_dimensions = inputs.dense_shape[:new_rank - 1] inner_dimensions = inputs.dense_shape[new_rank - 1:] - new_shape = array_ops.concat((outer_dimensions, - [math_ops.reduce_prod(inner_dimensions)]), 0) + new_shape = array_ops.concat( + (outer_dimensions, [math_ops.reduce_prod(inner_dimensions)]), 0) flattened = sparse_ops.sparse_reshape(inputs, new_shape) return flattened @@ -1541,10 +1570,18 @@ def _inner_flatten(inputs, new_rank, output_collections=None, scope=None): return utils.collect_named_outputs(output_collections, sc, flattened) -def _model_variable_getter(getter, name, shape=None, dtype=None, - initializer=None, regularizer=None, trainable=True, - collections=None, caching_device=None, - partitioner=None, rename=None, use_resource=None, +def _model_variable_getter(getter, + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + rename=None, + use_resource=None, **_): """Getter that uses model_variable for compatibility with core layers.""" short_name = name.split('/')[-1] @@ -1553,25 +1590,34 @@ def _model_variable_getter(getter, name, shape=None, dtype=None, name_components[-1] = rename[short_name] name = '/'.join(name_components) return variables.model_variable( - name, shape=shape, dtype=dtype, initializer=initializer, - regularizer=regularizer, collections=collections, trainable=trainable, - caching_device=caching_device, partitioner=partitioner, - custom_getter=getter, use_resource=use_resource) + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + collections=collections, + trainable=trainable, + caching_device=caching_device, + partitioner=partitioner, + custom_getter=getter, + use_resource=use_resource) def _build_variable_getter(rename=None): """Build a model variable getter that respects scope getter and renames.""" + # VariableScope will nest the getters def layer_variable_getter(getter, *args, **kwargs): kwargs['rename'] = rename return _model_variable_getter(getter, *args, **kwargs) + return layer_variable_getter def _add_variable_to_collections(variable, collections_set, collections_name): """Adds variable (or all its parts) to all collections with that name.""" - collections = utils.get_variable_collections( - collections_set, collections_name) or [] + collections = utils.get_variable_collections(collections_set, + collections_name) or [] variables_list = [variable] if isinstance(variable, tf_variables.PartitionedVariable): variables_list = [v for v in variable] @@ -1640,15 +1686,19 @@ def fully_connected(inputs, ValueError: If x has rank less than 2 or if its last dimension is not set. """ if not isinstance(num_outputs, six.integer_types): - raise ValueError( - 'num_outputs should be int or long, got %s.' % (num_outputs,)) + raise ValueError('num_outputs should be int or long, got %s.' % + (num_outputs,)) - layer_variable_getter = _build_variable_getter({'bias': 'biases', - 'kernel': 'weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) with variable_scope.variable_scope( - scope, 'fully_connected', [inputs], - reuse=reuse, custom_getter=layer_variable_getter) as sc: + scope, + 'fully_connected', [inputs], + reuse=reuse, + custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) layer = core_layers.Dense( units=num_outputs, @@ -1754,15 +1804,17 @@ class GDN(base.Layer): inverse=False, beta_min=1e-6, gamma_init=.1, - reparam_offset=2 ** -18, + reparam_offset=2**-18, data_format='channels_last', activity_regularizer=None, trainable=True, name=None, **kwargs): - super(GDN, self).__init__(trainable=trainable, name=name, - activity_regularizer=activity_regularizer, - **kwargs) + super(GDN, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) self.inverse = inverse self._beta_min = beta_min self._gamma_init = gamma_init @@ -1797,8 +1849,9 @@ class GDN(base.Layer): with ops.name_scope(name, 'GDNLowerBound', [inputs, bound]) as scope: inputs = ops.convert_to_tensor(inputs, name='inputs') bound = ops.convert_to_tensor(bound, name='bound') - with ops.get_default_graph().gradient_override_map( - {'Maximum': 'GDNLowerBound'}): + with ops.get_default_graph().gradient_override_map({ + 'Maximum': 'GDNLowerBound' + }): return math_ops.maximum(inputs, bound, name=scope) @staticmethod @@ -1825,12 +1878,14 @@ class GDN(base.Layer): raise ValueError('The channel dimension of the inputs to `GDN` ' 'must be defined.') self._input_rank = input_shape.ndims - self.input_spec = base.InputSpec(ndim=input_shape.ndims, - axes={channel_axis: num_channels}) + self.input_spec = base.InputSpec( + ndim=input_shape.ndims, axes={ + channel_axis: num_channels + }) - pedestal = array_ops.constant(self._reparam_offset ** 2, dtype=self.dtype) + pedestal = array_ops.constant(self._reparam_offset**2, dtype=self.dtype) beta_bound = array_ops.constant( - (self._beta_min + self._reparam_offset ** 2) ** .5, dtype=self.dtype) + (self._beta_min + self._reparam_offset**2)**.5, dtype=self.dtype) gamma_bound = array_ops.constant(self._reparam_offset, dtype=self.dtype) def beta_initializer(shape, dtype=None, partition_info=None): @@ -1844,19 +1899,21 @@ class GDN(base.Layer): eye = linalg_ops.eye(shape[0], dtype=dtype) return math_ops.sqrt(self._gamma_init * eye + pedestal) - beta = self.add_variable('reparam_beta', - shape=[num_channels], - initializer=beta_initializer, - dtype=self.dtype, - trainable=True) + beta = self.add_variable( + 'reparam_beta', + shape=[num_channels], + initializer=beta_initializer, + dtype=self.dtype, + trainable=True) beta = self._lower_bound(beta, beta_bound) self.beta = math_ops.square(beta) - pedestal - gamma = self.add_variable('reparam_gamma', - shape=[num_channels, num_channels], - initializer=gamma_initializer, - dtype=self.dtype, - trainable=True) + gamma = self.add_variable( + 'reparam_gamma', + shape=[num_channels, num_channels], + initializer=gamma_initializer, + dtype=self.dtype, + trainable=True) gamma = self._lower_bound(gamma, gamma_bound) self.gamma = math_ops.square(gamma) - pedestal @@ -1871,8 +1928,11 @@ class GDN(base.Layer): # Compute normalization pool. if self.data_format == 'channels_first': - norm_pool = nn.convolution(math_ops.square(inputs), gamma, 'VALID', - data_format='NC' + 'DHW'[-(ndim - 2):]) + norm_pool = nn.convolution( + math_ops.square(inputs), + gamma, + 'VALID', + data_format='NC' + 'DHW' [-(ndim - 2):]) if ndim == 3: norm_pool = array_ops.expand_dims(norm_pool, 2) norm_pool = nn.bias_add(norm_pool, self.beta, data_format='NCHW') @@ -1896,7 +1956,7 @@ class GDN(base.Layer): outputs.set_shape(inputs.get_shape()) return outputs - def _compute_output_shape(self, input_shape): + def compute_output_shape(self, input_shape): channel_axis = self._channel_axis() input_shape = tensor_shape.TensorShape(input_shape) if not 3 <= input_shape.ndim <= 5: @@ -1914,7 +1974,7 @@ def gdn(inputs, inverse=False, beta_min=1e-6, gamma_init=.1, - reparam_offset=2 ** -18, + reparam_offset=2**-18, data_format='channels_last', activity_regularizer=None, trainable=True, @@ -1980,17 +2040,18 @@ def gdn(inputs, Returns: Output tensor. """ - layer = GDN(inverse=inverse, - beta_min=beta_min, - gamma_init=gamma_init, - reparam_offset=reparam_offset, - data_format=data_format, - activity_regularizer=activity_regularizer, - trainable=trainable, - name=name, - dtype=inputs.dtype.base_dtype, - _scope=name, - _reuse=reuse) + layer = GDN( + inverse=inverse, + beta_min=beta_min, + gamma_init=gamma_init, + reparam_offset=reparam_offset, + data_format=data_format, + activity_regularizer=activity_regularizer, + trainable=trainable, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse) return layer.apply(inputs) @@ -2066,8 +2127,8 @@ def layer_norm(inputs, or if `inputs.shape[begin_params_axis:]` is not fully defined at graph build time. """ - with variable_scope.variable_scope(scope, 'LayerNorm', [inputs], - reuse=reuse) as sc: + with variable_scope.variable_scope( + scope, 'LayerNorm', [inputs], reuse=reuse) as sc: inputs = ops.convert_to_tensor(inputs) inputs_shape = inputs.shape inputs_rank = inputs_shape.ndims @@ -2077,15 +2138,14 @@ def layer_norm(inputs, if begin_norm_axis < 0: begin_norm_axis = inputs_rank + begin_norm_axis if begin_params_axis >= inputs_rank or begin_norm_axis >= inputs_rank: - raise ValueError( - 'begin_params_axis (%d) and begin_norm_axis (%d) ' - 'must be < rank(inputs) (%d)' - % (begin_params_axis, begin_norm_axis, inputs_rank)) + raise ValueError('begin_params_axis (%d) and begin_norm_axis (%d) ' + 'must be < rank(inputs) (%d)' % + (begin_params_axis, begin_norm_axis, inputs_rank)) params_shape = inputs_shape[begin_params_axis:] if not params_shape.is_fully_defined(): raise ValueError( - 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % ( - inputs.name, begin_params_axis, inputs_shape)) + 'Inputs %s: shape(inputs)[%s:] is not fully defined: %s' % + (inputs.name, begin_params_axis, inputs_shape)) # Allocate parameters for the beta and gamma of the normalization. beta, gamma = None, None if center: @@ -2099,8 +2159,8 @@ def layer_norm(inputs, collections=beta_collections, trainable=trainable) if scale: - gamma_collections = utils.get_variable_collections(variables_collections, - 'gamma') + gamma_collections = utils.get_variable_collections( + variables_collections, 'gamma') gamma = variables.model_variable( 'gamma', shape=params_shape, @@ -2114,7 +2174,11 @@ def layer_norm(inputs, # Compute layer normalization using the batch_normalization function. variance_epsilon = 1e-12 outputs = nn.batch_normalization( - inputs, mean, variance, offset=beta, scale=gamma, + inputs, + mean, + variance, + offset=beta, + scale=gamma, variance_epsilon=variance_epsilon) outputs.set_shape(inputs_shape) if activation_fn is not None: @@ -2122,6 +2186,34 @@ def layer_norm(inputs, return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def images_to_sequence(inputs, data_format=DATA_FORMAT_NHWC, + outputs_collections=None, scope=None): + """Convert a batch of images into a batch of sequences. + Args: + inputs: a (num_images, height, width, depth) tensor + data_format: A string. `NHWC` (default) and `NCHW` are supported. + outputs_collections: The collections to which the outputs are added. + scope: Optional scope for name_scope. + Returns: + (width, num_images*height, depth) sequence tensor + """ + if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): + raise ValueError('data_format has to be either NCHW or NHWC.') + with ops.name_scope(scope, 'ImagesToSequence', [inputs]) as sc: + inputs = ops.convert_to_tensor(inputs) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + if df == 'channels_first': + inputs = array_ops.transpose(inputs, [0, 2, 3, 1]) + _, _, width, depth = inputs.get_shape().as_list() + s = array_ops.shape(inputs) + batch_size, height = s[0], s[1] + transposed = array_ops.transpose(inputs, [2, 0, 1, 3]) + outputs = array_ops.reshape(transposed, [width, batch_size * height, depth]) + return utils.collect_named_outputs(outputs_collections, sc, outputs) + + @add_arg_scope def max_pool2d(inputs, kernel_size, @@ -2160,13 +2252,14 @@ def max_pool2d(inputs, raise ValueError('data_format has to be either NCHW or NHWC.') with ops.name_scope(scope, 'MaxPool2D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.MaxPooling2D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.MaxPooling2D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -2209,13 +2302,14 @@ def max_pool3d(inputs, raise ValueError('data_format has to be either NCDHW or NDHWC.') with ops.name_scope(scope, 'MaxPool3D', [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') - layer = pooling_layers.MaxPooling3D(pool_size=kernel_size, - strides=stride, - padding=padding, - data_format=df, - _scope=sc) + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') + layer = pooling_layers.MaxPooling3D( + pool_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + _scope=sc) outputs = layer.apply(inputs) return utils.collect_named_outputs(outputs_collections, sc, outputs) @@ -2268,8 +2362,8 @@ def pool(inputs, """ # pylint: enable=line-too-long - with ops.name_scope(scope, '%s_pool' % - (pooling_type.lower()), [inputs]) as sc: + with ops.name_scope(scope, '%s_pool' % (pooling_type.lower()), + [inputs]) as sc: inputs = ops.convert_to_tensor(inputs) input_rank = inputs.get_shape().ndims if input_rank is None: @@ -2314,18 +2408,16 @@ def one_hot_encoding(labels, labels = ops.convert_to_tensor(labels) if labels.dtype == dtypes.int32: labels = standard_ops.to_int64(labels) - outputs = standard_ops.one_hot(labels, - num_classes, - on_value=on_value, - off_value=off_value) + outputs = standard_ops.one_hot( + labels, num_classes, on_value=on_value, off_value=off_value) return utils.collect_named_outputs(outputs_collections, sc, outputs) def _apply_activation(y, activation_fn, output_collections): if activation_fn is not None: y = activation_fn(y) - ops.add_to_collections(list(output_collections or []) + - [ops.GraphKeys.ACTIVATIONS], y) + ops.add_to_collections( + list(output_collections or []) + [ops.GraphKeys.ACTIVATIONS], y) return y @@ -2370,7 +2462,7 @@ def repeat(inputs, repetitions, layer, *args, **kwargs): scope = 'repeat' outputs = inputs for i in range(repetitions): - kwargs['scope'] = scope + '_' + str(i+1) + kwargs['scope'] = scope + '_' + str(i + 1) outputs = layer(outputs, *args, **kwargs) return outputs @@ -2385,8 +2477,8 @@ def _scale_gradient_grad(op, grad): return [grad * op.inputs[1], None] -@function.Defun(python_grad_func=_scale_gradient_grad, - shape_func=_scale_gradient_shape) +@function.Defun( + python_grad_func=_scale_gradient_grad, shape_func=_scale_gradient_shape) def scale_gradient(inputs, gradient_multiplier): """Identity operation, but with the gradient multiplied by a tensor. @@ -2491,18 +2583,21 @@ def separable_convolution2d( """ if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') - layer_variable_getter = _build_variable_getter( - {'bias': 'biases', - 'depthwise_kernel': 'depthwise_weights', - 'pointwise_kernel': 'pointwise_weights'}) + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'depthwise_kernel': 'depthwise_weights', + 'pointwise_kernel': 'pointwise_weights' + }) with variable_scope.variable_scope( - scope, 'SeparableConv2d', [inputs], reuse=reuse, + scope, + 'SeparableConv2d', [inputs], + reuse=reuse, custom_getter=layer_variable_getter) as sc: inputs = ops.convert_to_tensor(inputs) - df = ('channels_first' if data_format and data_format.startswith('NC') - else 'channels_last') + df = ('channels_first' + if data_format and data_format.startswith('NC') else 'channels_last') if num_outputs is not None: # Apply separable conv using the SeparableConvolution2D layer. layer = convolutional_layers.SeparableConvolution2D( @@ -2535,8 +2630,8 @@ def separable_convolution2d( _add_variable_to_collections(layer.pointwise_kernel, variables_collections, 'weights') if layer.bias is not None: - _add_variable_to_collections(layer.bias, - variables_collections, 'biases') + _add_variable_to_collections(layer.bias, variables_collections, + 'biases') if normalizer_fn is not None: normalizer_params = normalizer_params or {} @@ -2551,8 +2646,7 @@ def separable_convolution2d( weights_collections = utils.get_variable_collections( variables_collections, 'weights') - depthwise_shape = [kernel_h, kernel_w, - num_filters_in, depth_multiplier] + depthwise_shape = [kernel_h, kernel_w, num_filters_in, depth_multiplier] depthwise_weights = variables.model_variable( 'depthwise_weights', shape=depthwise_shape, @@ -2561,11 +2655,18 @@ def separable_convolution2d( regularizer=weights_regularizer, trainable=trainable, collections=weights_collections) - strides = [1, 1, stride_h, stride_w] if data_format.startswith('NC') else [1, stride_h, stride_w, 1] + strides = [1, 1, stride_h, + stride_w] if data_format.startswith('NC') else [ + 1, stride_h, stride_w, 1 + ] - outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding, - rate=utils.two_element_tuple(rate), - data_format=data_format) + outputs = nn.depthwise_conv2d( + inputs, + depthwise_weights, + strides, + padding, + rate=utils.two_element_tuple(rate), + data_format=data_format) num_outputs = depth_multiplier * num_filters_in if normalizer_fn is not None: @@ -2575,13 +2676,16 @@ def separable_convolution2d( if biases_initializer is not None: biases_collections = utils.get_variable_collections( variables_collections, 'biases') - biases = variables.model_variable('biases', - shape=[num_outputs,], - dtype=dtype, - initializer=biases_initializer, - regularizer=biases_regularizer, - trainable=trainable, - collections=biases_collections) + biases = variables.model_variable( + 'biases', + shape=[ + num_outputs, + ], + dtype=dtype, + initializer=biases_initializer, + regularizer=biases_regularizer, + trainable=trainable, + collections=biases_collections) outputs = nn.bias_add(outputs, biases, data_format=data_format) if activation_fn is not None: @@ -2589,6 +2693,36 @@ def separable_convolution2d( return utils.collect_named_outputs(outputs_collections, sc.name, outputs) +@add_arg_scope +def sequence_to_images(inputs, height, output_data_format='channels_last', + outputs_collections=None, scope=None): + """Convert a batch of sequences into a batch of images. + Args: + inputs: (num_steps, num_batches, depth) sequence tensor + height: the height of the images + output_data_format: Format of output tensor. + Currently supports `'channels_first'` and `'channels_last'`. + outputs_collections: The collections to which the outputs are added. + scope: Optional scope for name_scope. + Returns: + A tensor representing the output of the operation. + """ + with ops.name_scope(scope, 'SequenceToImages', [inputs]) as sc: + inputs = ops.convert_to_tensor(inputs) + width, num_batches, depth = inputs.get_shape().as_list() + if num_batches is None: + num_batches = -1 + else: + num_batches = num_batches // height + reshaped = array_ops.reshape(inputs, + [width, num_batches, height, depth]) + if output_data_format == 'channels_first': + outputs = array_ops.transpose(reshaped, [1, 3, 2, 0]) + else: + outputs = array_ops.transpose(reshaped, [1, 2, 0, 3]) + return utils.collect_named_outputs(outputs_collections, sc, outputs) + + @add_arg_scope def softmax(logits, scope=None): """Performs softmax on Nth dimension of N-dimensional logit tensor. @@ -2651,7 +2785,7 @@ def spatial_softmax(features, ValueError: If unexpected data_format specified. ValueError: If num_channels dimension is unspecified. """ - with variable_scope.variable_scope(name, 'spatial_softmax'): + with variable_scope.variable_scope(name, 'spatial_softmax'): shape = array_ops.shape(features) static_shape = features.shape if data_format == DATA_FORMAT_NHWC: @@ -2663,44 +2797,52 @@ def spatial_softmax(features, if num_channels.value is None: raise ValueError('The num_channels dimension of the inputs to ' '`spatial_softmax` should be defined. Found `None`.') - - with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): + + with ops.name_scope('spatial_softmax_op', 'spatial_softmax_op', [features]): # Create tensors for x and y coordinate values, scaled to range [-1, 1]. - pos_x, pos_y = array_ops.meshgrid(math_ops.lin_space(-1., 1., num=height), - math_ops.lin_space(-1., 1., num=width), - indexing='ij') + pos_x, pos_y = array_ops.meshgrid( + math_ops.lin_space(-1., 1., num=height), + math_ops.lin_space(-1., 1., num=width), + indexing='ij') pos_x = array_ops.reshape(pos_x, [height * width]) pos_y = array_ops.reshape(pos_y, [height * width]) + if temperature is None: - temperature_collections = utils.get_variable_collections( + temp_initializer = init_ops.ones_initializer() + else: + temp_initializer = init_ops.constant_initializer(temperature) + + if not trainable: + temp_collections = None + else: + temp_collections = utils.get_variable_collections( variables_collections, 'temperature') - temperature = variables.model_variable( - 'temperature', - shape=(), - dtype=dtypes.float32, - initializer=init_ops.ones_initializer(), - collections=temperature_collections, - trainable=trainable) + + temperature = variables.model_variable( + 'temperature', + shape=(), + dtype=dtypes.float32, + initializer=temp_initializer, + collections=temp_collections, + trainable=trainable) if data_format == 'NCHW': features = array_ops.reshape(features, [-1, height * width]) else: features = array_ops.reshape( array_ops.transpose(features, [0, 3, 1, 2]), [-1, height * width]) - - softmax_attention = nn.softmax(features/temperature) + + softmax_attention = nn.softmax(features / temperature) expected_x = math_ops.reduce_sum( - pos_x * softmax_attention, [1], keep_dims=True) + pos_x * softmax_attention, [1], keepdims=True) expected_y = math_ops.reduce_sum( - pos_y * softmax_attention, [1], keep_dims=True) + pos_y * softmax_attention, [1], keepdims=True) expected_xy = array_ops.concat([expected_x, expected_y], 1) - feature_keypoints = array_ops.reshape( - expected_xy, [-1, num_channels.value * 2]) + feature_keypoints = array_ops.reshape(expected_xy, + [-1, num_channels.value * 2]) feature_keypoints.set_shape([None, num_channels.value * 2]) return feature_keypoints - - def stack(inputs, layer, stack_args, **kwargs): """Builds a stack of layers by applying layer repeatedly using stack_args. @@ -2748,7 +2890,7 @@ def stack(inputs, layer, stack_args, **kwargs): scope = 'stack' outputs = inputs for i in range(len(stack_args)): - kwargs['scope'] = scope + '_' + str(i+1) + kwargs['scope'] = scope + '_' + str(i + 1) layer_args = stack_args[i] if not isinstance(layer_args, (list, tuple)): layer_args = [layer_args] @@ -2779,11 +2921,10 @@ def unit_norm(inputs, dim, epsilon=1e-7, scope=None): raise ValueError('The input rank must be known.') input_rank = len(inputs.get_shape().as_list()) if dim < 0 or dim >= input_rank: - raise ValueError( - 'dim must be positive but smaller than the input rank.') + raise ValueError('dim must be positive but smaller than the input rank.') - lengths = math_ops.sqrt(epsilon + math_ops.reduce_sum( - math_ops.square(inputs), dim, True)) + lengths = math_ops.sqrt( + epsilon + math_ops.reduce_sum(math_ops.square(inputs), dim, True)) multiples = [] if dim > 0: multiples.append(array_ops.ones([dim], dtypes.int32)) @@ -2827,7 +2968,7 @@ def poincare_normalize(x, axis=1, epsilon=1e-5, name=None): """ with ops.name_scope(name, 'poincare_normalize', [x]) as name: x = ops.convert_to_tensor(x, name='x') - square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keep_dims=True) + square_sum = math_ops.reduce_sum(math_ops.square(x), axis, keepdims=True) x_inv_norm = math_ops.rsqrt(square_sum) x_inv_norm = math_ops.minimum((1. - epsilon) * x_inv_norm, 1.) return math_ops.multiply(x, x_inv_norm, name=name) @@ -2924,29 +3065,31 @@ def legacy_fully_connected(x, raise ValueError('last dimension of x must be known but is None') dtype = x.dtype.base_dtype - weight_collections = set(list(weight_collections or []) + - [ops.GraphKeys.GLOBAL_VARIABLES]) - w = variable_scope.get_variable('weights', - shape=[num_input_units, num_output_units], - dtype=dtype, - initializer=weight_init, - collections=weight_collections, - regularizer=weight_regularizer, - trainable=trainable) - x_2_dim = x if len(dims) <= 2 else array_ops.reshape(x, - [-1, num_input_units]) + weight_collections = set( + list(weight_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES]) + w = variable_scope.get_variable( + 'weights', + shape=[num_input_units, num_output_units], + dtype=dtype, + initializer=weight_init, + collections=weight_collections, + regularizer=weight_regularizer, + trainable=trainable) + x_2_dim = x if len(dims) <= 2 else array_ops.reshape( + x, [-1, num_input_units]) y = standard_ops.matmul(x_2_dim, w) if bias_init is not None: - bias_collections = set(list(bias_collections or []) + - [ops.GraphKeys.GLOBAL_VARIABLES]) - b = variable_scope.get_variable('bias', - shape=[num_output_units], - dtype=dtype, - initializer=bias_init, - collections=bias_collections, - regularizer=bias_regularizer, - trainable=trainable) + bias_collections = set( + list(bias_collections or []) + [ops.GraphKeys.GLOBAL_VARIABLES]) + b = variable_scope.get_variable( + 'bias', + shape=[num_output_units], + dtype=dtype, + initializer=bias_init, + collections=bias_collections, + regularizer=bias_regularizer, + trainable=trainable) y = nn.bias_add(y, b) diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index a05e464a26d8167707ce6d6455aca50b0416aa1f..0f062adbab3ca9acfb89543b69c7c957bbdf5dd8 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope @@ -126,8 +127,8 @@ class AvgPool3DTest(test.TestCase): def testInvalidDataFormat(self): depth, height, width = 3, 6, 9 images = np.random.uniform(size=(5, depth, height, width, 3)) - with self.assertRaisesRegexp(ValueError, - 'data_format has to be either NCDHW or NDHWC.'): + with self.assertRaisesRegexp( + ValueError, 'data_format has to be either NCDHW or NDHWC.'): _layers.avg_pool3d(images, [3, 3, 3], data_format='CDHWN') def testCreateAvgPool(self): @@ -147,7 +148,8 @@ class AvgPool3DTest(test.TestCase): def testCollectOutputs(self): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) - output = _layers.avg_pool3d(images, [3, 3, 3], outputs_collections='outputs') + output = _layers.avg_pool3d( + images, [3, 3, 3], outputs_collections='outputs') output_collected = ops.get_collection('outputs')[0] self.assertEqual(output_collected.aliases, ['AvgPool3D']) self.assertEqual(output_collected, output) @@ -182,7 +184,8 @@ class AvgPool3DTest(test.TestCase): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) output = _layers.avg_pool3d(images, [3, 3, 3], stride=1, padding='SAME') - self.assertListEqual(output.get_shape().as_list(), [5, depth, height, width, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 3]) def testGlobalAvgPool(self): depth, height, width = 3, 6, 9 @@ -514,7 +517,9 @@ class ConvolutionTest(test.TestCase): with arg_scope( [layers_lib.convolution2d], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = layers_lib.convolution2d(images, 32, [3, 3]) net = layers_lib.convolution2d(net, 32, [3, 3]) self.assertEqual(len(variables.get_variables()), 8) @@ -528,7 +533,9 @@ class ConvolutionTest(test.TestCase): with arg_scope( [layers_lib.convolution2d], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = layers_lib.convolution2d(images, 32, [3, 3], scope='Conv') net = layers_lib.convolution2d( net, 32, [3, 3], scope='Conv', reuse=True) @@ -701,7 +708,7 @@ class Convolution2dTransposeTests(test.TestCase): _layers.convolution2d_transpose(images, 32, 3, data_format='CHWN') def testOutputSizeWithStrideOneSamePaddingNCHW(self): - # `NCHW` data fomat is only supported for `GPU` device. + # `NCHW` data format is only supported for `GPU` device. if test.is_gpu_available(cuda_only=True): with self.test_session(use_gpu=True) as sess: num_filters = 32 @@ -1030,7 +1037,8 @@ class Convolution2dTransposeTests(test.TestCase): for _ in range(10): num_filters = 1 input_size = [ - 1, np.random.randint(1, max_image_size), + 1, + np.random.randint(1, max_image_size), np.random.randint(1, max_image_size), 1 ] filter_size = [ @@ -1184,8 +1192,10 @@ class ConvolutionInPlaneTest(test.TestCase): with self.test_session() as sess: sess.run(init_op) - result = sess.run(horz_gradients, - feed_dict={image: np.ones((1, 10, 10, 1))}) + result = sess.run( + horz_gradients, feed_dict={ + image: np.ones((1, 10, 10, 1)) + }) expected = np.zeros((1, 10, 9, 1)) self.assertAllEqual(result, expected) @@ -1292,6 +1302,19 @@ class ConvolutionInPlaneTest(test.TestCase): self.assertAllClose(result, expected, rtol=1e-5, atol=1e-5) +class DenseToSparseTest(test.TestCase): + + def testDenseFromConstantToSparse(self): + expected_constant = np.reshape(np.arange(24, dtype=np.int64), (3, 4, 2)) + tensor = constant_op.constant(expected_constant) + sparse = _layers.dense_to_sparse(tensor) + dense = sparse_ops.sparse_to_dense(sparse.indices, sparse.dense_shape, + sparse.values) + with self.test_session() as sess: + constant = sess.run(dense) + self.assertAllEqual(expected_constant, constant) + + class DropoutTest(test.TestCase): def testCreateDropout(self): @@ -1406,8 +1429,7 @@ class FlattenTest(test.TestCase): with ops.Graph().as_default() as g, self.test_session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) inputs.set_shape(tensor_shape.TensorShape((5,))) - with self.assertRaisesRegexp(ValueError, - 'incompatible with the layer'): + with self.assertRaisesRegexp(ValueError, 'incompatible with the layer'): _layers.flatten(inputs) def testUnknownLastDim(self): @@ -1717,7 +1739,9 @@ class FCTest(test.TestCase): with arg_scope( [_layers.fully_connected], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = _layers.fully_connected(images, 27) net = _layers.fully_connected(net, 27) self.assertEqual(len(variables.get_variables()), 8) @@ -1733,7 +1757,9 @@ class FCTest(test.TestCase): with arg_scope( [_layers.fully_connected], normalizer_fn=_layers.batch_norm, - normalizer_params={'decay': 0.9}): + normalizer_params={ + 'decay': 0.9 + }): net = _layers.fully_connected(images, 27, scope='fc1') net = _layers.fully_connected(net, 27, scope='fc1', reuse=True) self.assertEqual(len(variables.get_variables()), 4) @@ -1747,6 +1773,12 @@ class BatchNormTest(test.TestCase): expected_var *= correction_factor return expected_var, correction_factor + def testBatchNormCenterFalse(self): + a = array_ops.placeholder(dtype=dtypes.float32, shape=(10, 10, 10, 10)) + # Test that center=False builds a valid graph. + _layers.batch_norm( + a, center=False, data_format='NCHW', zero_debias_moving_mean=True) + def testUnknownShape(self): with ops.Graph().as_default() as g, self.test_session(g): inputs = array_ops.placeholder(dtype=dtypes.float32) @@ -1782,8 +1814,8 @@ class BatchNormTest(test.TestCase): images = np.random.uniform(size=(5, height, width, 3)).astype( dtype.as_numpy_dtype) output = _layers.batch_norm(images, fused=fused) - expected_name = ('BatchNorm/FusedBatchNorm' if fused else - 'BatchNorm/batchnorm') + expected_name = ('BatchNorm/FusedBatchNorm' + if fused else 'BatchNorm/batchnorm') self.assertTrue(output.op.name.startswith(expected_name)) self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3]) self.assertEqual( @@ -2002,8 +2034,8 @@ class BatchNormTest(test.TestCase): expected_var = np.var(image_values, axis=axis) if fused: # Add Bessel's correction - expected_var, _ = self._addBesselsCorrection(batch_size * height * - width, expected_var) + expected_var, _ = self._addBesselsCorrection( + batch_size * height * width, expected_var) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) output = _layers.batch_norm( @@ -2164,7 +2196,7 @@ class BatchNormTest(test.TestCase): # After initialization moving_mean == 0 and moving_variance == 1. self.assertAllClose(mean, [0] * 3) self.assertAllClose(variance, [1] * 3) - # Simulate assigment from saver restore. + # Simulate assignment from saver restore. init_assigns = [ state_ops.assign(moving_mean, expected_mean), state_ops.assign(moving_variance, expected_var) @@ -2522,8 +2554,8 @@ class BatchNormTest(test.TestCase): expected_var = np.var(image_values, axis=axis) if fused: # Add Bessel's correction - expected_var, _ = self._addBesselsCorrection(batch_size * height * - width, expected_var) + expected_var, _ = self._addBesselsCorrection( + batch_size * height * width, expected_var) images = constant_op.constant( image_values, shape=image_shape, dtype=dtypes.float32) output = _layers.batch_norm( @@ -2553,8 +2585,9 @@ class BatchNormTest(test.TestCase): np_output, new_images_gradients = sess.run([output, images_gradients]) # The outputs should be close to 0.0 mean and 1.0 variance self.assertAllClose( - np.mean( - np_output, axis=axis), [0] * channels, rtol=0.001, atol=0.001) + np.mean(np_output, axis=axis), [0] * channels, + rtol=0.001, + atol=0.001) self.assertAllClose( np.var(np_output, axis=axis), [1] * channels, rtol=0.01, atol=0.01) # The gradients should change slowly while updating moving_mean. @@ -2582,14 +2615,14 @@ class BatchNormTest(test.TestCase): channels = 3 with self.test_session() as sess: images = (np.ones((5, height, width, channels)) * 9.0).astype('f') - beta = init_ops.constant_initializer((np.ones(channels) * 5.0).astype( - 'f')) - gamma = init_ops.constant_initializer((np.ones(channels) * 2.0).astype( - 'f')) - mean = init_ops.constant_initializer((np.ones(channels) * 5.0).astype( - 'f')) - variance = init_ops.constant_initializer((np.ones(channels) * 4.0).astype( - 'f')) + beta = init_ops.constant_initializer( + (np.ones(channels) * 5.0).astype('f')) + gamma = init_ops.constant_initializer( + (np.ones(channels) * 2.0).astype('f')) + mean = init_ops.constant_initializer( + (np.ones(channels) * 5.0).astype('f')) + variance = init_ops.constant_initializer( + (np.ones(channels) * 4.0).astype('f')) output = _layers.batch_norm( images, is_training=False, @@ -2610,21 +2643,18 @@ class BatchNormTest(test.TestCase): with self.test_session(use_gpu=True) as sess: images = np.arange(np.product(shape), dtype=np.float32).reshape(shape) beta = init_ops.constant_initializer( - np.arange( - 2, channels + 2, dtype=np.float32)) + np.arange(2, channels + 2, dtype=np.float32)) gamma = init_ops.constant_initializer( - np.arange( - 10, channels + 10, dtype=np.float32) * 2.0) + np.arange(10, channels + 10, dtype=np.float32) * 2.0) mean = init_ops.constant_initializer( - np.arange( - 3, channels + 3, dtype=np.float32) * 5.0) + np.arange(3, channels + 3, dtype=np.float32) * 5.0) variance = init_ops.constant_initializer( - np.arange( - 1, channels + 1, dtype=np.float32) * 4.0) + np.arange(1, channels + 1, dtype=np.float32) * 4.0) if data_format == 'NCHW': # Reshape inputs from NHWC to NCHW format. images = array_ops.transpose( - images, [0, len(shape) - 1] + list(range(1, len(shape) - 1))) + images, [0, len(shape) - 1] + list(range(1, + len(shape) - 1))) output = _layers.batch_norm( images, is_training=is_training, @@ -2727,16 +2757,16 @@ class BatchNormTest(test.TestCase): # Tests that the adjustment is appropriately passed to and used by the core # BN layer. all_adjustments = [] + def _create_adjustment(shape): adjustments = [array_ops.ones(shape[-1:]), array_ops.zeros(shape[-1:])] all_adjustments.extend(adjustments) return adjustments + depth = 8 images = array_ops.zeros([10, 5, 5, depth]) output = _layers.batch_norm( - images, - is_training=True, - adjustment=_create_adjustment) + images, is_training=True, adjustment=_create_adjustment) self.assertListEqual(output.shape.as_list(), images.shape.as_list()) self.assertEqual(len(all_adjustments), 2) self.assertListEqual(all_adjustments[0].shape.as_list(), [depth]) @@ -2801,7 +2831,10 @@ class LayerNormTest(test.TestCase): # output_train and output_eval should be the same. self.assertAllClose(sess.run([output_train]), sess.run([output_eval])) - def doOutputTest(self, input_shape, tol=1e-5, begin_norm_axis=1, + def doOutputTest(self, + input_shape, + tol=1e-5, + begin_norm_axis=1, dtype=dtypes.float64): expected_mean = np.zeros(input_shape[:begin_norm_axis]) expected_var = np.ones(input_shape[:begin_norm_axis]) @@ -2832,13 +2865,10 @@ class LayerNormTest(test.TestCase): # Layer-norm implemented in numpy eps = 1e-12 expected_out = ( - (gamma * ( - input_values - - np.mean(input_values, axis=moments_axis, keepdims=True)) - / np.sqrt( - eps - + np.var(input_values, axis=moments_axis, keepdims=True))) - + beta) + (gamma * (input_values - np.mean( + input_values, axis=moments_axis, keepdims=True)) / + np.sqrt(eps + np.var( + input_values, axis=moments_axis, keepdims=True))) + beta) self.assertAllClose(expected_mean, mean, atol=tol, rtol=tol) self.assertAllClose(expected_var, var, atol=tol) # The full computation gets a bigger tolerance @@ -2856,10 +2886,10 @@ class LayerNormTest(test.TestCase): def testOutput4DInputNormOnInnermostAxis(self): # Equivalent tests - self.doOutputTest((100, 10, 10, 3), begin_norm_axis=3, tol=1e-4, - dtype=dtypes.float64) - self.doOutputTest((100, 10, 10, 3), begin_norm_axis=-1, tol=1e-4, - dtype=dtypes.float64) + self.doOutputTest( + (100, 10, 10, 3), begin_norm_axis=3, tol=1e-4, dtype=dtypes.float64) + self.doOutputTest( + (100, 10, 10, 3), begin_norm_axis=-1, tol=1e-4, dtype=dtypes.float64) def testOutputSmallInput(self): self.doOutputTest((10, 10, 10, 30)) @@ -2896,7 +2926,7 @@ class GDNTest(test.TestCase): x = np.random.uniform(size=(1, 2, 3, 4)[:ndim]) y = self._runGDN(x, x.shape, False, 'channels_last') self.assertEqual(x.shape, y.shape) - self.assertAllClose(y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) + self.assertAllClose(y, x / np.sqrt(1 + .1 * (x**2)), rtol=0, atol=1e-6) def testChannelsFirst(self): # `bias_add` doesn't support NCHW on CPU. @@ -2905,8 +2935,7 @@ class GDNTest(test.TestCase): x = np.random.uniform(size=(4, 3, 2, 1)[:ndim]) y = self._runGDN(x, x.shape, False, 'channels_first') self.assertEqual(x.shape, y.shape) - self.assertAllClose( - y, x / np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) + self.assertAllClose(y, x / np.sqrt(1 + .1 * (x**2)), rtol=0, atol=1e-6) def testWrongDims(self): for ndim in [1, 2, 6]: @@ -2918,7 +2947,29 @@ class GDNTest(test.TestCase): x = np.random.uniform(size=(1, 2, 3, 4)) y = self._runGDN(x, x.shape, True, 'channels_last') self.assertEqual(x.shape, y.shape) - self.assertAllClose(y, x * np.sqrt(1 + .1 * (x ** 2)), rtol=0, atol=1e-6) + self.assertAllClose(y, x * np.sqrt(1 + .1 * (x**2)), rtol=0, atol=1e-6) + + +class ImagesToSequenceTest(test.TestCase): + + def testInvalidDataFormat(self): + height, width = 7, 11 + images = np.random.uniform(size=(5, height, width, 2)) + with self.assertRaisesRegexp(ValueError, + 'data_format has to be either NCHW or NHWC.'): + _layers.images_to_sequence(images, data_format='CHWN') + + def testImagesToSequenceDims(self): + height, width = 7, 11 + images = np.random.uniform(size=(2, height, width, 5)).astype(np.float32) + output = _layers.images_to_sequence(images) + self.assertListEqual(output.get_shape().as_list(), [11, 14, 5]) + + def testImagesToSequenceNCHW(self): + height, width = 7, 11 + images = np.random.uniform(size=(2, 5, height, width)).astype(np.float32) + output = _layers.images_to_sequence(images, data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [11, 14, 5]) class MaxPool2DTest(test.TestCase): @@ -2995,20 +3046,22 @@ class MaxPool3DTest(test.TestCase): def testInvalidDataFormat(self): depth, height, width = 3, 6, 9 images = np.random.uniform(size=(5, depth, height, width, 3)) - with self.assertRaisesRegexp(ValueError, - 'data_format has to be either NCDHW or NDHWC.'): + with self.assertRaisesRegexp( + ValueError, 'data_format has to be either NCDHW or NDHWC.'): _layers.max_pool3d(images, [3, 3, 3], data_format='CDHWN') def testCreateMaxPool(self): depth, height, width = 3, 6, 9 - images = np.random.uniform(size=(5, depth, height, width, 3)).astype(np.float32) + images = np.random.uniform(size=(5, depth, height, width, 3)).astype( + np.float32) output = _layers.max_pool3d(images, [3, 3, 3]) self.assertEqual(output.op.name, 'MaxPool3D/MaxPool3D') self.assertListEqual(output.get_shape().as_list(), [5, 1, 2, 4, 3]) def testCreateMaxPoolNCDHW(self): depth, height, width = 3, 6, 9 - images = np.random.uniform(size=(5, 3, depth, height, width)).astype(np.float32) + images = np.random.uniform(size=(5, 3, depth, height, width)).astype( + np.float32) output = _layers.max_pool3d(images, [3, 3, 3], data_format='NCDHW') self.assertEquals(output.op.name, 'MaxPool3D/transpose_1') self.assertListEqual(output.get_shape().as_list(), [5, 3, 1, 2, 4]) @@ -3016,7 +3069,8 @@ class MaxPool3DTest(test.TestCase): def testCollectOutputs(self): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) - output = _layers.max_pool3d(images, [3, 3, 3], outputs_collections='outputs') + output = _layers.max_pool3d( + images, [3, 3, 3], outputs_collections='outputs') output_collected = ops.get_collection('outputs')[0] self.assertEqual(output_collected.aliases, ['MaxPool3D']) self.assertEqual(output_collected, output) @@ -3051,7 +3105,8 @@ class MaxPool3DTest(test.TestCase): depth, height, width = 3, 6, 9 images = random_ops.random_uniform((5, depth, height, width, 3), seed=1) output = _layers.max_pool3d(images, [3, 3, 3], stride=1, padding='SAME') - self.assertListEqual(output.get_shape().as_list(), [5, depth, height, width, 3]) + self.assertListEqual(output.get_shape().as_list(), + [5, depth, height, width, 3]) def testGlobalMaxPool(self): depth, height, width = 3, 6, 9 @@ -3231,7 +3286,11 @@ class SeparableConv2dTest(test.TestCase): images = random_ops.random_uniform((5, height, width, 3), seed=1) regularizer = regularizers.l2_regularizer(0.01) layers_lib.separable_conv2d( - images, 32, [3, 3], 2, weights_regularizer=regularizer) + images, + 32, [3, 3], + 2, + weights_regularizer=regularizer, + weights_initializer=init_ops.ones_initializer()) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 2) weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] @@ -3239,12 +3298,31 @@ class SeparableConv2dTest(test.TestCase): weight_decay.op.name, 'SeparableConv2d/depthwise_kernel/Regularizer/l2_regularizer') sess.run(variables_lib.global_variables_initializer()) - self.assertLessEqual(sess.run(weight_decay), 0.05) + depth_weight_one = sess.run(weight_decay) weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[1] self.assertEqual( weight_decay.op.name, 'SeparableConv2d/pointwise_kernel/Regularizer/l2_regularizer') - self.assertLessEqual(sess.run(weight_decay), 0.05) + pointwise_weight_one = sess.run(weight_decay) + + regularizer = regularizers.l2_regularizer(1.0) + layers_lib.separable_conv2d( + images, + 32, [3, 3], + 2, + weights_regularizer=regularizer, + weights_initializer=init_ops.ones_initializer()) + self.assertEqual( + len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 4) + weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[2] + sess.run(variables_lib.global_variables_initializer()) + depth_weight_two = sess.run(weight_decay) + weight_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[3] + pointwise_weight_two = sess.run(weight_decay) + + self.assertAllClose( + [100.0 * depth_weight_one, 100.0 * pointwise_weight_one], + [depth_weight_two, pointwise_weight_two]) def testReuseConvWithWeightDecay(self): height, width = 3, 3 @@ -3332,11 +3410,18 @@ class SeparableConv2dTest(test.TestCase): batch, height, width = 4, 10, 12 kernel_dim, stride = 3, 2 images = random_ops.random_uniform((batch, 3, height, width), seed=1) - output = layers_lib.separable_conv2d(images, num_outputs=num_filters, kernel_size=[kernel_dim, kernel_dim], - depth_multiplier=2, stride=stride, padding='VALID', data_format='NCHW') - self.assertListEqual( - output.get_shape().as_list(), [batch, correct_output_filters, - (height - kernel_dim + 1) // stride, (width - kernel_dim + 1) // stride]) + output = layers_lib.separable_conv2d( + images, + num_outputs=num_filters, + kernel_size=[kernel_dim, kernel_dim], + depth_multiplier=2, + stride=stride, + padding='VALID', + data_format='NCHW') + self.assertListEqual(output.get_shape().as_list(), [ + batch, correct_output_filters, (height - kernel_dim + 1) // stride, + (width - kernel_dim + 1) // stride + ]) class ScaleGradientTests(test.TestCase): @@ -3355,6 +3440,33 @@ class ScaleGradientTests(test.TestCase): np.testing.assert_array_equal([3 * 2], g_x.eval()) +class SequenceToImagesTest(test.TestCase): + + def testImagesToSequenceDims(self): + num_batches = 14 + num_time_steps = 11 + num_channels = 5 + desired_height = 7 + sequence = np.random.uniform(size=(num_time_steps, + num_batches, + num_channels)).astype(np.float32) + output = _layers.sequence_to_images(sequence, desired_height) + self.assertListEqual(output.get_shape().as_list(), [2, 7, 11, 5]) + + def testImagesToSequenceNCHW(self): + num_batches = 14 + num_time_steps = 11 + num_channels = 5 + desired_height = 7 + sequence = np.random.uniform(size=(num_time_steps, + num_batches, + num_channels)).astype(np.float32) + output = _layers.sequence_to_images(sequence, + desired_height, + output_data_format='channels_first') + self.assertListEqual(output.get_shape().as_list(), [2, 5, 7, 11]) + + class SoftmaxTests(test.TestCase): def setUp(self): @@ -3433,8 +3545,7 @@ class SpatialSoftmaxTests(test.TestCase): sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) - self.assertAllEqual(keypoints.shape, - (batch_shape[0], batch_shape[3] * 2)) + self.assertAllEqual(keypoints.shape, (batch_shape[0], batch_shape[3] * 2)) def testSpatialSoftmaxShapeNCHW(self): batch_shape = (2, 2, 35, 35) @@ -3445,8 +3556,7 @@ class SpatialSoftmaxTests(test.TestCase): sess.run(variables_lib.global_variables_initializer()) feed_dict = {features: np_features} keypoints = sess.run(spatial_softmax, feed_dict) - self.assertAllEqual(keypoints.shape, - (batch_shape[0], batch_shape[1] * 2)) + self.assertAllEqual(keypoints.shape, (batch_shape[0], batch_shape[1] * 2)) def testTwoMaxActivationsSameChannel(self): batch_size, height, width, nchannels = (2, 35, 35, 1) @@ -3465,8 +3575,8 @@ class SpatialSoftmaxTests(test.TestCase): x_loc = [avg_x] y_loc = [avg_y] - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3484,13 +3594,13 @@ class SpatialSoftmaxTests(test.TestCase): spatial_softmax = _layers.spatial_softmax(features) np_features = np.zeros(batch_shape, dtype=np.float32) - edges = [(0, 0), (0, width-1), (height-1, 0), (height-1, width-1)] + edges = [(0, 0), (0, width - 1), (height - 1, 0), (height - 1, width - 1)] x_loc, y_loc = zip(*edges) for c in range(nchannels): np_features[:, x_loc[c], y_loc[c], c] = 100. - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3519,10 +3629,10 @@ class SpatialSoftmaxTests(test.TestCase): np_features1[:, x_loc[c], y_loc[c], c] = 100. np_features2[:, x_loc[c], y_loc[c], c] = 100. - np_keypoints1 = self._SpatialSoftmax( - x_loc, y_loc, height1, width1, batch_size, nchannels) - np_keypoints2 = self._SpatialSoftmax( - x_loc, y_loc, height2, width2, batch_size, nchannels) + np_keypoints1 = self._SpatialSoftmax(x_loc, y_loc, height1, width1, + batch_size, nchannels) + np_keypoints2 = self._SpatialSoftmax(x_loc, y_loc, height2, width2, + batch_size, nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3548,8 +3658,8 @@ class SpatialSoftmaxTests(test.TestCase): for c in range(nchannels): np_features[:, x_loc[c], y_loc[c], c] = 100. - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3571,8 +3681,8 @@ class SpatialSoftmaxTests(test.TestCase): for c in range(nchannels): np_features[:, c, x_loc[c], y_loc[c]] = 100. - np_keypoints = self._SpatialSoftmax( - x_loc, y_loc, height, width, batch_size, nchannels) + np_keypoints = self._SpatialSoftmax(x_loc, y_loc, height, width, batch_size, + nchannels) # Make sure expected location keypoints matches actual location keypoints. with self.test_session() as sess: @@ -3667,8 +3777,7 @@ class UnitNormTests(test.TestCase): image = random_ops.random_uniform((height, width, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum( - math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) shape = [height, width, 3] del shape[dim] @@ -3704,8 +3813,7 @@ class UnitNormTests(test.TestCase): image = array_ops.placeholder(dtypes.float32, (None, None, 3)) output = _layers.unit_norm(image, dim=dim, epsilon=1e-6) norms = math_ops.sqrt( - math_ops.reduce_sum( - math_ops.square(output), reduction_indices=dim)) + math_ops.reduce_sum(math_ops.square(output), reduction_indices=dim)) with self.test_session(): actual = norms.eval({image: placeholder_value}) @@ -3769,8 +3877,8 @@ class PoincareNormalizeTest(test.TestCase): with self.test_session(): x_tf = constant_op.constant(x_np, name='x') y_tf = _layers.poincare_normalize(x_tf, dim) - err = gradient_checker.compute_gradient_error(x_tf, x_shape, - y_tf, x_shape) + err = gradient_checker.compute_gradient_error(x_tf, x_shape, y_tf, + x_shape) print('PoinCareNormalize gradient err = %g ' % err) self.assertLess(err, 1e-4) @@ -3782,14 +3890,9 @@ class LegacyFullyConnectedTest(test.TestCase): test.TestCase.setUp(self) random_seed.set_random_seed(1234) self.input = constant_op.constant([[1., 2., 3.], [-4., 15., -6.]]) - self.input_3_dim_arr = [[[1., 1.1, 1.2], - [2., 2.1, 2.2], - [3., 3.1, 3.2], - [4., 4.1, 4.2]], - [[5., 5.1, 5.2], - [6., 6.1, 6.2], - [7., 7.1, 7.2], - [8., 8.1, 8.2]]] + self.input_3_dim_arr = [[[1., 1.1, 1.2], [2., 2.1, 2.2], [3., 3.1, 3.2], + [4., 4.1, 4.2]], [[5., 5.1, 5.2], [6., 6.1, 6.2], + [7., 7.1, 7.2], [8., 8.1, 8.2]]] self.input_3_dim = constant_op.constant(self.input_3_dim_arr) assert not ops.get_collection(ops.GraphKeys.SUMMARIES) @@ -3884,15 +3987,10 @@ class LegacyFullyConnectedTest(test.TestCase): self._custom_initializers(self.input, 2, [[13.0, 13.0], [11.0, 11.0]]) def test_custom_initializers_multi_dim(self): - self._custom_initializers(self.input_3_dim, 2, - [[[7.6, 7.6], - [13.6, 13.6], - [19.6, 19.6], - [25.6, 25.6]], - [[31.6, 31.6], - [37.6, 37.6], - [43.6, 43.6], - [49.6, 49.6]]]) + self._custom_initializers( + self.input_3_dim, 2, + [[[7.6, 7.6], [13.6, 13.6], [19.6, 19.6], [25.6, 25.6]], + [[31.6, 31.6], [37.6, 37.6], [43.6, 43.6], [49.6, 49.6]]]) def test_custom_collections(self): layers_lib.legacy_relu( @@ -4002,12 +4100,16 @@ class LegacyFullyConnectedTest(test.TestCase): with self.test_session() as sess: variables_lib.global_variables_initializer().run() # we can feed in input with first dimension 2 - shape_value = sess.run(array_ops.shape(y), - feed_dict={x: self.input_3_dim_arr}) + shape_value = sess.run( + array_ops.shape(y), feed_dict={ + x: self.input_3_dim_arr + }) self.assertAllClose(shape_value, [2, 4, 1]) # we can feed in input with first dimension 1 - shape_value = sess.run(array_ops.shape(y), - feed_dict={x: [self.input_3_dim_arr[0]]}) + shape_value = sess.run( + array_ops.shape(y), feed_dict={ + x: [self.input_3_dim_arr[0]] + }) self.assertAllClose(shape_value, [1, 4, 1]) # we cannot feed in input with inconsistent dimensions with self.assertRaises(ValueError): diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib.py b/tensorflow/contrib/layers/python/layers/rev_block_lib.py index 31a1b38bd4832c5816136cab3297aa22e843b0f3..123275e1fde047cd3772528641b2e3b09742fbdc 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib.py @@ -34,12 +34,13 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib.framework.python import ops as contrib_framework_ops from tensorflow.python.framework import function from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.layers import base from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import math_ops -from tensorflow.python.ops import template from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest __all__ = ["rev_block", "RevBlock", "recompute_grad"] @@ -137,7 +138,17 @@ def _rev_block_forward(x1, return y1, y2 -class RevBlock(object): +def _scope_wrap(fn, scope): + + @functools.wraps(fn) + def wrap(*args, **kwargs): + with variable_scope.variable_scope(scope): + return fn(*args, **kwargs) + + return wrap + + +class RevBlock(base.Layer): """Block of reversible layers. See rev_block.""" def __init__(self, @@ -146,7 +157,10 @@ class RevBlock(object): num_layers=1, f_side_input=None, g_side_input=None, - use_efficient_backprop=True): + use_efficient_backprop=True, + name="revblock", + **kwargs): + super(RevBlock, self).__init__(name=name, **kwargs) if isinstance(f, list): assert len(f) == num_layers @@ -158,18 +172,8 @@ class RevBlock(object): else: g = [g] * num_layers - scope_prefix = "revblock/revlayer_%d/" - f_scope = scope_prefix + "f" - g_scope = scope_prefix + "g" - - f = [ - template.make_template(f_scope % i, fn, create_scope_now_=True) - for i, fn in enumerate(f) - ] - g = [ - template.make_template(g_scope % i, fn, create_scope_now_=True) - for i, fn in enumerate(g) - ] + f = [_scope_wrap(fn, "revlayer_%d/f" % i) for i, fn in enumerate(f)] + g = [_scope_wrap(fn, "revlayer_%d/g" % i) for i, fn in enumerate(g)] self.f = f self.g = g @@ -180,6 +184,39 @@ class RevBlock(object): self._use_efficient_backprop = use_efficient_backprop + def call(self, inputs, forward=True): + vs = variable_scope.get_variable_scope() + vars_before = vs.global_variables() + + if forward: + x1, x2 = inputs + out = self._forward(x1, x2) + else: + y1, y2 = inputs + out = self._backward(y1, y2) + + # Add any created variables to the Layer's variable stores + new_vars = vs.global_variables()[len(vars_before):] + train_vars = vs.trainable_variables() + for new_var in new_vars: + if new_var in train_vars: + self._trainable_weights.append(new_var) + else: + self._non_trainable_weights.append(new_var) + + return out + + def forward(self, x1, x2): + return self.apply([x1, x2]) + + def backward(self, y1, y2): + return self.apply([y1, y2], forward=False) + + def build(self, _): + logging.warn("RevBlock constructs its variables on first call, not on " + "build.") + self.built = True + def _efficient_grad_fn(self, inputs, variables, ys, grad_ys): """Custom gradient fn for a block of reversible residual layers.""" side_inputs = inputs[2:] @@ -228,17 +265,18 @@ class RevBlock(object): f.reverse() g.reverse() - for i in xrange(self.num_layers): - ys, grad_ys, f_ret, g_ret = _rev_layer_backward( - ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], - self.g_side_input) + with variable_scope.variable_scope(self.scope_name, reuse=True): + for i in xrange(self.num_layers): + ys, grad_ys, f_ret, g_ret = _rev_layer_backward( + ys, grad_ys, f[i], g[i], f_vars[i], self.f_side_input, g_vars[i], + self.g_side_input) - grad_f_vars, grad_f_side = f_ret - grad_g_vars, grad_g_side = g_ret - f_var_grads.append(grad_f_vars) - g_var_grads.append(grad_g_vars) - f_side_grads.append(grad_f_side) - g_side_grads.append(grad_g_side) + grad_f_vars, grad_f_side = f_ret + grad_g_vars, grad_g_side = g_ret + f_var_grads.append(grad_f_vars) + g_var_grads.append(grad_g_vars) + f_side_grads.append(grad_f_side) + g_side_grads.append(grad_g_side) # Accumulate layer gradients for f_side_input and g_side_input acc_f_side_grads = _acc_grads(*f_side_grads) @@ -265,7 +303,7 @@ class RevBlock(object): grad_x1, grad_x2 = grad_ys return [grad_x1, grad_x2] + side_input_grads, variable_grads - def forward(self, x1, x2): + def _forward(self, x1, x2): """Run forward through the reversible layers.""" side_inputs = [self.f_side_input, self.g_side_input] @@ -275,7 +313,7 @@ class RevBlock(object): self._efficient_grad_fn if self._use_efficient_backprop else None) @_fn_with_custom_grad(custom_grad_fn) - def _forward(x1_, x2_, *flat_side_inputs): + def _forward_wrap(x1_, x2_, *flat_side_inputs): f_side, g_side = nest.pack_sequence_as(side_inputs, flat_side_inputs) return _rev_block_forward( x1_, @@ -287,9 +325,9 @@ class RevBlock(object): g_side_input=g_side, gate_outputs=self._use_efficient_backprop) - return _forward(x1, x2, *flat_side_inputs) + return _forward_wrap(x1, x2, *flat_side_inputs) - def backward(self, y1, y2): + def _backward(self, y1, y2): """Run backward through the reversible layers.""" f = list(self.f) @@ -356,7 +394,14 @@ def rev_block(x1, Returns: y1, y2: tuple of float Tensors. """ - block = RevBlock(f, g, num_layers, f_side_input, g_side_input, is_training) + block = RevBlock( + f=f, + g=g, + num_layers=num_layers, + f_side_input=f_side_input, + g_side_input=g_side_input, + use_efficient_backprop=is_training, + _reuse=variable_scope.get_variable_scope().reuse) return block.forward(x1, x2) diff --git a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py index a420753fd5728e7eef4f135d4943d25e8e05d5c2..cbcbcd75114a522b95631e4e7e95c1641b0a9987 100644 --- a/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py +++ b/tensorflow/contrib/layers/python/layers/rev_block_lib_test.py @@ -188,13 +188,46 @@ class RevBlockTest(test.TestCase): def f(x): x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = core_layers.batch_normalization(x, training=True) + x = layers.batch_norm(x, is_training=True) x = convolutional.conv1d(x, self.CHANNELS // 2, 3, padding="same") - x = core_layers.batch_normalization(x, training=True) + x = layers.batch_norm(x, is_training=True) return x self._testRevBlock(x=x, f=f) + def testReuse(self): + + def f(x): + return core_layers.dense(x, self.CHANNELS // 2) + + def g(x): + return core_layers.dense(x, self.CHANNELS // 2) + + x = random_ops.random_uniform( + [self.BATCH_SIZE, self.CHANNELS], dtype=dtypes.float32) + x1, x2 = array_ops.split(x, 2, axis=-1) + + with variable_scope.variable_scope("test"): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_before = len(variables.global_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + + loss = math_ops.reduce_mean(y1 + y2) + _ = gradients_impl.gradients(loss, + [x] + variables.trainable_variables()) + + with variable_scope.variable_scope("test", reuse=True): + y1, y2 = rev_block_lib.rev_block(x1, x2, f, g, num_layers=self.NUM_LAYERS) + + num_vars_after = len(variables.global_variables()) + self.assertEqual(num_vars_before, num_vars_after) + class RecomputeTest(test.TestCase): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 94920db574e07529c28313a78e0128676fcc7970..abf6e393bb0fbbce4e43f6d209e9b30517df36c3 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -10,7 +10,7 @@ package(default_visibility = [ "//tensorflow:internal", ]) -load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "py_test", "tf_py_test") py_library( name = "learn", @@ -154,12 +154,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "experiment_test", size = "medium", srcs = ["python/learn/experiment_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", "//tensorflow/core:protos_all_py", @@ -173,6 +172,17 @@ py_test( ], ) +py_test( + name = "export_strategy_test", + size = "small", + srcs = ["python/learn/export_strategy_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow/python:client_testlib", + ], +) + py_test( name = "graph_actions_test", size = "small", @@ -346,6 +356,7 @@ py_test( srcs = ["python/learn/estimators/dnn_linear_combined_test.py"], shard_count = 4, srcs_version = "PY2AND3", + tags = ["no_oss"], # flaky b/70524820 deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -377,6 +388,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lookup_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", @@ -461,6 +473,7 @@ py_test( size = "medium", srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"], srcs_version = "PY2AND3", + tags = ["noasan"], deps = [ ":learn", "//tensorflow/contrib/layers:layers_py", @@ -482,7 +495,7 @@ py_test( name = "linear_test", size = "medium", srcs = ["python/learn/estimators/linear_test.py"], - shard_count = 4, + shard_count = 20, srcs_version = "PY2AND3", tags = ["no_pip"], deps = [ @@ -715,12 +728,11 @@ py_test( ], ) -py_test( +tf_py_test( name = "graph_io_test", size = "small", srcs = ["python/learn/learn_io/graph_io_test.py"], - srcs_version = "PY2AND3", - deps = [ + additional_deps = [ ":learn", "//tensorflow/python:client", "//tensorflow/python:client_testlib", @@ -736,20 +748,7 @@ py_test( "//tensorflow/python:training", "//tensorflow/python:variables", ], -) - -py_test( - name = "numpy_io_test", - size = "small", - srcs = ["python/learn/learn_io/numpy_io_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":learn", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:training", - "//third_party/py/numpy", - ], + grpc_enabled = True, ) py_test( diff --git a/tensorflow/contrib/learn/python/learn/datasets/__init__.py b/tensorflow/contrib/learn/python/learn/datasets/__init__.py index a3521b4109ab40d8478f20afc317cf5154da2b43..7240b0de149051afa045a8113f9e9b212840c311 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/__init__.py +++ b/tensorflow/contrib/learn/python/learn/datasets/__init__.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Dataset utilities and synthetic/reference datasets.""" from __future__ import absolute_import @@ -46,11 +45,12 @@ DATASETS = { # List of all synthetic datasets SYNTHETIC = { - # All of these will return ['data', 'target'] -> base.Dataset - 'circles': synthetic.circles, - 'spirals': synthetic.spirals + # All of these will return ['data', 'target'] -> base.Dataset + 'circles': synthetic.circles, + 'spirals': synthetic.spirals } + def load_dataset(name, size='small', test_with_fake_data=False): """Loads dataset by name. @@ -83,23 +83,28 @@ def make_dataset(name, n_samples=100, noise=None, seed=42, *args, **kwargs): seed: int or None, seed for noise Returns: - Shuffled features and labels for given synthetic dataset of type `base.Dataset` + Shuffled features and labels for given synthetic dataset of type + `base.Dataset` Raises: ValueError: Raised if `name` not found Note: - - This is a generic synthetic data generator - individual generators might have more parameters! + - This is a generic synthetic data generator - individual generators might + have more parameters! See documentation for individual parameters - - Note that the `noise` parameter uses `numpy.random.normal` and depends on `numpy`'s seed + - Note that the `noise` parameter uses `numpy.random.normal` and depends on + `numpy`'s seed TODO: - Support multiclass datasets - - Need shuffling routine. Currently synthetic datasets are reshuffled to avoid train/test correlation, + - Need shuffling routine. Currently synthetic datasets are reshuffled to + avoid train/test correlation, but that hurts reprodusability """ # seed = kwargs.pop('seed', None) if name not in SYNTHETIC: raise ValueError('Synthetic dataset not found or not implemeted: %s' % name) else: - return SYNTHETIC[name](n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs) + return SYNTHETIC[name]( + n_samples=n_samples, noise=noise, seed=seed, *args, **kwargs) diff --git a/tensorflow/contrib/learn/python/learn/datasets/base.py b/tensorflow/contrib/learn/python/learn/datasets/base.py index 71978d439449e29c7cb907b18bab5d6659a972b6..ca720ae5ed26e74da12bd6c5a37231b41442f76f 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/base.py +++ b/tensorflow/contrib/learn/python/learn/datasets/base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Base utilities for loading datasets.""" from __future__ import absolute_import @@ -24,13 +23,11 @@ import csv import os from os import path import random -import tempfile import time import numpy as np from six.moves import urllib -from tensorflow.contrib.framework import deprecated from tensorflow.python.platform import gfile Dataset = collections.namedtuple('Dataset', ['data', 'target']) @@ -100,9 +97,7 @@ def load_iris(data_path=None): module_path = path.dirname(__file__) data_path = path.join(module_path, 'data', 'iris.csv') return load_csv_with_header( - data_path, - target_dtype=np.int, - features_dtype=np.float) + data_path, target_dtype=np.int, features_dtype=np.float) def load_boston(data_path=None): @@ -118,16 +113,10 @@ def load_boston(data_path=None): module_path = path.dirname(__file__) data_path = path.join(module_path, 'data', 'boston_house_prices.csv') return load_csv_with_header( - data_path, - target_dtype=np.float, - features_dtype=np.float) + data_path, target_dtype=np.float, features_dtype=np.float) -def retry(initial_delay, - max_delay, - factor=2.0, - jitter=0.25, - is_retriable=None): +def retry(initial_delay, max_delay, factor=2.0, jitter=0.25, is_retriable=None): """Simple decorator for wrapping retriable functions. Args: @@ -152,7 +141,7 @@ def retry(initial_delay, def delays(): delay = initial_delay while delay <= max_delay: - yield delay * random.uniform(1 - jitter, 1 + jitter) + yield delay * random.uniform(1 - jitter, 1 + jitter) delay *= factor def wrap(fn): @@ -172,7 +161,9 @@ def retry(initial_delay, else: raise return fn(*args, **kwargs) + return wrapped_fn + return wrap diff --git a/tensorflow/contrib/learn/python/learn/datasets/mnist.py b/tensorflow/contrib/learn/python/learn/datasets/mnist.py index 1f3295747e141760445b021bf4f59cc47b88b8b2..37f9175015a239f763c7721cf36ab8063c0a3e32 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/mnist.py +++ b/tensorflow/contrib/learn/python/learn/datasets/mnist.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Functions for downloading and reading MNIST data.""" from __future__ import absolute_import @@ -123,8 +122,8 @@ class DataSet(object): numpy.random.seed(seed1 if seed is None else seed2) dtype = dtypes.as_dtype(dtype).base_dtype if dtype not in (dtypes.uint8, dtypes.float32): - raise TypeError('Invalid image dtype %r, expected uint8 or float32' % - dtype) + raise TypeError( + 'Invalid image dtype %r, expected uint8 or float32' % dtype) if fake_data: self._num_examples = 10000 self.one_hot = one_hot @@ -202,7 +201,9 @@ class DataSet(object): end = self._index_in_epoch images_new_part = self._images[start:end] labels_new_part = self._labels[start:end] - return numpy.concatenate((images_rest_part, images_new_part), axis=0) , numpy.concatenate((labels_rest_part, labels_new_part), axis=0) + return numpy.concatenate( + (images_rest_part, images_new_part), axis=0), numpy.concatenate( + (labels_rest_part, labels_new_part), axis=0) else: self._index_in_epoch += batch_size end = self._index_in_epoch @@ -257,16 +258,14 @@ def read_data_sets(train_dir, test_labels = extract_labels(f, one_hot=one_hot) if not 0 <= validation_size <= len(train_images): - raise ValueError( - 'Validation size should be between 0 and {}. Received: {}.' - .format(len(train_images), validation_size)) + raise ValueError('Validation size should be between 0 and {}. Received: {}.' + .format(len(train_images), validation_size)) validation_images = train_images[:validation_size] validation_labels = train_labels[:validation_size] train_images = train_images[validation_size:] train_labels = train_labels[validation_size:] - options = dict(dtype=dtype, reshape=reshape, seed=seed) train = DataSet(train_images, train_labels, **options) diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py index 907dc0f3dfced7e55c5f46711fbe93f6400e1de7..9a843168c27d9cae3f55efe4fe4c688d86c745f3 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Synthetic dataset generators.""" from __future__ import absolute_import @@ -23,18 +22,27 @@ import numpy as np from tensorflow.contrib.learn.python.learn.datasets.base import Dataset -def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args, **kwargs): + +def circles(n_samples=100, + noise=None, + seed=None, + factor=0.8, + n_classes=2, + *args, + **kwargs): """Create circles separated by some value Args: n_samples: int, number of datapoints to generate noise: float or None, standard deviation of the Gaussian noise added seed: int or None, seed for the noise - factor: float, size factor of the inner circles with respect to the outer ones + factor: float, size factor of the inner circles with respect to the outer + ones n_classes: int, number of classes to generate Returns: - Shuffled features and labels for 'circles' synthetic dataset of type `base.Dataset` + Shuffled features and labels for 'circles' synthetic dataset of type + `base.Dataset` Note: The multi-class support might not work as expected if `noise` is enabled @@ -54,7 +62,7 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args if seed is not None: np.random.seed(seed) # Algo: 1) Generate initial circle, 2) For ever class generate a smaller radius circle - linspace = np.linspace(0, 2*np.pi, n_samples // n_classes) + linspace = np.linspace(0, 2 * np.pi, n_samples // n_classes) circ_x = np.empty(0, dtype=np.int32) circ_y = np.empty(0, dtype=np.int32) base_cos = np.cos(linspace) @@ -66,12 +74,12 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args circ_y = np.append(circ_y, base_sin) base_cos *= factor base_sin *= factor - y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32)) + y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32)) # Add more points if n_samples is not divisible by n_classes (unbalanced!) extras = n_samples % n_classes - circ_x = np.append(circ_x, np.cos(np.random.rand(extras)*2*np.pi)) - circ_y = np.append(circ_y, np.sin(np.random.rand(extras)*2*np.pi)) + circ_x = np.append(circ_x, np.cos(np.random.rand(extras) * 2 * np.pi)) + circ_y = np.append(circ_y, np.sin(np.random.rand(extras) * 2 * np.pi)) y = np.append(y, np.zeros(extras, dtype=np.int32)) # Reshape the features/labels @@ -85,10 +93,13 @@ def circles(n_samples=100, noise=None, seed=None, factor=0.8, n_classes=2, *args return Dataset(data=X[indices], target=y[indices]) -def spirals(n_samples=100, noise=None, seed=None, - mode = 'archimedes', - n_loops = 2, - *args, **kwargs): +def spirals(n_samples=100, + noise=None, + seed=None, + mode='archimedes', + n_loops=2, + *args, + **kwargs): """Create spirals Currently only binary classification is supported for spiral generation @@ -104,7 +115,8 @@ def spirals(n_samples=100, noise=None, seed=None, 'fermat': a spiral with branch distances decreasing (sqrt) Returns: - Shuffled features and labels for 'spirals' synthetic dataset of type `base.Dataset` + Shuffled features and labels for 'spirals' synthetic dataset of type + `base.Dataset` Raises: ValueError: If the generation `mode` is not valid @@ -112,34 +124,35 @@ def spirals(n_samples=100, noise=None, seed=None, TODO: - Generation of unbalanced data """ - n_classes = 2 # I am not sure how to make it multiclass + n_classes = 2 # I am not sure how to make it multiclass _modes = { - 'archimedes': _archimedes_spiral, - 'bernoulli': _bernoulli_spiral, - 'fermat': _fermat_spiral + 'archimedes': _archimedes_spiral, + 'bernoulli': _bernoulli_spiral, + 'fermat': _fermat_spiral } if mode is None or mode not in _modes: - raise ValueError("Cannot generate spiral with mode %s"%mode) + raise ValueError('Cannot generate spiral with mode %s' % mode) if seed is not None: np.random.seed(seed) - linspace = np.linspace(0, 2*n_loops*np.pi, n_samples // n_classes) + linspace = np.linspace(0, 2 * n_loops * np.pi, n_samples // n_classes) spir_x = np.empty(0, dtype=np.int32) spir_y = np.empty(0, dtype=np.int32) y = np.empty(0, dtype=np.int32) for label in range(n_classes): - base_cos, base_sin = _modes[mode](linspace, label*np.pi, *args, **kwargs) + base_cos, base_sin = _modes[mode](linspace, label * np.pi, *args, **kwargs) spir_x = np.append(spir_x, base_cos) spir_y = np.append(spir_y, base_sin) - y = np.append(y, label*np.ones(n_samples // n_classes, dtype=np.int32)) + y = np.append(y, label * np.ones(n_samples // n_classes, dtype=np.int32)) # Add more points if n_samples is not divisible by n_classes (unbalanced!) extras = n_samples % n_classes if extras > 0: - x_exrta, y_extra = _modes[mode](np.random.rand(extras)*2*np.pi, *args, **kwargs) + x_extra, y_extra = _modes[mode](np.random.rand(extras) * 2 * np.pi, *args, + **kwargs) spir_x = np.append(spir_x, x_extra) spir_y = np.append(spir_y, y_extra) y = np.append(y, np.zeros(extras, dtype=np.int32)) @@ -162,7 +175,8 @@ def _archimedes_spiral(theta, theta_offset=0., *args, **kwargs): theta: array-like, angles from polar coordinates to be converted theta_offset: float, angle offset in radians (2*pi = 0) """ - x, y = theta*np.cos(theta + theta_offset), theta*np.sin(theta + theta_offset) + x, y = theta * np.cos(theta + theta_offset), theta * np.sin( + theta + theta_offset) x_norm = np.max(np.abs(x)) y_norm = np.max(np.abs(y)) x, y = x / x_norm, y / y_norm @@ -181,7 +195,8 @@ def _bernoulli_spiral(theta, theta_offset=0., *args, **kwargs): """ exp_scale = kwargs.pop('exp_scale', 0.1) - x, y = np.exp(exp_scale*theta)*np.cos(theta + theta_offset), np.exp(exp_scale*theta)*np.sin(theta + theta_offset) + x, y = np.exp(exp_scale * theta) * np.cos(theta + theta_offset), np.exp( + exp_scale * theta) * np.sin(theta + theta_offset) x_norm = np.max(np.abs(x)) y_norm = np.max(np.abs(y)) x, y = x / x_norm, y / y_norm @@ -195,7 +210,8 @@ def _fermat_spiral(theta, theta_offset=0., *args, **kwargs): theta: array-like, angles from polar coordinates to be converted theta_offset: float, angle offset in radians (2*pi = 0) """ - x, y = np.sqrt(theta)*np.cos(theta + theta_offset), np.sqrt(theta)*np.sin(theta + theta_offset) + x, y = np.sqrt(theta) * np.cos(theta + theta_offset), np.sqrt(theta) * np.sin( + theta + theta_offset) x_norm = np.max(np.abs(x)) y_norm = np.max(np.abs(y)) x, y = x / x_norm, y / y_norm diff --git a/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py b/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py index 5340afab46eba957d6d612bb583983b627537547..5809995c8c7d8e72eb47ee88a72547bae7fd3594 100644 --- a/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py +++ b/tensorflow/contrib/learn/python/learn/datasets/synthetic_test.py @@ -24,12 +24,14 @@ from tensorflow.python.platform import test from tensorflow.contrib.learn.python.learn import datasets from tensorflow.contrib.learn.python.learn.datasets import synthetic + class SyntheticTest(test.TestCase): """Test synthetic dataset generation""" def test_make_dataset(self): """Test if the synthetic routine wrapper complains about the name""" - self.assertRaises(ValueError, datasets.make_dataset, name='_non_existing_name') + self.assertRaises( + ValueError, datasets.make_dataset, name='_non_existing_name') def test_all_datasets_callable(self): """Test if all methods inside the `SYNTHETIC` are callable""" @@ -52,9 +54,10 @@ class SyntheticTest(test.TestCase): """ n_samples = 100 n_classes = 2 - circ = synthetic.circles(n_samples = n_samples, noise = None, n_classes = n_classes) + circ = synthetic.circles( + n_samples=n_samples, noise=None, n_classes=n_classes) self.assertIsInstance(circ, datasets.base.Dataset) - self.assertTupleEqual(circ.data.shape, (n_samples,2)) + self.assertTupleEqual(circ.data.shape, (n_samples, 2)) self.assertTupleEqual(circ.target.shape, (n_samples,)) self.assertSetEqual(set(circ.target), set(range(n_classes))) @@ -67,17 +70,24 @@ class SyntheticTest(test.TestCase): """ seed = 42 noise = 0.1 - circ0 = synthetic.circles(n_samples = 100, noise = noise, n_classes = 2, seed = seed) - circ1 = synthetic.circles(n_samples = 100, noise = noise, n_classes = 2, seed = seed) + circ0 = synthetic.circles( + n_samples=100, noise=noise, n_classes=2, seed=seed) + circ1 = synthetic.circles( + n_samples=100, noise=noise, n_classes=2, seed=seed) np.testing.assert_array_equal(circ0.data, circ1.data) np.testing.assert_array_equal(circ0.target, circ1.target) - circ1 = synthetic.circles(n_samples = 100, noise = noise, n_classes = 2, seed = seed+1) - self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, circ1.data) - self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.target, circ1.target) + circ1 = synthetic.circles( + n_samples=100, noise=noise, n_classes=2, seed=seed + 1) + self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, + circ1.data) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + circ0.target, circ1.target) - circ1 = synthetic.circles(n_samples = 100, noise = noise/2., n_classes = 2, seed = seed) - self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, circ1.data) + circ1 = synthetic.circles( + n_samples=100, noise=noise / 2., n_classes=2, seed=seed) + self.assertRaises(AssertionError, np.testing.assert_array_equal, circ0.data, + circ1.data) def test_spirals(self): """Test if the circles are generated correctly @@ -89,13 +99,14 @@ class SyntheticTest(test.TestCase): - returned `target` shape is (n_samples,) - set of unique classes range is [0, n_classes) """ - self.assertRaises(ValueError, synthetic.spirals, mode='_unknown_mode_spiral_') + self.assertRaises( + ValueError, synthetic.spirals, mode='_unknown_mode_spiral_') n_samples = 100 modes = ('archimedes', 'bernoulli', 'fermat') for mode in modes: - spir = synthetic.spirals(n_samples = n_samples, noise = None, mode = mode) + spir = synthetic.spirals(n_samples=n_samples, noise=None, mode=mode) self.assertIsInstance(spir, datasets.base.Dataset) - self.assertTupleEqual(spir.data.shape, (n_samples,2)) + self.assertTupleEqual(spir.data.shape, (n_samples, 2)) self.assertTupleEqual(spir.target.shape, (n_samples,)) self.assertSetEqual(set(spir.target), set(range(2))) @@ -110,18 +121,24 @@ class SyntheticTest(test.TestCase): noise = 0.1 modes = ('archimedes', 'bernoulli', 'fermat') for mode in modes: - spir0 = synthetic.spirals(n_samples = 1000, noise = noise, seed = seed) - spir1 = synthetic.spirals(n_samples = 1000, noise = noise, seed = seed) + spir0 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed) + spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed) np.testing.assert_array_equal(spir0.data, spir1.data) np.testing.assert_array_equal(spir0.target, spir1.target) - spir1 = synthetic.spirals(n_samples = 1000, noise = noise, seed = seed+1) - self.assertRaises(AssertionError, np.testing.assert_array_equal, spir0.data, spir1.data) - self.assertRaises(AssertionError, np.testing.assert_array_equal, spir0.target, spir1.target) + spir1 = synthetic.spirals(n_samples=1000, noise=noise, seed=seed + 1) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + spir0.data, spir1.data) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + spir0.target, spir1.target) + + spir1 = synthetic.spirals(n_samples=1000, noise=noise / 2., seed=seed) + self.assertRaises(AssertionError, np.testing.assert_array_equal, + spir0.data, spir1.data) - spir1 = synthetic.spirals(n_samples = 1000, noise = noise/2., seed = seed) - self.assertRaises(AssertionError, np.testing.assert_array_equal, spir0.data, spir1.data) + def test_spirals_synthetic(self): + synthetic.spirals(3) -if __name__ == "__main__": +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py index 14750961efa30128708430fac038498de0a42118..ef5e620e8f08cffa7c2b945089aa5d150baefefc 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/composable_model_test.py @@ -18,7 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import composable_model @@ -55,7 +55,7 @@ def _base_model_fn(features, labels, mode, params): raise NotImplementedError def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() assert global_step train_step = model.get_train_step(loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py index 6b125534a42c5cdde69773d99cefd6e7b2d60c9c..b968aeed1b7a11d522b531783f04f0104b37904f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/debug_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/debug_test.py @@ -44,7 +44,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.platform import test from tensorflow.python.training import input as input_lib - NUM_EXAMPLES = 100 N_CLASSES = 5 # Cardinality of multiclass labels. LABEL_DIMENSION = 3 # Dimensionality of regression labels. @@ -52,8 +51,10 @@ LABEL_DIMENSION = 3 # Dimensionality of regression labels. def _train_test_split(features_and_labels): features, labels = features_and_labels - train_set = (features[:int(len(features) / 2)], labels[:int(len(features) / 2)]) - test_set = (features[int(len(features) / 2):], labels[int(len(features) / 2):]) + train_set = (features[:int(len(features) / 2)], + labels[:int(len(features) / 2)]) + test_set = (features[int(len(features) / 2):], + labels[int(len(features) / 2):]) return train_set, test_set @@ -86,17 +87,17 @@ class DebugClassifierTest(test.TestCase): (train_features, train_labels), (test_features, test_labels) = _train_test_split( [self.features, self.labels]) - majority_class, _ = max(collections.Counter(train_labels).items(), - key=operator.itemgetter(1)) + majority_class, _ = max( + collections.Counter(train_labels).items(), key=operator.itemgetter(1)) expected_prediction = np.vstack( [[majority_class] for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=N_CLASSES) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_classes( + input_fn=_input_fn_builder(test_features, None)) self.assertAllEqual(expected_prediction, np.vstack(pred)) def testPredictBinary(self): @@ -105,34 +106,34 @@ class DebugClassifierTest(test.TestCase): test_labels) = _train_test_split( [self.features, self.binary_labels]) - majority_class, _ = max(collections.Counter(train_labels).items(), - key=operator.itemgetter(1)) + majority_class, _ = max( + collections.Counter(train_labels).items(), key=operator.itemgetter(1)) expected_prediction = np.vstack( [[majority_class] for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_classes( + input_fn=_input_fn_builder(test_features, None)) self.assertAllEqual(expected_prediction, np.vstack(pred)) - (train_features, train_labels), ( - test_features, test_labels) = _train_test_split( - [self.features, self.binary_float_labels]) + (train_features, + train_labels), (test_features, test_labels) = _train_test_split( + [self.features, self.binary_float_labels]) - majority_class, _ = max(collections.Counter(train_labels).items(), - key=operator.itemgetter(1)) + majority_class, _ = max( + collections.Counter(train_labels).items(), key=operator.itemgetter(1)) expected_prediction = np.vstack( [[majority_class] for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_classes(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_classes( + input_fn=_input_fn_builder(test_features, None)) self.assertAllEqual(expected_prediction, np.vstack(pred)) def testPredictProba(self): @@ -150,8 +151,8 @@ class DebugClassifierTest(test.TestCase): [class_distribution for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=N_CLASSES) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) pred = classifier.predict_proba( input_fn=_input_fn_builder(test_features, None)) @@ -173,17 +174,17 @@ class DebugClassifierTest(test.TestCase): [class_distribution for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) pred = classifier.predict_proba( input_fn=_input_fn_builder(test_features, None)) self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) - (train_features, train_labels), ( - test_features, test_labels) = _train_test_split( - [self.features, self.binary_float_labels]) + (train_features, + train_labels), (test_features, test_labels) = _train_test_split( + [self.features, self.binary_float_labels]) class_distribution = np.zeros((1, 2)) for label in train_labels: @@ -194,8 +195,8 @@ class DebugClassifierTest(test.TestCase): [class_distribution for _ in range(test_labels.shape[0])]) classifier = debug.DebugClassifier(n_classes=2) - classifier.fit(input_fn=_input_fn_builder(train_features, train_labels), - steps=50) + classifier.fit( + input_fn=_input_fn_builder(train_features, train_labels), steps=50) pred = classifier.predict_proba( input_fn=_input_fn_builder(test_features, None)) @@ -232,13 +233,12 @@ class DebugClassifierTest(test.TestCase): def _input_fn(): iris = test_data.prepare_iris_data_for_logistic_regression() return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) }, constant_op.constant( iris.target, shape=[100], dtype=dtypes.int32) - classifier = debug.DebugClassifier(config=run_config.RunConfig( - tf_random_seed=1)) + classifier = debug.DebugClassifier( + config=run_config.RunConfig(tf_random_seed=1)) classifier.fit(input_fn=_input_fn, steps=5) scores = classifier.evaluate(input_fn=_input_fn, steps=1) self.assertIn('loss', scores) @@ -342,8 +342,7 @@ class DebugClassifierTest(test.TestCase): def _input_fn(): iris = base.load_iris() return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) }, constant_op.constant( iris.target, shape=[150], dtype=dtypes.int32) @@ -387,7 +386,9 @@ class DebugClassifierTest(test.TestCase): # Create 4 rows, one of them (y = x), three of them (y=Not(x)) # The logistic prediction should be (y = 0.25). labels = constant_op.constant([[1], [0], [0], [0]]) - features = {'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32),} + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + } return features, labels classifier = debug.DebugClassifier(n_classes=2) @@ -404,8 +405,7 @@ class DebugClassifierTest(test.TestCase): # The logistic prediction should be (y = 0.25). labels = constant_op.constant([[1.], [0.], [0.], [0.]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) } return features, labels @@ -414,8 +414,7 @@ class DebugClassifierTest(test.TestCase): # 4 rows, with different weights. labels = constant_op.constant([[1.], [0.], [0.], [0.]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[7.], [1.], [1.], [1.]]) } return features, labels @@ -438,8 +437,7 @@ class DebugClassifierTest(test.TestCase): # than (y=Not(x)) due to the relative higher weight of the first row. labels = constant_op.constant([[1], [0], [0], [0]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[100.], [3.], [2.], [2.]]) } return features, labels @@ -448,8 +446,7 @@ class DebugClassifierTest(test.TestCase): # Create 4 rows (y = x) labels = constant_op.constant([[1], [1], [1], [1]]) features = { - 'x': array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), 'w': constant_op.constant([[1.], [1.], [1.], [1.]]) } return features, labels @@ -469,8 +466,7 @@ class DebugClassifierTest(test.TestCase): features = { 'x': input_lib.limit_epochs( - array_ops.ones( - shape=[4, 1], dtype=dtypes.float32), + array_ops.ones(shape=[4, 1], dtype=dtypes.float32), num_epochs=num_epochs), } return features, labels @@ -578,12 +574,11 @@ class DebugClassifierTest(test.TestCase): language = feature_column.sparse_column_with_hash_bucket('language', 100) feature_columns = [ feature_column.real_valued_column('age'), - feature_column.embedding_column( - language, dimension=1) + feature_column.embedding_column(language, dimension=1) ] - classifier = debug.DebugClassifier(config=run_config.RunConfig( - tf_random_seed=1)) + classifier = debug.DebugClassifier( + config=run_config.RunConfig(tf_random_seed=1)) classifier.fit(input_fn=input_fn, steps=5) def default_input_fn(unused_estimator, examples): @@ -614,8 +609,8 @@ class DebugRegressorTest(test.TestCase): classifier.fit( input_fn=_input_fn_builder(train_features, train_labels), steps=50) - pred = classifier.predict_scores(input_fn=_input_fn_builder(test_features, - None)) + pred = classifier.predict_scores( + input_fn=_input_fn_builder(test_features, None)) self.assertAllClose(expected_prediction, np.vstack(pred), atol=0.1) def testExperimentIntegration(self): @@ -698,7 +693,9 @@ class DebugRegressorTest(test.TestCase): # Create 4 rows, one of them (y = x), three of them (y=Not(x)) # The algorithm should learn (y = 0.25). labels = constant_op.constant([[1.], [0.], [0.], [0.]]) - features = {'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32),} + features = { + 'x': array_ops.ones(shape=[4, 1], dtype=dtypes.float32), + } return features, labels regressor = debug.DebugRegressor( @@ -853,5 +850,6 @@ class DebugRegressorTest(test.TestCase): predictions2 = list(regressor2.predict_scores(input_fn=predict_input_fn)) self.assertAllClose(predictions, predictions2) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index cb15ef23e95d27c737d8ae08065b804bafd39a07..c17b41c0f767e19d9c3635a8f60347a49b297cfb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -23,7 +23,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec @@ -189,7 +189,7 @@ def _dnn_model_fn(features, labels, mode, params, config=None): """Returns the op to optimize the loss.""" return optimizers.optimize_loss( loss=loss, - global_step=contrib_variables.get_global_step(), + global_step=training_util.get_global_step(), learning_rate=_LEARNING_RATE, optimizer=_get_optimizer(optimizer), gradient_multipliers=( diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py index 57e70e169ca9d6fb2adc4e50bf387cc7cf330aed..4e65c180d8bee9ab8fe9b1fbf32edc229c31af09 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined_test.py @@ -1046,11 +1046,14 @@ class DNNLinearCombinedClassifierTest(test.TestCase): if global_step == 100: # Expected is 100, but because of the global step increment bug, is 50. - self.assertEqual(50, step_counter.steps) + # Occasionally, step increments one more time due to a race condition, + # reaching 51 steps. + self.assertIn(step_counter.steps, [50, 51]) else: - # Occasionally, training stops when global_step == 101, due to a race - # condition. - self.assertEqual(51, step_counter.steps) + # Occasionally, training stops when global_step == 102, due to a race + # condition. In addition, occasionally step increments one more time due + # to a race condition reaching 52 steps. + self.assertIn(step_counter.steps, [51, 52]) def testGlobalStepDNNLinearCombinedBugFixed(self): """Tests global step update for dnn-linear combined model.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 12f9bba531a296a00d17956b8ce32e5d7dead380..2bd57597c2e9444b51b1dacfbe4180b443c95a3d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -1224,7 +1224,7 @@ class DNNRegressorTest(test.TestCase): self, predictions, expected_shape): predictions_nparray = np.array(predictions) self.assertAllEqual(expected_shape, predictions_nparray.shape) - self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.float)) + self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.floating)) def testPredict_AsIterableFalse(self): """Tests predict method with as_iterable=False.""" diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 788d2d0b1a58fad16712c968593b40de0d3979f0..4b63e08ab3372849309ee5d28d754de82e9632f4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Base Estimator class.""" from __future__ import absolute_import @@ -30,7 +29,6 @@ import six from google.protobuf import message from tensorflow.contrib import layers -from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables @@ -60,6 +58,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_util from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import gfile @@ -76,7 +75,6 @@ from tensorflow.python.util import compat from tensorflow.python.util import tf_decorator from tensorflow.python.util import tf_inspect - AS_ITERABLE_DATE = '2016-09-15' AS_ITERABLE_INSTRUCTIONS = ( 'The default behavior of predict() is changing. The default value for\n' @@ -213,7 +211,7 @@ def _get_replica_device_setter(config): 'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable', 'MutableHashTableV2', 'MutableHashTableOfTensors', 'MutableHashTableOfTensorsV2', 'MutableDenseHashTable', - 'MutableDenseHashTableV2' + 'MutableDenseHashTableV2', 'VarHandleOp' ] if config.task_type: @@ -223,8 +221,11 @@ def _get_replica_device_setter(config): if config.num_ps_replicas > 0: return device_setter.replica_device_setter( - ps_tasks=config.num_ps_replicas, worker_device=worker_device, - merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec) + ps_tasks=config.num_ps_replicas, + worker_device=worker_device, + merge_devices=True, + ps_ops=ps_ops, + cluster=config.cluster_spec) else: return None @@ -284,10 +285,10 @@ def _make_metrics_ops(metrics, features, labels, predictions): raise ValueError('Invalid metric for {}. It returned a tuple with ' 'len {}, expected 2.'.format(name, len(name))) if not isinstance(predictions, dict): - raise ValueError( - 'Metrics passed provide (name, prediction), ' - 'but predictions are not dict. ' - 'Metrics: %s, Predictions: %s.' % (metrics, predictions)) + raise ValueError('Metrics passed provide (name, prediction), ' + 'but predictions are not dict. ' + 'Metrics: %s, Predictions: %s.' % (metrics, + predictions)) # Here are two options: labels are single Tensor or a dict. if isinstance(labels, dict) and name[1] in labels: # If labels are dict and the prediction name is in it, apply metric. @@ -298,10 +299,10 @@ def _make_metrics_ops(metrics, features, labels, predictions): else: # Single head metrics. if isinstance(predictions, dict): - raise ValueError( - 'Metrics passed provide only name, no prediction, ' - 'but predictions are dict. ' - 'Metrics: %s, Labels: %s.' % (metrics, labels_tensor_or_dict)) + raise ValueError('Metrics passed provide only name, no prediction, ' + 'but predictions are dict. ' + 'Metrics: %s, Labels: %s.' % (metrics, + labels_tensor_or_dict)) result[name] = metric(predictions, labels_tensor_or_dict) return result @@ -360,10 +361,22 @@ def _write_dict_to_summary(output_dir, dictionary, current_global_step): logging.warn('Skipping summary for %s, cannot parse string to Summary.', key) continue + elif isinstance(dictionary[key], np.ndarray): + value = summary_proto.value.add() + value.tag = key + value.node_name = key + tensor_proto = tensor_util.make_tensor_proto(dictionary[key]) + value.tensor.CopyFrom(tensor_proto) + logging.info( + 'Summary for np.ndarray is not visible in Tensorboard by default. ' + 'Consider using a Tensorboard plugin for visualization (see ' + 'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md' + ' for more information).') else: logging.warn( 'Skipping summary for %s, must be a float, np.float32, np.int64, ' - 'np.int32 or int or a serialized string of Summary.', key) + 'np.int32 or int or np.ndarray or a serialized string of Summary.', + key) summary_writer.add_summary(summary_proto, current_global_step) summary_writer.flush() @@ -372,8 +385,8 @@ GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec', ['tags', 'transforms']) -class BaseEstimator( - sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable): +class BaseEstimator(sklearn.BaseEstimator, evaluable.Evaluable, + trainable.Trainable): """Abstract BaseEstimator class to train and evaluate TensorFlow models. Users should not instantiate or subclass this class. Instead, use an @@ -415,7 +428,7 @@ class BaseEstimator( # necessary. # pylint: disable=g-doc-exception raise ValueError( - "model_dir are set both in constructor and RunConfig, but with " + 'model_dir are set both in constructor and RunConfig, but with ' "different values. In constructor: '{}', in RunConfig: " "'{}' ".format(model_dir, self._config.model_dir)) # pylint: enable=g-doc-exception @@ -444,12 +457,16 @@ class BaseEstimator( # TODO(wicke): make RunConfig immutable, and then return it without a copy. return copy.deepcopy(self._config) - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('y', None), ('batch_size', None) - ) - def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, - monitors=None, max_steps=None): + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('y', None), ('batch_size', None)) + def fit(self, + x=None, + y=None, + input_fn=None, + steps=None, + batch_size=None, + monitors=None, + max_steps=None): # pylint: disable=g-doc-args,g-doc-return-or-yield """See `Trainable`. @@ -481,13 +498,15 @@ class BaseEstimator( logging.info('Loss for final step: %s.', loss) return self - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('y', None), ('batch_size', None) - ) - def partial_fit( - self, x=None, y=None, input_fn=None, steps=1, batch_size=None, - monitors=None): + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('y', None), ('batch_size', None)) + def partial_fit(self, + x=None, + y=None, + input_fn=None, + steps=1, + batch_size=None, + monitors=None): """Incremental fit on a batch of samples. This method is expected to be called several times consecutively @@ -523,13 +542,16 @@ class BaseEstimator( """ logging.warning('The current implementation of partial_fit is not optimized' ' for use in a loop. Consider using fit() instead.') - return self.fit(x=x, y=y, input_fn=input_fn, steps=steps, - batch_size=batch_size, monitors=monitors) + return self.fit( + x=x, + y=y, + input_fn=input_fn, + steps=steps, + batch_size=batch_size, + monitors=monitors) - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('y', None), ('batch_size', None) - ) + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('y', None), ('batch_size', None)) def evaluate(self, x=None, y=None, @@ -571,13 +593,15 @@ class BaseEstimator( eval_results.update({'global_step': global_step}) return eval_results - @deprecated_args( - SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, ('x', None), - ('batch_size', None), ('as_iterable', True) - ) - def predict( - self, x=None, input_fn=None, batch_size=None, outputs=None, - as_iterable=True): + @deprecated_args(SCIKIT_DECOUPLE_DATE, SCIKIT_DECOUPLE_INSTRUCTIONS, + ('x', None), ('batch_size', None), ('as_iterable', True)) + def predict(self, + x=None, + input_fn=None, + batch_size=None, + outputs=None, + as_iterable=True, + iterate_batches=False): """Returns predictions for given features. Args: @@ -593,6 +617,9 @@ class BaseEstimator( for each example until inputs are exhausted. Note: The inputs must terminate if you want the iterable to terminate (e.g. be sure to pass num_epochs=1 if you are using something like read_batch_features). + iterate_batches: If True, yield the whole batch at once instead of + decomposing the batch into individual samples. Only relevant when + as_iterable is True. Returns: A numpy array of predicted classes or regression values if the @@ -612,7 +639,8 @@ class BaseEstimator( input_fn=input_fn, feed_fn=feed_fn, outputs=outputs, - as_iterable=as_iterable) + as_iterable=as_iterable, + iterate_batches=iterate_batches) def get_variable_value(self, name): """Returns value of the variable given by name. @@ -638,16 +666,17 @@ class BaseEstimator( return self._model_dir @deprecated('2017-03-25', 'Please use Estimator.export_savedmodel() instead.') - def export(self, - export_dir, - input_fn=export._default_input_fn, # pylint: disable=protected-access - input_feature_key=None, - use_deprecated_input_fn=True, - signature_fn=None, - prediction_key=None, - default_batch_size=1, - exports_to_keep=None, - checkpoint_path=None): + def export( + self, + export_dir, + input_fn=export._default_input_fn, # pylint: disable=protected-access + input_feature_key=None, + use_deprecated_input_fn=True, + signature_fn=None, + prediction_key=None, + default_batch_size=1, + exports_to_keep=None, + checkpoint_path=None): """Exports inference graph into given dir. Args: @@ -785,8 +814,8 @@ class BaseEstimator( logging.debug('Setting feature info to %s.', str(self._features_info)) if labels is not None: if self._labels_info is not None: - logging.debug('Given labels: %s, required signatures: %s.', - str(labels), str(self._labels_info)) + logging.debug('Given labels: %s, required signatures: %s.', str(labels), + str(self._labels_info)) if not tensor_signature.tensors_compatible(labels, self._labels_info): raise ValueError('Labels are incompatible with given information. ' 'Given labels: %s, required signatures: %s.' % @@ -837,13 +866,13 @@ class BaseEstimator( if not checkpoint_path: latest_path = saver.latest_checkpoint(self._model_dir) if not latest_path: - raise NotFittedError("Couldn't find trained model at %s." - % self._model_dir) + raise NotFittedError( + "Couldn't find trained model at %s." % self._model_dir) checkpoint_path = latest_path # Setup output directory. - eval_dir = os.path.join(self._model_dir, 'eval' if not name else - 'eval_' + name) + eval_dir = os.path.join(self._model_dir, 'eval' + if not name else 'eval_' + name) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) @@ -866,8 +895,7 @@ class BaseEstimator( 'Use steps=None if intended.') if steps: hooks.append( - evaluation.StopAfterNEvalsHook( - steps, log_progress=log_progress)) + evaluation.StopAfterNEvalsHook(steps, log_progress=log_progress)) global_step_key = 'global_step' while global_step_key in eval_dict: @@ -903,8 +931,8 @@ class BaseEstimator( # Check that model has been trained. checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: - raise NotFittedError("Couldn't find trained model at %s." - % self._model_dir) + raise NotFittedError( + "Couldn't find trained model at %s." % self._model_dir) with ops.Graph().as_default() as g: random_seed.set_random_seed(self._config.tf_random_seed) @@ -966,7 +994,8 @@ class BaseEstimator( existing_keys = predictions.keys() predictions = { key: value - for key, value in six.iteritems(predictions) if key in outputs + for key, value in six.iteritems(predictions) + if key in outputs } if not predictions: raise ValueError('Expected to run at least one output from %s, ' @@ -1032,8 +1061,7 @@ class BaseEstimator( chief_only_hooks=chief_hooks + model_fn_ops.training_chief_hooks, save_checkpoint_secs=0, # Saving is handled by a hook. save_summaries_steps=self._config.save_summary_steps, - config=self._session_config - ) as mon_sess: + config=self._session_config) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([model_fn_ops.train_op, model_fn_ops.loss]) @@ -1124,8 +1152,7 @@ class Estimator(BaseEstimator): if params is not None and 'params' not in model_fn_args: raise ValueError('Estimator\'s model_fn (%s) does not have a params ' 'argument, but params (%s) were passed to the ' - 'Estimator\'s constructor.' % - (model_fn, params)) + 'Estimator\'s constructor.' % (model_fn, params)) if params is None and 'params' in model_fn_args: logging.warning('Estimator\'s model_fn (%s) includes params ' 'argument, but params are not passed to Estimator.', @@ -1179,8 +1206,9 @@ class Estimator(BaseEstimator): # Custom metrics should overwrite defaults. if metrics: - model_fn_ops.eval_metric_ops.update(_make_metrics_ops( - metrics, features, labels, model_fn_ops.predictions)) + model_fn_ops.eval_metric_ops.update( + _make_metrics_ops(metrics, features, labels, + model_fn_ops.predictions)) return model_fn_ops @@ -1225,12 +1253,12 @@ class Estimator(BaseEstimator): Raises: ValueError: if `metrics` don't match `labels`. """ - model_fn_ops = self._call_model_fn( - features, labels, model_fn_lib.ModeKeys.EVAL, metrics) + model_fn_ops = self._call_model_fn(features, labels, + model_fn_lib.ModeKeys.EVAL, metrics) if metric_key.MetricKey.LOSS not in model_fn_ops.eval_metric_ops: model_fn_ops.eval_metric_ops[metric_key.MetricKey.LOSS] = ( - metrics_lib.streaming_mean(model_fn_ops.loss)) + metrics_lib.mean(model_fn_ops.loss)) return model_fn_ops def _get_predict_ops(self, features): @@ -1250,13 +1278,17 @@ class Estimator(BaseEstimator): self._labels_info) return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER) - def export_savedmodel( - self, export_dir_base, serving_input_fn, - default_output_alternative_key=None, - assets_extra=None, - as_text=False, - checkpoint_path=None, - graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),)): + def export_savedmodel(self, + export_dir_base, + serving_input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + checkpoint_path=None, + graph_rewrite_specs=(GraphRewriteSpec( + (tag_constants.SERVING,), ()),), + strip_default_attrs=False): + # pylint: disable=line-too-long """Exports inference graph as a SavedModel into given dir. Args: @@ -1280,6 +1312,10 @@ class Estimator(BaseEstimator): produce a separate MetaGraphDef within the exported SavedModel, tagged and rewritten as specified. Defaults to a single entry using the default serving tag ("serve") and no rewriting. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued + Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Returns: The string path to the exported directory. @@ -1287,6 +1323,7 @@ class Estimator(BaseEstimator): Raises: ValueError: if an unrecognized export_type is requested. """ + # pylint: enable=line-too-long if serving_input_fn is None: raise ValueError('serving_input_fn must be defined.') @@ -1294,8 +1331,8 @@ class Estimator(BaseEstimator): # Locate the latest checkpoint checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: - raise NotFittedError("Couldn't find trained model at %s." - % self._model_dir) + raise NotFittedError( + "Couldn't find trained model at %s." % self._model_dir) export_dir = saved_model_export_utils.get_timestamped_export_dir( export_dir_base) @@ -1329,10 +1366,10 @@ class Estimator(BaseEstimator): saved_model_export_utils.get_output_alternatives( model_fn_ops, default_output_alternative_key)) - init_op = control_flow_ops.group( - variables.local_variables_initializer(), - resources.initialize_resources(resources.shared_resources()), - lookup_ops.tables_initializer()) + init_op = control_flow_ops.group(variables.local_variables_initializer(), + resources.initialize_resources( + resources.shared_resources()), + lookup_ops.tables_initializer()) # Build the SignatureDefs from all pairs of input and output alternatives signature_def_map = saved_model_export_utils.build_all_signature_defs( @@ -1362,11 +1399,12 @@ class Estimator(BaseEstimator): # TODO(soergel): switch to main_op or otherwise update when dust settles builder.add_meta_graph_and_variables( - session, untransformed_tags, + session, + untransformed_tags, signature_def_map=signature_def_map, - assets_collection=ops.get_collection( - ops.GraphKeys.ASSET_FILEPATHS), - legacy_init_op=init_op) + assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS), + legacy_init_op=init_op, + strip_default_attrs=strip_default_attrs) # pylint: disable=protected-access base_meta_graph_def = builder._saved_model.meta_graphs[0] @@ -1375,12 +1413,16 @@ class Estimator(BaseEstimator): if graph_rewrite_specs[1:]: # Prepare the input_names and output_names needed for the # meta_graph_transform call below. - input_names = [tensor.name - for input_dict in input_alternatives.values() - for tensor in input_dict.values()] - output_names = [tensor.name - for output_alternative in output_alternatives.values() - for tensor in output_alternative[1].values()] + input_names = [ + tensor.name + for input_dict in input_alternatives.values() + for tensor in input_dict.values() + ] + output_names = [ + tensor.name + for output_alternative in output_alternatives.values() + for tensor in output_alternative[1].values() + ] # Write the additional MetaGraphDefs for graph_rewrite_spec in graph_rewrite_specs[1:]: @@ -1399,11 +1441,11 @@ class Estimator(BaseEstimator): # Add the extra assets if assets_extra: - assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir), - compat.as_bytes('assets.extra')) + assets_extra_path = os.path.join( + compat.as_bytes(temp_export_dir), compat.as_bytes('assets.extra')) for dest_relative, source in assets_extra.items(): - dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), - compat.as_bytes(dest_relative)) + dest_absolute = os.path.join( + compat.as_bytes(assets_extra_path), compat.as_bytes(dest_relative)) dest_path = os.path.dirname(dest_absolute) gfile.MakeDirs(dest_path) gfile.Copy(source, dest_absolute) @@ -1423,25 +1465,36 @@ class SKCompat(sklearn.BaseEstimator): def fit(self, x, y, batch_size=128, steps=None, max_steps=None, monitors=None): - input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, feed_fn=None, - batch_size=batch_size, shuffle=True, - epochs=None) + input_fn, feed_fn = _get_input_fn( + x, + y, + input_fn=None, + feed_fn=None, + batch_size=batch_size, + shuffle=True, + epochs=None) all_monitors = [] if feed_fn: all_monitors = [basic_session_run_hooks.FeedFnHook(feed_fn)] if monitors: all_monitors.extend(monitors) - self._estimator.fit(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - monitors=all_monitors) + self._estimator.fit( + input_fn=input_fn, + steps=steps, + max_steps=max_steps, + monitors=all_monitors) return self def score(self, x, y, batch_size=128, steps=None, metrics=None, name=None): - input_fn, feed_fn = _get_input_fn(x, y, input_fn=None, - feed_fn=None, batch_size=batch_size, - shuffle=False, epochs=1) + input_fn, feed_fn = _get_input_fn( + x, + y, + input_fn=None, + feed_fn=None, + batch_size=batch_size, + shuffle=False, + epochs=1) if metrics is not None and not isinstance(metrics, dict): raise ValueError('Metrics argument should be None or dict. ' 'Got %s.' % metrics) @@ -1457,8 +1510,13 @@ class SKCompat(sklearn.BaseEstimator): def predict(self, x, batch_size=128, outputs=None): input_fn, feed_fn = _get_input_fn( - x, None, input_fn=None, feed_fn=None, batch_size=batch_size, - shuffle=False, epochs=1) + x, + None, + input_fn=None, + feed_fn=None, + batch_size=batch_size, + shuffle=False, + epochs=1) results = list( self._estimator._infer_model( input_fn=input_fn, @@ -1469,7 +1527,6 @@ class SKCompat(sklearn.BaseEstimator): if not isinstance(results[0], dict): return np.concatenate([output for output in results], axis=0) return { - key: np.concatenate( - [output[key] for output in results], axis=0) + key: np.concatenate([output[key] for output in results], axis=0) for key in results[0] } diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py index 248c6c733ffca351c848ba07110ba89928634a23..d4a46b41d0c93ef58d5db8c433cbf348fec10f5e 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_input_test.py @@ -23,7 +23,7 @@ import tempfile import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn import models @@ -41,7 +41,6 @@ from tensorflow.python.platform import test from tensorflow.python.training import input as input_lib from tensorflow.python.training import queue_runner_impl - _BOSTON_INPUT_DIM = 13 _IRIS_INPUT_DIM = 4 @@ -93,8 +92,8 @@ def boston_eval_fn(): constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]) labels = array_ops.reshape( constant_op.constant(boston.target), [n_examples, 1]) - return array_ops.concat([features, features], 0), array_ops.concat( - [labels, labels], 0) + return array_ops.concat([features, features], + 0), array_ops.concat([labels, labels], 0) def extract(data, key): @@ -114,7 +113,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -129,7 +128,10 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return prediction, loss, train_op @@ -139,7 +141,10 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -150,7 +155,10 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -173,7 +181,9 @@ class EstimatorInputTest(test.TestCase): scores = est.evaluate( x=boston_input, y=float64_target, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) del est # Create another estimator object with the same output dir. est2 = estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir) @@ -182,7 +192,9 @@ class EstimatorInputTest(test.TestCase): scores2 = est2.evaluate( x=boston_input, y=float64_target, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) self.assertAllClose(scores2['MSE'], scores['MSE']) predictions = np.array(list(est2.predict(x=boston_input))) other_score = _sklearn.mean_squared_error(predictions, @@ -197,7 +209,9 @@ class EstimatorInputTest(test.TestCase): scores = est.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) predictions = np.array(list(est.predict(x=boston.data))) other_score = _sklearn.mean_squared_error(predictions, boston.target) self.assertAllClose(scores['MSE'], other_score) @@ -213,7 +227,9 @@ class EstimatorInputTest(test.TestCase): scores = est.evaluate( x=boston_input, y=float64_target, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) predictions = np.array(list(est.predict(x=boston_input))) other_score = _sklearn.mean_squared_error(predictions, boston.target) self.assertAllClose(other_score, scores['MSE']) @@ -228,14 +244,15 @@ class EstimatorInputTest(test.TestCase): scores = est.score( x=iris.data, y=iris.target, - metrics={('accuracy', 'class'): metric_ops.streaming_accuracy}) + metrics={ + ('accuracy', 'class'): metric_ops.streaming_accuracy + }) predictions = est.predict(x=iris.data) predictions_class = est.predict(x=iris.data, outputs=['class'])['class'] self.assertEqual(predictions['prob'].shape[0], iris.target.shape[0]) self.assertAllClose(predictions['class'], predictions_class) - self.assertAllClose( - predictions['class'], np.argmax( - predictions['prob'], axis=1)) + self.assertAllClose(predictions['class'], + np.argmax(predictions['prob'], axis=1)) other_score = _sklearn.accuracy_score(iris.target, predictions['class']) self.assertAllClose(scores['accuracy'], other_score) self.assertTrue('global_step' in scores) @@ -250,17 +267,18 @@ class EstimatorInputTest(test.TestCase): scores = est.evaluate( x=iris_data, y=iris_target, - metrics={('accuracy', 'class'): metric_ops.streaming_accuracy}) + metrics={ + ('accuracy', 'class'): metric_ops.streaming_accuracy + }) predictions = list(est.predict(x=iris_data)) predictions_class = list(est.predict(x=iris_data, outputs=['class'])) self.assertEqual(len(predictions), iris.target.shape[0]) classes_batch = np.array([p['class'] for p in predictions]) self.assertAllClose(classes_batch, np.array([p['class'] for p in predictions_class])) - self.assertAllClose( - classes_batch, - np.argmax( - np.array([p['prob'] for p in predictions]), axis=1)) + self.assertAllClose(classes_batch, + np.argmax( + np.array([p['prob'] for p in predictions]), axis=1)) other_score = _sklearn.accuracy_score(iris.target, classes_batch) self.assertAllClose(other_score, scores['accuracy']) self.assertTrue('global_step' in scores) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index be2b0cb3ca959323b4de095ca072278f028be301..d81a534b79bc90fe91ffd3cb97a7865a7cb4c2a9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -32,7 +32,7 @@ from google.protobuf import text_format from tensorflow.contrib import learn from tensorflow.contrib import lookup -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import experiment @@ -111,8 +111,8 @@ def boston_eval_fn(): constant_op.constant(boston.data), [n_examples, _BOSTON_INPUT_DIM]) labels = array_ops.reshape( constant_op.constant(boston.target), [n_examples, 1]) - return array_ops.concat([features, features], 0), array_ops.concat( - [labels, labels], 0) + return array_ops.concat([features, features], + 0), array_ops.concat([labels, labels], 0) def extract(data, key): @@ -132,7 +132,7 @@ def linear_model_params_fn(features, labels, mode, params): prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( loss, - variables.get_global_step(), + training_util.get_global_step(), optimizer='Adagrad', learning_rate=params['learning_rate']) return prediction, loss, train_op @@ -147,7 +147,10 @@ def linear_model_fn(features, labels, mode): (_, features), = features.items() prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return prediction, loss, train_op @@ -157,7 +160,10 @@ def linear_model_fn_with_model_fn_ops(features, labels, mode): model_fn.ModeKeys.INFER) prediction, loss = (models.linear_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return model_fn.ModelFnOps( mode=mode, predictions=prediction, loss=loss, train_op=train_op) @@ -168,7 +174,10 @@ def logistic_model_no_mode_fn(features, labels): labels = array_ops.one_hot(labels, 3, 1, 0) prediction, loss = (models.logistic_regression_zero_init(features, labels)) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return { 'class': math_ops.argmax(prediction, 1), 'prob': prediction @@ -184,14 +193,12 @@ def _build_estimator_for_export_tests(tmpdir): def _input_fn(): iris = base.load_iris() return { - 'feature': constant_op.constant( - iris.data, dtype=dtypes.float32) + 'feature': constant_op.constant(iris.data, dtype=dtypes.float32) }, constant_op.constant( iris.target, shape=[150], dtype=dtypes.int32) feature_columns = [ - feature_column_lib.real_valued_column( - 'feature', dimension=4) + feature_column_lib.real_valued_column('feature', dimension=4) ] est = linear.LinearRegressor(feature_columns) @@ -241,7 +248,7 @@ def _build_estimator_for_resource_export_test(): const = constant_op.constant(-1, dtype=dtypes.int64) table = lookup.MutableHashTable( dtypes.string, dtypes.int64, const, name='LookupTableModel') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) if mode in (model_fn.ModeKeys.TRAIN, model_fn.ModeKeys.EVAL): key = constant_op.constant(['key']) value = constant_op.constant([42], dtype=dtypes.int64) @@ -291,8 +298,8 @@ class CheckCallsMonitor(monitors_lib.BaseMonitor): self.begin_calls == self.expect_calls) -def _model_fn_ops( - expected_features, expected_labels, actual_features, actual_labels, mode): +def _model_fn_ops(expected_features, expected_labels, actual_features, + actual_labels, mode): assert_ops = tuple([ check_ops.assert_equal( expected_features[k], actual_features[k], name='assert_%s' % k) @@ -306,15 +313,15 @@ def _model_fn_ops( mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1)) + train_op=training_util.get_global_step().assign_add(1)) def _make_input_fn(features, labels): + def _input_fn(): - return { - k: constant_op.constant(v) - for k, v in six.iteritems(features) - }, constant_op.constant(labels) + return {k: constant_op.constant(v) + for k, v in six.iteritems(features)}, constant_op.constant(labels) + return _input_fn @@ -369,11 +376,13 @@ class EstimatorModelFnTest(test.TestCase): self.assertEqual(expected_params, params) self.assertTrue(config.i_am_test) return _model_fn_ops(features, labels, arg0, arg1, mode) + partial_model_fn = functools.partial( _model_fn, foo=expected_foo, bar=expected_bar) est = estimator.Estimator( - model_fn=partial_model_fn, params=expected_params, + model_fn=partial_model_fn, + params=expected_params, config=expected_config) self.assertEqual(0, model_fn_call_count[0]) est.fit(input_fn=_make_input_fn(features, labels), steps=1) @@ -382,17 +391,24 @@ class EstimatorModelFnTest(test.TestCase): def testModelFnWithModelDir(self): expected_param = {'some_param': 'some_value'} expected_model_dir = tempfile.mkdtemp() - def _argument_checker(features, labels, mode, params, config=None, + + def _argument_checker(features, + labels, + mode, + params, + config=None, model_dir=None): _, _, _ = features, labels, config self.assertEqual(model_fn.ModeKeys.TRAIN, mode) self.assertEqual(expected_param, params) self.assertEqual(model_dir, expected_model_dir) return (constant_op.constant(0.), constant_op.constant(0.), - variables.get_global_step().assign_add(1)) - est = estimator.Estimator(model_fn=_argument_checker, - params=expected_param, - model_dir=expected_model_dir) + training_util.get_global_step().assign_add(1)) + + est = estimator.Estimator( + model_fn=_argument_checker, + params=expected_param, + model_dir=expected_model_dir) est.fit(input_fn=boston_input_fn, steps=1) def testInvalidModelFn_no_train_op(self): @@ -400,7 +416,7 @@ class EstimatorModelFnTest(test.TestCase): def _invalid_model_fn(features, labels): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): loss = 100.0 - w return None, loss, None @@ -415,7 +431,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) predictions = loss @@ -434,7 +450,7 @@ class EstimatorModelFnTest(test.TestCase): # pylint: disable=unused-argument w = variables_lib.Variable(42.0, 'weight') loss = 100.0 - w - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) with ops.control_dependencies([update_global_step]): train_op = w.assign_add(loss / 100.0) return None, loss, train_op @@ -447,8 +463,7 @@ class EstimatorModelFnTest(test.TestCase): est.predict(input_fn=boston_input_fn) with self.assertRaisesRegexp(ValueError, 'Missing prediction'): est.predict( - input_fn=functools.partial( - boston_input_fn, num_epochs=1), + input_fn=functools.partial(boston_input_fn, num_epochs=1), as_iterable=True) def testModelFnScaffoldInTraining(self): @@ -464,7 +479,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant(0.), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(init_fn=_init_fn)) est = estimator.Estimator(model_fn=_model_fn_scaffold) @@ -483,7 +498,7 @@ class EstimatorModelFnTest(test.TestCase): mode=mode, predictions=constant_op.constant([[1.]]), loss=constant_op.constant(0.), - train_op=variables.get_global_step().assign_add(1), + train_op=training_util.get_global_step().assign_add(1), scaffold=monitored_session.Scaffold(saver=self.mock_saver)) def input_fn(): @@ -498,15 +513,17 @@ class EstimatorModelFnTest(test.TestCase): self.assertTrue(self.mock_saver.restore.called) est.predict(input_fn=input_fn) self.assertTrue(self.mock_saver.restore.called) + def serving_input_fn(): - serialized_tf_example = array_ops.placeholder(dtype=dtypes.string, - shape=[None], - name='input_example_tensor') + serialized_tf_example = array_ops.placeholder( + dtype=dtypes.string, shape=[None], name='input_example_tensor') features, labels = input_fn() - return input_fn_utils.InputFnOps( - features, labels, {'examples': serialized_tf_example}) + return input_fn_utils.InputFnOps(features, labels, { + 'examples': serialized_tf_example + }) - est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn) + est.export_savedmodel( + os.path.join(est.model_dir, 'export'), serving_input_fn) self.assertTrue(self.mock_saver.restore.called) @@ -550,33 +567,28 @@ class EstimatorTest(test.TestCase): def testRunConfigModelDir(self): config = run_config.RunConfig(model_dir='test_dir') - est = estimator.Estimator(model_fn=linear_model_fn, - config=config) + est = estimator.Estimator(model_fn=linear_model_fn, config=config) self.assertEqual('test_dir', est.config.model_dir) self.assertEqual('test_dir', est.model_dir) def testModelDirAndRunConfigModelDir(self): config = run_config.RunConfig(model_dir='test_dir') - est = estimator.Estimator(model_fn=linear_model_fn, - config=config, - model_dir='test_dir') + est = estimator.Estimator( + model_fn=linear_model_fn, config=config, model_dir='test_dir') self.assertEqual('test_dir', est.config.model_dir) with self.assertRaisesRegexp( - ValueError, - 'model_dir are set both in constructor and RunConfig, ' + ValueError, 'model_dir are set both in constructor and RunConfig, ' 'but with different'): - estimator.Estimator(model_fn=linear_model_fn, - config=config, - model_dir='different_dir') + estimator.Estimator( + model_fn=linear_model_fn, config=config, model_dir='different_dir') def testModelDirIsCopiedToRunConfig(self): config = run_config.RunConfig() self.assertIsNone(config.model_dir) - est = estimator.Estimator(model_fn=linear_model_fn, - model_dir='test_dir', - config=config) + est = estimator.Estimator( + model_fn=linear_model_fn, model_dir='test_dir', config=config) self.assertEqual('test_dir', est.config.model_dir) self.assertEqual('test_dir', est.model_dir) @@ -656,25 +668,27 @@ class EstimatorTest(test.TestCase): boston = base.load_boston() output_dir = tempfile.mkdtemp() est = estimator.SKCompat( - estimator.Estimator( - model_fn=linear_model_fn, model_dir=output_dir)) + estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)) float64_labels = boston.target.astype(np.float64) est.fit(x=boston.data, y=float64_labels, steps=50) scores = est.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) del est # Create another estimator object with the same output dir. est2 = estimator.SKCompat( - estimator.Estimator( - model_fn=linear_model_fn, model_dir=output_dir)) + estimator.Estimator(model_fn=linear_model_fn, model_dir=output_dir)) # Check we can evaluate and predict. scores2 = est2.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) self.assertAllClose(scores['MSE'], scores2['MSE']) predictions = np.array(list(est2.predict(x=boston.data))) other_score = _sklearn.mean_squared_error(predictions, float64_labels) @@ -685,14 +699,15 @@ class EstimatorTest(test.TestCase): scores3 = est2.score( x=boston.data, y=float64_labels, - metrics={'MSE': metric_ops.streaming_mean_squared_error}) + metrics={ + 'MSE': metric_ops.streaming_mean_squared_error + }) self.assertLess(scores3['MSE'], scores['MSE']) def test_checkpoint_contains_relative_paths(self): tmpdir = tempfile.mkdtemp() est = estimator.Estimator( - model_dir=tmpdir, - model_fn=linear_model_fn_with_model_fn_ops) + model_dir=tmpdir, model_fn=linear_model_fn_with_model_fn_ops) est.fit(input_fn=boston_input_fn, steps=5) checkpoint_file_content = file_io.read_file_to_string( @@ -700,22 +715,20 @@ class EstimatorTest(test.TestCase): ckpt = checkpoint_state_pb2.CheckpointState() text_format.Merge(checkpoint_file_content, ckpt) self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5') - self.assertAllEqual( - ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths) + self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'], + ckpt.all_model_checkpoint_paths) def test_train_save_copy_reload(self): tmpdir = tempfile.mkdtemp() model_dir1 = os.path.join(tmpdir, 'model_dir1') est1 = estimator.Estimator( - model_dir=model_dir1, - model_fn=linear_model_fn_with_model_fn_ops) + model_dir=model_dir1, model_fn=linear_model_fn_with_model_fn_ops) est1.fit(input_fn=boston_input_fn, steps=5) model_dir2 = os.path.join(tmpdir, 'model_dir2') os.renames(model_dir1, model_dir2) est2 = estimator.Estimator( - model_dir=model_dir2, - model_fn=linear_model_fn_with_model_fn_ops) + model_dir=model_dir2, model_fn=linear_model_fn_with_model_fn_ops) self.assertEqual(5, est2.get_variable_value('global_step')) est2.fit(input_fn=boston_input_fn, steps=5) self.assertEqual(10, est2.get_variable_value('global_step')) @@ -724,7 +737,9 @@ class EstimatorTest(test.TestCase): boston = base.load_boston() est = estimator.SKCompat( estimator.Estimator( - model_fn=linear_model_params_fn, params={'learning_rate': 0.01})) + model_fn=linear_model_params_fn, params={ + 'learning_rate': 0.01 + })) est.fit(x=boston.data, y=boston.target, steps=100) def testHooksNotChanged(self): @@ -824,11 +839,13 @@ class EstimatorTest(test.TestCase): def testMonitorsForFit(self): est = estimator.Estimator(model_fn=linear_model_fn) - est.fit(input_fn=boston_input_fn, - steps=21, - monitors=[CheckCallsMonitor(expect_calls=21)]) + est.fit( + input_fn=boston_input_fn, + steps=21, + monitors=[CheckCallsMonitor(expect_calls=21)]) def testHooksForEvaluate(self): + class CheckCallHook(session_run_hook.SessionRunHook): def __init__(self): @@ -874,7 +891,9 @@ class EstimatorTest(test.TestCase): est.evaluate( input_fn=boston_input_fn, steps=200, - metrics={'MSE': _streaming_mean_squared_error_histogram}) + metrics={ + 'MSE': _streaming_mean_squared_error_histogram + }) events = util_test.latest_events(est.model_dir + '/eval') output_values = {} for e in events: @@ -884,6 +903,37 @@ class EstimatorTest(test.TestCase): self.assertTrue('MSE' in output_values) self.assertTrue(output_values['MSE'].HasField('histo')) + def testSummaryWritingWithTensor(self): + + def _streaming_precition_mean_tensor(predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + return metric_ops.streaming_mean_tensor( + predictions, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) + + est = estimator.Estimator(model_fn=linear_model_fn) + est.fit(input_fn=boston_input_fn, steps=200) + est.evaluate( + input_fn=boston_input_fn, + steps=200, + metrics={ + 'PMT': _streaming_precition_mean_tensor + }) + events = util_test.latest_events(est.model_dir + '/eval') + output_values = {} + for e in events: + if e.HasField('summary'): + for v in e.summary.value: + output_values[v.tag] = v + self.assertTrue('PMT' in output_values) + self.assertTrue(output_values['PMT'].HasField('tensor')) + def testLossInGraphCollection(self): class _LossCheckerHook(session_run_hook.SessionRunHook): @@ -927,8 +977,8 @@ class EstimatorTest(test.TestCase): self.assertTrue( gfile.Exists( os.path.join( - compat.as_bytes(export_dir), compat.as_bytes( - 'saved_model.pb')))) + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) self.assertTrue( gfile.Exists( os.path.join( @@ -988,11 +1038,11 @@ class EstimatorTest(test.TestCase): self.assertTrue('input_example_tensor' in graph_ops) self.assertTrue('ParseExample/ParseExample' in graph_ops) self.assertTrue('linear/linear/feature/matmul' in graph_ops) - self.assertItemsEqual( - ['bogus_lookup', 'feature'], - [compat.as_str_any(x) for x in graph.get_collection( - constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)]) - + self.assertItemsEqual(['bogus_lookup', 'feature'], [ + compat.as_str_any(x) + for x in graph.get_collection( + constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS) + ]) # cleanup gfile.DeleteRecursively(tmpdir) @@ -1010,8 +1060,8 @@ class EstimatorTest(test.TestCase): self.assertTrue( gfile.Exists( os.path.join( - compat.as_bytes(export_dir), compat.as_bytes( - 'saved_model.pb')))) + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) self.assertTrue( gfile.Exists( os.path.join( @@ -1054,19 +1104,22 @@ class EstimatorTest(test.TestCase): export_dir_base = os.path.join( compat.as_bytes(tmpdir), compat.as_bytes('export')) export_dir = est.export_savedmodel( - export_dir_base, serving_input_fn, assets_extra=assets_extra, + export_dir_base, + serving_input_fn, + assets_extra=assets_extra, graph_rewrite_specs=[ estimator.GraphRewriteSpec(['tag_1'], []), estimator.GraphRewriteSpec(['tag_2', 'tag_3'], - ['strip_unused_nodes'])]) + ['strip_unused_nodes']) + ]) self.assertTrue(gfile.Exists(export_dir_base)) self.assertTrue(gfile.Exists(export_dir)) self.assertTrue( gfile.Exists( os.path.join( - compat.as_bytes(export_dir), compat.as_bytes( - 'saved_model.pb')))) + compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) self.assertTrue( gfile.Exists( os.path.join( @@ -1179,18 +1232,15 @@ class InferRealValuedColumnsTest(test.TestCase): self.assertEqual(1, len(feature_columns)) feature_column = feature_columns[0] self.assertEqual('', feature_column.name) - self.assertEqual( - { - '': - parsing_ops.FixedLenFeature( - shape=expected_shape, dtype=expected_dtype) - }, - feature_column.config) + self.assertEqual({ + '': + parsing_ops.FixedLenFeature( + shape=expected_shape, dtype=expected_dtype) + }, feature_column.config) def testInt32Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.int32)) + np.ones(shape=[7, 8], dtype=np.int32)) self._assert_single_feature_column([8], dtypes.int32, feature_columns) def testInt32InputFn(self): @@ -1200,8 +1250,7 @@ class InferRealValuedColumnsTest(test.TestCase): def testInt64Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.int64)) + np.ones(shape=[7, 8], dtype=np.int64)) self._assert_single_feature_column([8], dtypes.int64, feature_columns) def testInt64InputFn(self): @@ -1211,8 +1260,7 @@ class InferRealValuedColumnsTest(test.TestCase): def testFloat32Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.float32)) + np.ones(shape=[7, 8], dtype=np.float32)) self._assert_single_feature_column([8], dtypes.float32, feature_columns) def testFloat32InputFn(self): @@ -1222,8 +1270,7 @@ class InferRealValuedColumnsTest(test.TestCase): def testFloat64Input(self): feature_columns = estimator.infer_real_valued_columns_from_input( - np.ones( - shape=[7, 8], dtype=np.float64)) + np.ones(shape=[7, 8], dtype=np.float64)) self._assert_single_feature_column([8], dtypes.float64, feature_columns) def testFloat64InputFn(self): @@ -1242,8 +1289,8 @@ class InferRealValuedColumnsTest(test.TestCase): ValueError, 'on integer or non floating types are not supported'): # pylint: disable=g-long-lambda estimator.infer_real_valued_columns_from_input_fn( - lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), - None)) + lambda: (constant_op.constant(False, shape=[7, 8], dtype=dtypes.bool), None) + ) def testStringInput(self): with self.assertRaisesRegexp( @@ -1280,8 +1327,9 @@ class ReplicaDeviceSetterTest(test.TestCase): def testVariablesAreOnPs(self): tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}} - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): + with test.mock.patch.dict('os.environ', { + 'TF_CONFIG': json.dumps(tf_config) + }): config = run_config.RunConfig() with ops.device(estimator._get_replica_device_setter(config)): @@ -1308,14 +1356,14 @@ class ReplicaDeviceSetterTest(test.TestCase): def testMutableHashTableIsOnPs(self): tf_config = {'cluster': {run_config.TaskType.PS: ['fake_ps_0']}} - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): + with test.mock.patch.dict('os.environ', { + 'TF_CONFIG': json.dumps(tf_config) + }): config = run_config.RunConfig() with ops.device(estimator._get_replica_device_setter(config)): default_val = constant_op.constant([-1, -1], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('/job:ps/task:0', table._table_ref.device) @@ -1325,8 +1373,7 @@ class ReplicaDeviceSetterTest(test.TestCase): with ops.device( estimator._get_replica_device_setter(run_config.RunConfig())): default_val = constant_op.constant([-1, -1], dtypes.int64) - table = lookup.MutableHashTable(dtypes.string, dtypes.int64, - default_val) + table = lookup.MutableHashTable(dtypes.string, dtypes.int64, default_val) input_string = constant_op.constant(['brain', 'salad', 'tank']) output = table.lookup(input_string) self.assertDeviceEqual('', table._table_ref.device) @@ -1342,8 +1389,9 @@ class ReplicaDeviceSetterTest(test.TestCase): 'index': 3 } } - with test.mock.patch.dict('os.environ', - {'TF_CONFIG': json.dumps(tf_config)}): + with test.mock.patch.dict('os.environ', { + 'TF_CONFIG': json.dumps(tf_config) + }): config = run_config.RunConfig() with ops.device(estimator._get_replica_device_setter(config)): diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py index 1d89dfb55b10b032cab7dcf434d396404d4eb83b..2113fae3940f14c8ca07e5f76986408ae8a33831 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimators_test.py @@ -22,7 +22,7 @@ import random import numpy as np -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python import learn from tensorflow.contrib.learn.python.learn import datasets from tensorflow.contrib.learn.python.learn import metric_spec @@ -62,7 +62,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["transformed_x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( @@ -72,9 +72,11 @@ class FeatureEngineeringFunctionTest(test.TestCase): # predictions = transformed_x (9) self.assertEqual(9., prediction) metrics = estimator.evaluate( - input_fn=input_fn, steps=1, - metrics={"label": - metric_spec.MetricSpec(lambda predictions, labels: labels)}) + input_fn=input_fn, + steps=1, + metrics={ + "label": metric_spec.MetricSpec(lambda predictions, labels: labels) + }) # labels = transformed_y (99) self.assertEqual(99., metrics["label"]) @@ -82,10 +84,10 @@ class FeatureEngineeringFunctionTest(test.TestCase): def input_fn(): return { - "x": constant_op.constant(["9."]) - }, { - "y": constant_op.constant(["99."]) - } + "x": constant_op.constant(["9."]) + }, { + "y": constant_op.constant(["99."]) + } def feature_engineering_fn(features, labels): # Github #12205: raise a TypeError if called twice. @@ -100,19 +102,21 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator = estimator_lib.Estimator( - model_fn=model_fn, feature_engineering_fn=feature_engineering_fn) + model_fn=model_fn, feature_engineering_fn=feature_engineering_fn) estimator.fit(input_fn=input_fn, steps=1) prediction = next(estimator.predict(input_fn=input_fn, as_iterable=True)) # predictions = transformed_x (9) self.assertEqual(9., prediction) metrics = estimator.evaluate( - input_fn=input_fn, steps=1, - metrics={"label": - metric_spec.MetricSpec(lambda predictions, labels: labels)}) + input_fn=input_fn, + steps=1, + metrics={ + "label": metric_spec.MetricSpec(lambda predictions, labels: labels) + }) # labels = transformed_y (99) self.assertEqual(99., metrics["label"]) @@ -139,7 +143,7 @@ class FeatureEngineeringFunctionTest(test.TestCase): _ = labels predictions = features["x"] loss = constant_op.constant([2.]) - update_global_step = variables.get_global_step().assign_add(1) + update_global_step = training_util.get_global_step().assign_add(1) return predictions, loss, update_global_step estimator_with_fe_fn = estimator_lib.Estimator( @@ -150,12 +154,10 @@ class FeatureEngineeringFunctionTest(test.TestCase): # predictions = x prediction_with_fe_fn = next( - estimator_with_fe_fn.predict( - input_fn=input_fn, as_iterable=True)) + estimator_with_fe_fn.predict(input_fn=input_fn, as_iterable=True)) self.assertEqual(9., prediction_with_fe_fn) prediction_without_fe_fn = next( - estimator_without_fe_fn.predict( - input_fn=input_fn, as_iterable=True)) + estimator_without_fe_fn.predict(input_fn=input_fn, as_iterable=True)) self.assertEqual(1., prediction_without_fe_fn) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index bc0e6fc0091c9b5419ab526855b404eb4a927e97..9b124b2c19f16bbc9b2afeadb82a32006e1a0ae9 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -181,7 +181,8 @@ def regression_head(label_name=None, weight_column_name=None, label_dimension=1, enable_centered_bias=False, - head_name=None): + head_name=None, + link_fn=None): """Creates a `Head` for linear regression. Args: @@ -199,6 +200,8 @@ def regression_head(label_name=None, head_name: name of the head. If provided, predictions, summary and metrics keys will be suffixed by `"/" + head_name` and the default variable scope will be `head_name`. + link_fn: link function to convert logits to predictions. If provided, + this link function will be used instead of identity. Returns: An instance of `Head` for linear regression. @@ -210,7 +213,7 @@ def regression_head(label_name=None, enable_centered_bias=enable_centered_bias, head_name=head_name, loss_fn=_mean_squared_loss, - link_fn=array_ops.identity) + link_fn=(link_fn if link_fn is not None else array_ops.identity)) def poisson_regression_head(label_name=None, diff --git a/tensorflow/contrib/learn/python/learn/estimators/head_test.py b/tensorflow/contrib/learn/python/learn/estimators/head_test.py index 3881bf533d642bef68fa9ab4ba908bbb8f7f8091..7c2d9bb0767cb979dae9c84b5342d129225677ed 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head_test.py @@ -33,6 +33,7 @@ from tensorflow.python.client import session from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import lookup_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variables from tensorflow.python.ops.losses import losses as losses_lib from tensorflow.python.platform import test @@ -153,6 +154,25 @@ class RegressionHeadTest(test.TestCase): _assert_no_variables(self) _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) + def testRegressionWithLogitFn(self): + head = head_lib.regression_head(link_fn=math_ops.square) + def _assert_preditions(test_case, expected_predictions, model_fn_ops): + variables.initialize_local_variables().run() + test_case.assertAllClose(expected_predictions, + model_fn_ops.predictions["scores"].eval()) + with ops.Graph().as_default(), session.Session(): + model_fn_ops = head.create_model_fn_ops( + {}, + labels=((0.,), (1.,), (1.,)), + mode=model_fn.ModeKeys.TRAIN, + train_op_fn=head_lib.no_op_train_fn, + logits=((1.,), (1.,), (3.,))) + self._assert_output_alternatives(model_fn_ops) + _assert_summary_tags(self, ["loss"]) + _assert_no_variables(self) + _assert_metrics(self, 5. / 3, {"loss": 5. / 3}, model_fn_ops) + _assert_preditions(self, ([1.0, 1.0, 9.0]), model_fn_ops) + def testRegressionWithInvalidLogits(self): head = head_lib.regression_head() with ops.Graph().as_default(), session.Session(): diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index 992b804f59ecd88fedc2fba10d3079f93c4fe83d..8f9d6fc318a357853bdb8e3264f6691b410006b1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -28,7 +28,7 @@ import time import numpy as np from tensorflow.contrib.factorization.python.ops import clustering_ops -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators.model_fn import ModelFnOps from tensorflow.python.framework import ops @@ -128,7 +128,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config): random_seed=params.get('random_seed'), kmeans_plus_plus_num_retries=params.get( 'kmeans_plus_plus_num_retries')).training_graph() - incr_step = state_ops.assign_add(variables.get_global_step(), 1) + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) loss = math_ops.reduce_sum(losses, name=KMeansClustering.LOSS_OP_NAME) summary.scalar('loss/raw', loss) training_op = with_dependencies([training_op, incr_step], loss) diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py index ce87b4723d436495e5fb149f0ab8f2eea44d82b8..b28835a809736a099ad2f08d127dc68d7977a3c1 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py @@ -199,15 +199,7 @@ class KMeansTest(KMeansTestBase): input_fn=self.input_fn(batch_size=self.num_points), steps=1) self.assertNear(self.true_score, score, self.true_score * 0.01) - def test_infer(self): - kmeans = self._kmeans() - # Make a call to fit to initialize the cluster centers. - max_steps = 1 - kmeans.fit(input_fn=self.input_fn(), max_steps=max_steps) - clusters = kmeans.clusters() - - # Make a small test set - num_points = 10 + def _infer_helper(self, kmeans, clusters, num_points): points, true_assignments, true_offsets = make_random_points( clusters, num_points) # Test predict @@ -231,6 +223,17 @@ class KMeansTest(KMeansTestBase): np.transpose(np.sum(np.square(clusters), axis=1, keepdims=True))) self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.fit(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.clusters() + + # Run inference on small datasets. + self._infer_helper(kmeans, clusters, num_points=10) + self._infer_helper(kmeans, clusters, num_points=1) + class KMeansTestMultiStageInit(KMeansTestBase): diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index f5445ad4e728dbd3904279573771de9454b5d17c..37aa8b339622415d082933cdf66d2472a4119b48 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -26,7 +26,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import feature_column from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib @@ -170,7 +170,7 @@ def _linear_model_fn(features, labels, mode, params, config=None): weight_collections=[parent_scope]) def _train_op_fn(loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() my_vars = ops.get_collection(parent_scope) grads = gradients.gradients(loss, my_vars) if gradient_clip_norm: @@ -252,7 +252,7 @@ def sdca_model_fn(features, labels, mode, params): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step(columns_to_variables, weight_column_name, loss_type, features, diff --git a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py index 93c62f87e8495f299a8c456574c7b40534186304..ac2d10011e222eb9c534d7fbae3c0cb5f4820945 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/logistic_regressor_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables +from tensorflow.python.training import training_util from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn.datasets import base from tensorflow.contrib.learn.python.learn.estimators import logistic_regressor @@ -57,7 +57,10 @@ def _logistic_regression_model_fn(features, labels, mode): predictions = math_ops.sigmoid(logits) loss = losses.sigmoid_cross_entropy(labels, logits) train_op = optimizers.optimize_loss( - loss, variables.get_global_step(), optimizer='Adagrad', learning_rate=0.1) + loss, + training_util.get_global_step(), + optimizer='Adagrad', + learning_rate=0.1) return predictions, loss, train_op diff --git a/tensorflow/contrib/learn/python/learn/evaluable.py b/tensorflow/contrib/learn/python/learn/evaluable.py index 66e15265171679dcd710fdf05bed3105de6bab99..8f6cd39864b437f163dd7c1140dc88755ce98529 100644 --- a/tensorflow/contrib/learn/python/learn/evaluable.py +++ b/tensorflow/contrib/learn/python/learn/evaluable.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """`Evaluable` interface.""" from __future__ import absolute_import @@ -59,9 +58,12 @@ class Evaluable(object): for which this evaluation was performed. Args: - x: Matrix of shape [n_samples, n_features...] or dictionary of many matrices - containing the input samples for fitting the model. Can be iterator that returns - arrays of features or dictionary of array of features. If set, `input_fn` must + x: Matrix of shape [n_samples, n_features...] or dictionary of many + matrices + containing the input samples for fitting the model. Can be iterator that + returns + arrays of features or dictionary of array of features. If set, + `input_fn` must be `None`. y: Vector or matrix [n_samples] or [n_samples, n_outputs] containing the label values (class labels in classification, real numbers in diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index fc4bd1f461d7bfbfcfb78201d527959055342f0a..bec976afd2719138117976381669ca3292360480 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Experiment class collecting information needed for a single training run.""" from __future__ import absolute_import @@ -35,6 +34,7 @@ from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import run_config from tensorflow.contrib.tpu.python.tpu import tpu_estimator from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator import util as estimator_util from tensorflow.python.framework import ops from tensorflow.python.platform import tf_logging as logging from tensorflow.python.training import basic_session_run_hooks @@ -42,10 +42,21 @@ from tensorflow.python.training import saver from tensorflow.python.training import server_lib from tensorflow.python.util import compat - __all__ = ["Experiment"] +def _get_standardized_predicate_fn(predicate_fn): + pred_fn_args = estimator_util.fn_args(predicate_fn) + if "checkpoint_path" not in pred_fn_args: + # pylint: disable=unused-argument + def _pred_fn_wrapper(eval_results, checkpoint_path): + return predicate_fn(eval_results) + + return _pred_fn_wrapper + else: + return predicate_fn + + class _EvalAndExportListener(basic_session_run_hooks.CheckpointSaverListener): """Listener that evaluates and exports a model after creating a checkpoint. @@ -265,8 +276,7 @@ class Experiment(object): self._train_steps_per_iteration = train_steps_per_iteration if (self._train_steps_per_iteration is not None and not isinstance(self._train_steps_per_iteration, int)): - raise ValueError( - "`train_steps_per_iteration` must be an integer.") + raise ValueError("`train_steps_per_iteration` must be an integer.") @property def estimator(self): @@ -346,9 +356,10 @@ class Experiment(object): config.cluster_spec and config.master): self._start_server() elif config.cluster_spec and config.master: - raise ValueError('For distributed runtime, Experiment class only works with' - 'tf.contrib.learn.RunConfig for now, but provided {}' - .format(type(config))) + raise ValueError( + "For distributed runtime, Experiment class only works with" + "tf.contrib.learn.RunConfig for now, but provided {}".format( + type(config))) extra_hooks = [] if delay_secs is None: @@ -401,11 +412,12 @@ class Experiment(object): logging.info("Waiting %d secs before starting eval.", delay_secs) time.sleep(delay_secs) - return self._call_evaluate(input_fn=self._eval_input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name=(name or "one_pass"), - hooks=self._eval_hooks) + return self._call_evaluate( + input_fn=self._eval_input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name=(name or "one_pass"), + hooks=self._eval_hooks) @deprecated( "2016-10-23", @@ -446,22 +458,33 @@ class Experiment(object): evaluate_checkpoint_only_once: Whether to skip evaluation of checkpoints that have already been evaluated. Default is `True`. continuous_eval_predicate_fn: A predicate function determining whether to - continue eval after each iteration. `predicate_fn` takes the evaluation - results as arguments. At the beginning of evaluation, the passed eval - results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, - continuous eval will run in an infinite loop (if `train_steps` is None) - or exit once global step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` will be None + so it's expected that the predicate function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. + export: Whether to export from this step. Default is 'True'. Raises: ValueError: if `continuous_eval_predicate_fn` is neither None nor callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None if delay_secs is None: delay_secs = self._eval_delay_secs @@ -475,13 +498,12 @@ class Experiment(object): previous_path = None eval_result = None last_warning_time = 0 - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or predicate_fn( + eval_result, checkpoint_path=previous_path if eval_result else None)): # Exit if we have already reached number of steps to train. if self._has_training_stopped(eval_result): logging.info("Exiting continuous eval, global_step=%s >= " - "train_step=%s", - eval_result[ops.GraphKeys.GLOBAL_STEP], + "train_step=%s", eval_result[ops.GraphKeys.GLOBAL_STEP], self._train_steps) return @@ -502,12 +524,13 @@ class Experiment(object): logging.warning(error_msg) last_warning_time = time.time() else: - eval_result = self._call_evaluate(input_fn=input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name=name, - checkpoint_path=latest_path, - hooks=self._eval_hooks) + eval_result = self._call_evaluate( + input_fn=input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name=name, + checkpoint_path=latest_path, + hooks=self._eval_hooks) # Ensure eval result is not None for next round of evaluation. if not eval_result: eval_result = {} @@ -532,8 +555,8 @@ class Experiment(object): return False global_step = eval_result.get(ops.GraphKeys.GLOBAL_STEP) - return global_step and self._train_steps and ( - global_step >= self._train_steps) + return global_step and self._train_steps and (global_step >= + self._train_steps) def continuous_eval(self, delay_secs=None, @@ -652,8 +675,7 @@ class Experiment(object): return eval_result, export_results @experimental - def continuous_train_and_eval(self, - continuous_eval_predicate_fn=None): + def continuous_train_and_eval(self, continuous_eval_predicate_fn=None): """Interleaves training and evaluation. The frequency of evaluation is controlled by the `train_steps_per_iteration` @@ -682,11 +704,19 @@ class Experiment(object): Args: continuous_eval_predicate_fn: A predicate function determining whether to - continue after each iteration. `predicate_fn` takes the evaluation - results as its arguments. At the beginning of evaluation, the passed - eval results will be None so it's expected that the predicate function - handles that gracefully. When `predicate_fn` is not specified, this will - run in an infinite loop or exit when global_step reaches `train_steps`. + continue eval after each iteration. A `predicate_fn` has one of the + following signatures: + * (eval_results) -> boolean + * (eval_results, checkpoint_path) -> boolean + Where `eval_results` is the dictionary of metric evaluations and + checkpoint_path is the path to the checkpoint containing the parameters + on which that evaluation was based. + At the beginning of evaluation, the passed `eval_results` and + `checkpoint_path` will be None so it's expected that the predicate + function handles that gracefully. + When `predicate_fn` is not specified, continuous eval will run in an + infinite loop (if `train_steps` is None). or exit once global step + reaches `train_steps`. Returns: A tuple of the result of the `evaluate` call to the `Estimator` and the @@ -697,13 +727,18 @@ class Experiment(object): callable. """ - if (continuous_eval_predicate_fn is not None and - not callable(continuous_eval_predicate_fn)): - raise ValueError( - "`continuous_eval_predicate_fn` must be a callable, or None.") + if continuous_eval_predicate_fn is not None: + if not callable(continuous_eval_predicate_fn): + raise ValueError( + "`continuous_eval_predicate_fn` must be a callable, or None.") + predicate_fn = _get_standardized_predicate_fn( + continuous_eval_predicate_fn) + else: + predicate_fn = None - eval_result = None export_results = None + latest_checkpoint = None + eval_result = None # Set the default value for train_steps_per_iteration, which will be # overridden by other settings. @@ -713,8 +748,9 @@ class Experiment(object): elif self._train_steps is not None: train_steps_per_iteration = int(self._train_steps / 10) - while (not continuous_eval_predicate_fn or - continuous_eval_predicate_fn(eval_result)): + while (not predicate_fn or predicate_fn( + eval_result, checkpoint_path=latest_checkpoint + if eval_result else None)): if self._has_training_stopped(eval_result): # Exits once max steps of training is satisfied. @@ -729,11 +765,14 @@ class Experiment(object): saving_listeners=self._saving_listeners) logging.info("Evaluating model now.") - eval_result = self._call_evaluate(input_fn=self._eval_input_fn, - steps=self._eval_steps, - metrics=self._eval_metrics, - name="one_pass", - hooks=self._eval_hooks) + latest_checkpoint = saver.latest_checkpoint(self._estimator.model_dir) + eval_result = self._call_evaluate( + input_fn=self._eval_input_fn, + steps=self._eval_steps, + metrics=self._eval_metrics, + name="one_pass", + checkpoint_path=latest_checkpoint, + hooks=self._eval_hooks) export_results = self._maybe_export(eval_result) return eval_result, export_results @@ -741,8 +780,7 @@ class Experiment(object): def _maybe_export(self, eval_result, checkpoint_path=None): """Export the Estimator using export_fn, if defined.""" export_dir_base = os.path.join( - compat.as_bytes(self._estimator.model_dir), - compat.as_bytes("export")) + compat.as_bytes(self._estimator.model_dir), compat.as_bytes("export")) export_results = [] for strategy in self._export_strategies: @@ -780,10 +818,11 @@ class Experiment(object): hooks=self._train_monitors, saving_listeners=self._saving_listeners) - eval_result = self._call_evaluate(input_fn=self._eval_input_fn, - steps=1, - metrics=self._eval_metrics, - name="one_pass") + eval_result = self._call_evaluate( + input_fn=self._eval_input_fn, + steps=1, + metrics=self._eval_metrics, + name="one_pass") _ = self._maybe_export(eval_result) return eval_result @@ -805,9 +844,14 @@ class Experiment(object): server.start() return server - def _call_train(self, _sentinel=None, # pylint: disable=invalid-name, - input_fn=None, steps=None, hooks=None, max_steps=None, - saving_listeners=None): + def _call_train( + self, + _sentinel=None, # pylint: disable=invalid-name, + input_fn=None, + steps=None, + hooks=None, + max_steps=None, + saving_listeners=None): if _sentinel is not None: raise ValueError("_call_train should be called with keyword args only") @@ -823,14 +867,18 @@ class Experiment(object): hooks=hooks, saving_listeners=saving_listeners) else: - return self._estimator.fit(input_fn=input_fn, - steps=steps, - max_steps=max_steps, - monitors=hooks) - - def _call_evaluate(self, _sentinel=None, # pylint: disable=invalid-name, - input_fn=None, steps=None, metrics=None, name=None, - checkpoint_path=None, hooks=None): + return self._estimator.fit( + input_fn=input_fn, steps=steps, max_steps=max_steps, monitors=hooks) + + def _call_evaluate( + self, + _sentinel=None, # pylint: disable=invalid-name, + input_fn=None, + steps=None, + metrics=None, + name=None, + checkpoint_path=None, + hooks=None): if _sentinel is not None: raise ValueError("_call_evaluate should be called with keyword args only") @@ -838,18 +886,20 @@ class Experiment(object): if metrics is not None: raise ValueError( "`eval_metrics` must be `None` with `tf.estimator.Estimator`") - return self._estimator.evaluate(input_fn=input_fn, - steps=steps, - name=name, - checkpoint_path=checkpoint_path, - hooks=hooks) + return self._estimator.evaluate( + input_fn=input_fn, + steps=steps, + name=name, + checkpoint_path=checkpoint_path, + hooks=hooks) else: - return self._estimator.evaluate(input_fn=input_fn, - steps=steps, - metrics=metrics, - name=name, - checkpoint_path=checkpoint_path, - hooks=hooks) + return self._estimator.evaluate( + input_fn=input_fn, + steps=steps, + metrics=metrics, + name=name, + checkpoint_path=checkpoint_path, + hooks=hooks) @contextlib.contextmanager diff --git a/tensorflow/contrib/learn/python/learn/experiment_test.py b/tensorflow/contrib/learn/python/learn/experiment_test.py index c29c198d094090a59c8c7dd2949c3f069adf49d0..545d7d8924c0c10544e6113e2968b7ae3d2090fc 100644 --- a/tensorflow/contrib/learn/python/learn/experiment_test.py +++ b/tensorflow/contrib/learn/python/learn/experiment_test.py @@ -492,6 +492,33 @@ class ExperimentTest(test.TestCase): self.assertEqual(3, est.eval_count) self.assertEqual([noop_hook], est.eval_hooks) + def test_continuous_eval_predicate_fn_with_checkpoint(self): + for est in self._estimators_for_tests(): + eval_metrics = 'eval_metrics' if not isinstance( + est, core_estimator.Estimator) else None + est.fake_checkpoint() + noop_hook = _NoopHook() + + def _predicate_fn(eval_result, checkpoint_path): + self.assertEqual(not eval_result, + checkpoint_path is None) + return est.eval_count < 3 # pylint: disable=cell-var-from-loop + + ex = experiment.Experiment( + est, + train_input_fn='train_input', + eval_input_fn='eval_input', + eval_metrics=eval_metrics, + eval_hooks=[noop_hook], + eval_delay_secs=0, + continuous_eval_throttle_secs=0) + ex.continuous_eval( + evaluate_checkpoint_only_once=False, + continuous_eval_predicate_fn=_predicate_fn) + self.assertEqual(0, est.fit_count) + self.assertEqual(3, est.eval_count) + self.assertEqual([noop_hook], est.eval_hooks) + def test_run_local(self): for est in self._estimators_for_tests(): eval_metrics = 'eval_metrics' if not isinstance( diff --git a/tensorflow/contrib/learn/python/learn/export_strategy.py b/tensorflow/contrib/learn/python/learn/export_strategy.py index f276aab0e6beb011a21c20fa194dd5212db796d1..55a8b824312b89e0ac66513242191f4201ac212a 100644 --- a/tensorflow/contrib/learn/python/learn/export_strategy.py +++ b/tensorflow/contrib/learn/python/learn/export_strategy.py @@ -26,13 +26,14 @@ __all__ = ['ExportStrategy'] class ExportStrategy( - collections.namedtuple('ExportStrategy', ['name', 'export_fn'])): + collections.namedtuple('ExportStrategy', + ['name', 'export_fn', 'strip_default_attrs'])): """A class representing a type of model export. Typically constructed by a utility function specific to the exporter, such as `saved_model_export_utils.make_export_strategy()`. - The fields are: + Attributes: name: The directory name under the export base directory where exports of this type will be written. export_fn: A function that writes an export, given an estimator, a @@ -45,11 +46,20 @@ class ExportStrategy( The signature of this function must be one of: - * `(estimator, export_path) -> export_path` - * `(estimator, export_path, checkpoint_path) -> export_path` - * `(estimator, export_path, checkpoint_path, eval_result) -> export_path` + * `(estimator, export_path) -> export_path` + * `(estimator, export_path, checkpoint_path) -> export_path` + * `(estimator, export_path, checkpoint_path, eval_result) -> export_path` + * `(estimator, export_path, checkpoint_path, eval_result, + strip_default_attrs) -> export_path` + strip_default_attrs: (Optional) Boolean. If set as True, default attrs in + the `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. """ + def __new__(cls, name, export_fn, strip_default_attrs=None): + return super(ExportStrategy, cls).__new__( + cls, name, export_fn, strip_default_attrs) + def export(self, estimator, export_path, @@ -83,5 +93,6 @@ class ExportStrategy( raise ValueError('An export_fn accepting eval_result must also accept ' 'checkpoint_path.') kwargs['eval_result'] = eval_result - + if 'strip_default_attrs' in export_fn_args: + kwargs['strip_default_attrs'] = self.strip_default_attrs return self.export_fn(estimator, export_path, **kwargs) diff --git a/tensorflow/contrib/learn/python/learn/export_strategy_test.py b/tensorflow/contrib/learn/python/learn/export_strategy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..43c3551cccc3b8e6b66bd2b36839a3dfc5fe8eea --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/export_strategy_test.py @@ -0,0 +1,89 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for ExportStrategy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn import export_strategy +from tensorflow.python.platform import test + + +class ExportStrategyTest(test.TestCase): + + def test_no_optional_args_export(self): + model_path = '/path/to/model' + def _export_fn(estimator, export_path): + self.assertTupleEqual((estimator, export_path), (None, None)) + return model_path + + strategy = export_strategy.ExportStrategy('foo', _export_fn) + self.assertTupleEqual(strategy, ('foo', _export_fn, None)) + self.assertIs(strategy.export(None, None), model_path) + + def test_checkpoint_export(self): + ckpt_model_path = '/path/to/checkpoint_model' + def _ckpt_export_fn(estimator, export_path, checkpoint_path): + self.assertTupleEqual((estimator, export_path), (None, None)) + self.assertEqual(checkpoint_path, 'checkpoint') + return ckpt_model_path + + strategy = export_strategy.ExportStrategy('foo', _ckpt_export_fn) + self.assertTupleEqual(strategy, ('foo', _ckpt_export_fn, None)) + self.assertIs(strategy.export(None, None, 'checkpoint'), ckpt_model_path) + + def test_checkpoint_eval_export(self): + ckpt_eval_model_path = '/path/to/checkpoint_eval_model' + def _ckpt_eval_export_fn(estimator, export_path, checkpoint_path, + eval_result): + self.assertTupleEqual((estimator, export_path), (None, None)) + self.assertEqual(checkpoint_path, 'checkpoint') + self.assertEqual(eval_result, 'eval') + return ckpt_eval_model_path + + strategy = export_strategy.ExportStrategy('foo', _ckpt_eval_export_fn) + self.assertTupleEqual(strategy, ('foo', _ckpt_eval_export_fn, None)) + self.assertIs(strategy.export(None, None, 'checkpoint', 'eval'), + ckpt_eval_model_path) + + def test_eval_only_export(self): + def _eval_export_fn(estimator, export_path, eval_result): + del estimator, export_path, eval_result + + strategy = export_strategy.ExportStrategy('foo', _eval_export_fn) + self.assertTupleEqual(strategy, ('foo', _eval_export_fn, None)) + with self.assertRaisesRegexp(ValueError, 'An export_fn accepting ' + 'eval_result must also accept ' + 'checkpoint_path'): + strategy.export(None, None, eval_result='eval') + + def test_strip_default_attr_export(self): + strip_default_attrs_model_path = '/path/to/strip_default_attrs_model' + def _strip_default_attrs_export_fn(estimator, export_path, + strip_default_attrs): + self.assertTupleEqual((estimator, export_path), (None, None)) + self.assertTrue(strip_default_attrs) + return strip_default_attrs_model_path + + strategy = export_strategy.ExportStrategy('foo', + _strip_default_attrs_export_fn, + True) + self.assertTupleEqual(strategy, + ('foo', _strip_default_attrs_export_fn, True)) + self.assertIs(strategy.export(None, None), strip_default_attrs_model_path) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 86fad4c5535a918d87e0741687cfebe3afaf9ddf..96be8b1bc402479d5611965f27abb197363cb939 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -35,6 +35,7 @@ from tensorflow.python.platform import tf_logging as logging # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels + # pylint: enable=g-multiple-import,g-bad-import-order @@ -74,11 +75,11 @@ def _get_in_out_shape(x_shape, y_shape, n_classes, batch_size=None): if not y_is_dict: output_shape = out_el_shape(y_shape, n_classes) else: - output_shape = dict([ - (k, out_el_shape(v, n_classes[k] - if n_classes is not None and k in n_classes else None)) - for k, v in list(y_shape.items()) - ]) + output_shape = dict([(k, + out_el_shape(v, n_classes[k] + if n_classes is not None and + k in n_classes else None)) + for k, v in list(y_shape.items())]) return input_shape, output_shape, batch_size @@ -314,23 +315,23 @@ class DataFeeder(object): input_dtype: DType of input (or dictionary of shapes). output_dtype: DType of output (or dictionary of shapes. """ - x_is_dict, y_is_dict = isinstance(x, dict), y is not None and isinstance( - y, dict) + x_is_dict, y_is_dict = isinstance( + x, dict), y is not None and isinstance(y, dict) if isinstance(y, list): y = np.array(y) self._x = dict([(k, check_array(v, v.dtype)) for k, v in list(x.items()) ]) if x_is_dict else check_array(x, x.dtype) - self._y = None if y is None else ( - dict([(k, check_array(v, v.dtype)) for k, v in list(y.items())]) - if y_is_dict else check_array(y, y.dtype)) + self._y = None if y is None else (dict( + [(k, check_array(v, v.dtype)) for k, v in list(y.items())]) + if y_is_dict else check_array(y, y.dtype)) # self.n_classes is not None means we're converting raw target indices # to one-hot. if n_classes is not None: if not y_is_dict: - y_dtype = (np.int64 - if n_classes is not None and n_classes > 1 else np.float32) + y_dtype = ( + np.int64 if n_classes is not None and n_classes > 1 else np.float32) self._y = (None if y is None else check_array(y, dtype=y_dtype)) self.n_classes = n_classes @@ -352,8 +353,8 @@ class DataFeeder(object): # self._output_dtype == np.float32 when y is None self._output_dtype = ( dict([(k, _check_dtype(v.dtype)) for k, v in list(self._y.items())]) - if y_is_dict else ( - _check_dtype(self._y.dtype) if y is not None else np.float32)) + if y_is_dict else (_check_dtype(self._y.dtype) + if y is not None else np.float32)) # self.n_classes is None means we're passing in raw target indices if n_classes is not None and y_is_dict: @@ -478,8 +479,8 @@ class DataFeeder(object): # Assign input features from random indices. def extract(data, indices): - return (np.array(_access(data, indices)).reshape((indices.shape[0], 1)) if - len(data.shape) == 1 else _access(data, indices)) + return (np.array(_access(data, indices)).reshape((indices.shape[0], 1)) + if len(data.shape) == 1 else _access(data, indices)) # assign labels from random indices def assign_label(data, shape, dtype, n_classes, indices): @@ -511,16 +512,18 @@ class DataFeeder(object): feed_dict[self._epoch_placeholder.name] = [self.epoch] # Take next batch of indices. - x_len = list(self._x.values())[0].shape[ - 0] if x_is_dict else self._x.shape[0] + x_len = list( + self._x.values())[0].shape[0] if x_is_dict else self._x.shape[0] end = min(x_len, self.offset + self._batch_size) batch_indices = self.indices[self.offset:end] # adding input placeholder feed_dict.update( dict([(self._input_placeholder[k].name, extract(v, batch_indices)) - for k, v in list(self._x.items())]) if x_is_dict else - {self._input_placeholder.name: extract(self._x, batch_indices)}) + for k, v in list(self._x.items())]) if x_is_dict else { + self._input_placeholder.name: + extract(self._x, batch_indices) + }) # move offset and reset it if necessary self.offset += self._batch_size @@ -545,7 +548,8 @@ class DataFeeder(object): assign_label(v, shape, dtype, n_classes, batch_indices) }) else: - shape, dtype, n_classes = self.output_shape, self._output_dtype, self.n_classes + shape, dtype, n_classes = (self.output_shape, self._output_dtype, + self.n_classes) feed_dict.update({ self._output_placeholder.name: assign_label(self._y, shape, dtype, n_classes, batch_indices) @@ -621,8 +625,9 @@ class StreamingDataFeeder(DataFeeder): elif y is None: y_first_el_shape = None else: - y_first_el_shape = ([1] + list(y_first_el[0].shape if isinstance( - y_first_el, list) else y_first_el.shape)) + y_first_el_shape = ( + [1] + list(y_first_el[0].shape + if isinstance(y_first_el, list) else y_first_el.shape)) self.input_shape, self.output_shape, self._batch_size = _get_in_out_shape( x_first_el_shape, y_first_el_shape, n_classes, batch_size) @@ -683,8 +688,8 @@ class StreamingDataFeeder(DataFeeder): if shape is None: return None elif isinstance(shape, dict): - return dict([(k, np.zeros(shape[k], dtype[k])) - for k in list(shape.keys())]) + return dict( + [(k, np.zeros(shape[k], dtype[k])) for k in list(shape.keys())]) else: return np.zeros(shape, dtype=dtype) @@ -857,8 +862,8 @@ class DaskDataFeeder(object): """Returns a function, that will sample data and provide it to placeholders. Args: - input_placeholder: tf.Placeholder for input features mini batch. - output_placeholder: tf.Placeholder for output labels. + input_placeholder: tf.placeholder for input features mini batch. + output_placeholder: tf.placeholder for output labels. Returns: A function that when called samples a random subset of batch size diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index 4b34fc62849766370979bb2002d42ee03ea7161a..3a46c239688017f9204d2c6182a6f81cd325a417 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import io_ops @@ -280,14 +281,33 @@ def _get_file_names(file_pattern, randomize_input): def _get_examples(file_name_queue, reader, num_threads, read_batch_size, filter_fn, parse_fn): + """Get example filenames matching. + + Args: + file_name_queue: A queue implementation that dequeues elements in + first-in first-out order. + reader: A function or class that returns an object with + `read` method, (filename tensor) -> (example tensor). + num_threads: The number of threads enqueuing examples. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. + filter_fn: Filtering function, takes both keys as well as an `Example` + Tensors and returns a boolean mask of the same shape as the input Tensors + to be applied for filtering. If `None`, no filtering is done. + parse_fn: Parsing function, takes `Example` Tensor returns parsed + representation. If `None`, no parsing is done. + + Returns: + List of example file names matching `file_name_queue`. + """ with ops.name_scope('read'): example_list = [] for _ in range(num_threads): - if read_batch_size > 1: - keys, examples_proto = reader().read_up_to(file_name_queue, - read_batch_size) - else: - keys, examples_proto = reader().read(file_name_queue) + keys, examples_proto = utils.smart_cond( + read_batch_size > 1, + lambda: reader().read_up_to(file_name_queue, read_batch_size), + lambda: reader().read(file_name_queue)) + if filter_fn: mask = filter_fn(keys, examples_proto) keys = array_ops.boolean_mask(keys, mask) @@ -379,14 +399,15 @@ def _read_keyed_batch_examples_helper(file_pattern, capacity=1, dtypes=[dtypes.string], shapes=[[]]) enqueue_op = file_name_queue.enqueue( input_pipeline_ops.seek_next( - file_names, shuffle=randomize_input, num_epochs=num_epochs, + file_names, + shuffle=randomize_input, + num_epochs=num_epochs, seed=seed)) queue_runner.add_queue_runner( queue_runner.QueueRunner(file_name_queue, [enqueue_op])) else: file_name_queue = input_ops.string_input_producer( - constant_op.constant( - file_names, name='input'), + constant_op.constant(file_names, name='input'), shuffle=randomize_input, num_epochs=num_epochs, name=file_name_queue_scope, @@ -496,7 +517,8 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: - if read_batch_size is None: read_batch_size = batch_size + if read_batch_size is None: + read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py index 6f0fd9a2976d37d1c701a96f50c2b987562cb191..e11e8b698adc113486bbb45572c8129e964cc931 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py @@ -204,8 +204,7 @@ class GraphIOTest(test.TestCase): shape = (0,) features = { "feature": - parsing_ops.FixedLenFeature( - shape=shape, dtype=dtypes_lib.float32) + parsing_ops.FixedLenFeature(shape=shape, dtype=dtypes_lib.float32) } with ops.Graph().as_default() as g, self.test_session(graph=g) as sess: @@ -255,8 +254,8 @@ class GraphIOTest(test.TestCase): self.assertAllEqual((None,), inputs.get_shape().as_list()) self.assertEqual("%s:1" % name, inputs.name) file_name_queue_name = "%s/file_name_queue" % name - file_name_queue_limit_name = ("%s/limit_epochs/epochs" % - file_name_queue_name) + file_name_queue_limit_name = ( + "%s/limit_epochs/epochs" % file_name_queue_name) file_names_name = "%s/input" % file_name_queue_name example_queue_name = "%s/random_shuffle_queue" % name op_nodes = test_util.assert_ops_in_graph({ @@ -354,8 +353,8 @@ class GraphIOTest(test.TestCase): json_lines = [ "".join([ '{"features": { "feature": { "sequence": {', - '"bytes_list": { "value": ["', base64.b64encode(l).decode("ascii"), - '"]}}}}}\n' + '"bytes_list": { "value": ["', + base64.b64encode(l).decode("ascii"), '"]}}}}}\n' ]) for l in lines ] return self._create_temp_file("".join(json_lines)) @@ -823,6 +822,31 @@ class GraphIOTest(test.TestCase): coord.request_stop() coord.join(threads) + def test_read_keyed_batch_features_shared_queue(self): + batch_size = 17 + shape = (0,) + fixed_feature = parsing_ops.FixedLenFeature( + shape=shape, dtype=dtypes_lib.float32) + feature = {"feature": fixed_feature} + reader = io_ops.TFRecordReader + + _, queued_feature = graph_io.read_keyed_batch_features_shared_queue( + _VALID_FILE_PATTERN, batch_size, feature, reader) + + with ops.Graph().as_default() as g, self.test_session(graph=g) as session: + features_result = graph_io.read_batch_features( + _VALID_FILE_PATTERN, batch_size, feature, reader) + session.run(variables.local_variables_initializer()) + + self.assertAllEqual( + queued_feature.get("feature").get_shape().as_list(), + features_result.get("feature").get_shape().as_list()) + + def test_get_file_names_errors(self): + # Raise bad file_pattern. + with self.assertRaises(ValueError): + graph_io._get_file_names([], True) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py b/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py deleted file mode 100644 index 6fe8de8705b8854e5861879d2a505fe03fddc7e5..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/learn/python/learn/learn_io/numpy_io_test.py +++ /dev/null @@ -1,280 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for numpy_io.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.learn.python.learn.learn_io import numpy_io -from tensorflow.python.framework import errors -from tensorflow.python.platform import test -from tensorflow.python.training import coordinator -from tensorflow.python.training import queue_runner_impl - - -class NumpyIoTest(test.TestCase): - - def testNumpyInputFn(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - session.run([features, target]) - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithVeryLargeBatchSizeAndMultipleEpochs(self): - a = np.arange(2) * 1.0 - b = np.arange(32, 34) - x = {'a': a, 'b': b} - y = np.arange(-32, -30) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=128, shuffle=False, num_epochs=2) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1, 0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33, 32, 33]) - self.assertAllEqual(res[1], [-32, -31, -32, -31]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithZeroEpochs(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=0) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeNotDividedByDataSize(self): - batch_size = 2 - a = np.arange(5) * 1.0 - b = np.arange(32, 37) - x = {'a': a, 'b': b} - y = np.arange(-32, -27) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2, 3]) - self.assertAllEqual(res[0]['b'], [34, 35]) - self.assertAllEqual(res[1], [-30, -29]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [4]) - self.assertAllEqual(res[0]['b'], [36]) - self.assertAllEqual(res[1], [-28]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeNotDividedByDataSizeAndMultipleEpochs(self): - batch_size = 2 - a = np.arange(3) * 1.0 - b = np.arange(32, 35) - x = {'a': a, 'b': b} - y = np.arange(-32, -29) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=3) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2, 0]) - self.assertAllEqual(res[0]['b'], [34, 32]) - self.assertAllEqual(res[1], [-30, -32]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [1, 2]) - self.assertAllEqual(res[0]['b'], [33, 34]) - self.assertAllEqual(res[1], [-31, -30]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1]) - self.assertAllEqual(res[0]['b'], [32, 33]) - self.assertAllEqual(res[1], [-32, -31]) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [2]) - self.assertAllEqual(res[0]['b'], [34]) - self.assertAllEqual(res[1], [-30]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithBatchSizeLargerThanDataSize(self): - batch_size = 10 - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - y = np.arange(-32, -28) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=batch_size, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [0, 1, 2, 3]) - self.assertAllEqual(res[0]['b'], [32, 33, 34, 35]) - self.assertAllEqual(res[1], [-32, -31, -30, -29]) - - with self.assertRaises(errors.OutOfRangeError): - session.run([features, target]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithDifferentDimensionsOfFeatures(self): - a = np.array([[1, 2], [3, 4]]) - b = np.array([5, 6]) - x = {'a': a, 'b': b} - y = np.arange(-32, -30) - - with self.test_session() as session: - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - features, target = input_fn() - - coord = coordinator.Coordinator() - threads = queue_runner_impl.start_queue_runners(session, coord=coord) - - res = session.run([features, target]) - self.assertAllEqual(res[0]['a'], [[1, 2], [3, 4]]) - self.assertAllEqual(res[0]['b'], [5, 6]) - self.assertAllEqual(res[1], [-32, -31]) - - coord.request_stop() - coord.join(threads) - - def testNumpyInputFnWithXAsNonDict(self): - x = np.arange(32, 36) - y = np.arange(4) - with self.test_session(): - with self.assertRaisesRegexp(TypeError, 'x must be dict'): - failing_input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - failing_input_fn() - - def testNumpyInputFnWithTargetKeyAlreadyInX(self): - array = np.arange(32, 36) - x = {'__target_key__': array} - y = np.arange(4) - - with self.test_session(): - input_fn = numpy_io.numpy_input_fn( - x, y, batch_size=2, shuffle=False, num_epochs=1) - input_fn() - self.assertAllEqual(x['__target_key__'], array) - self.assertItemsEqual(x.keys(), ['__target_key__']) - - def testNumpyInputFnWithMismatchLengthOfInputs(self): - a = np.arange(4) * 1.0 - b = np.arange(32, 36) - x = {'a': a, 'b': b} - x_mismatch_length = {'a': np.arange(1), 'b': b} - y_longer_length = np.arange(10) - - with self.test_session(): - with self.assertRaisesRegexp( - ValueError, 'Length of tensors in x and y is mismatched.'): - failing_input_fn = numpy_io.numpy_input_fn( - x, y_longer_length, batch_size=2, shuffle=False, num_epochs=1) - failing_input_fn() - - with self.assertRaisesRegexp( - ValueError, 'Length of tensors in x and y is mismatched.'): - failing_input_fn = numpy_io.numpy_input_fn( - x=x_mismatch_length, - y=None, - batch_size=2, - shuffle=False, - num_epochs=1) - failing_input_fn() - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/learn/python/learn/metric_spec.py b/tensorflow/contrib/learn/python/learn/metric_spec.py index ed6683abedbb8ae76ba364405158eb52cbb6d762..6440bc204b8e339ff51311dcc87b36f556b94092 100644 --- a/tensorflow/contrib/learn/python/learn/metric_spec.py +++ b/tensorflow/contrib/learn/python/learn/metric_spec.py @@ -42,10 +42,8 @@ def _args(fn): """ if hasattr(fn, 'func') and hasattr(fn, 'keywords'): # Handle functools.partial and similar objects. - return tuple([ - arg for arg in tf_inspect.getargspec(fn.func).args - if arg not in set(fn.keywords.keys()) - ]) + return tuple( + [arg for arg in _args(fn.func) if arg not in set(fn.keywords.keys())]) # Handle function. return tuple(tf_inspect.getargspec(fn).args) diff --git a/tensorflow/contrib/learn/python/learn/monitors.py b/tensorflow/contrib/learn/python/learn/monitors.py index 3e0b1ad21a9a4a08fa94c8e9796f2b0dd5f8d622..51381a7427c919592b8e818c4b46dba974992610 100644 --- a/tensorflow/contrib/learn/python/learn/monitors.py +++ b/tensorflow/contrib/learn/python/learn/monitors.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Monitors instrument the training process. @@get_default_monitors @@ -151,8 +150,8 @@ class BaseMonitor(object): ValueError: if we've not begun an epoch, or `epoch` number does not match. """ if self._current_epoch != epoch: - raise ValueError( - "epoch_end expected %s but got %s.", self._current_epoch, epoch) + raise ValueError("epoch_end expected %s but got %s.", self._current_epoch, + epoch) self._current_epoch = None def step_begin(self, step): @@ -171,8 +170,8 @@ class BaseMonitor(object): ValueError: if we've already begun a step, or `step` < 0, or `step` > `max_steps`. """ - if (step < 0) or ( - (self._max_steps is not None) and (step > self._max_steps)): + if (step < 0) or ((self._max_steps is not None) and + (step > self._max_steps)): raise ValueError("Invalid step %s." % step) self._current_step = step return [] @@ -203,8 +202,8 @@ class BaseMonitor(object): ValueError: if we've not begun a step, or `step` number does not match. """ if self._current_step != step: - raise ValueError( - "step_end expected %s but got %s.", self._current_step, step) + raise ValueError("step_end expected %s but got %s.", self._current_step, + step) self._current_step = None return False @@ -253,6 +252,7 @@ class EveryN(BaseMonitor): treatment. """ + # TODO(ipolosukhin): Add also every n seconds. def __init__(self, every_n_steps=100, first_n_steps=1): @@ -475,8 +475,8 @@ class LoggingTrainable(EveryN): super(LoggingTrainable, self).every_n_step_begin(step) # Get a list of trainable variables at the beginning of every N steps. # We cannot get this in __init__ because train_op has not been generated. - trainables = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, - scope=self._scope) + trainables = ops.get_collection( + ops.GraphKeys.TRAINABLE_VARIABLES, scope=self._scope) self._names = {} for var in trainables: self._names[var.name] = var.value().name @@ -561,12 +561,19 @@ class ValidationMonitor(EveryN): provided. """ - def __init__(self, x=None, y=None, input_fn=None, batch_size=None, + def __init__(self, + x=None, + y=None, + input_fn=None, + batch_size=None, eval_steps=None, - every_n_steps=100, metrics=None, hooks=None, + every_n_steps=100, + metrics=None, + hooks=None, early_stopping_rounds=None, early_stopping_metric="loss", - early_stopping_metric_minimize=True, name=None): + early_stopping_metric_minimize=True, + name=None): """Initializes a ValidationMonitor. Args: @@ -597,8 +604,8 @@ class ValidationMonitor(EveryN): Raises: ValueError: If both x and input_fn are provided. """ - super(ValidationMonitor, self).__init__(every_n_steps=every_n_steps, - first_n_steps=-1) + super(ValidationMonitor, self).__init__( + every_n_steps=every_n_steps, first_n_steps=-1) # TODO(mdan): Checks like this are already done by evaluate. if x is None and input_fn is None: raise ValueError("Either x or input_fn should be provided.") @@ -654,20 +661,27 @@ class ValidationMonitor(EveryN): def _evaluate_estimator(self): if isinstance(self._estimator, core_estimator.Estimator): - if any((x is not None for x in - [self.x, self.y, self.batch_size, self.metrics])): + if any((x is not None + for x in [self.x, self.y, self.batch_size, self.metrics])): raise ValueError( "tf.estimator.Estimator does not support following " "arguments: x, y, batch_size, metrics. Should set as `None` " "in ValidationMonitor") return self._estimator.evaluate( - input_fn=self.input_fn, steps=self.eval_steps, hooks=self.hooks, + input_fn=self.input_fn, + steps=self.eval_steps, + hooks=self.hooks, name=self.name) else: return self._estimator.evaluate( - x=self.x, y=self.y, input_fn=self.input_fn, - batch_size=self.batch_size, steps=self.eval_steps, - metrics=self.metrics, hooks=self.hooks, name=self.name) + x=self.x, + y=self.y, + input_fn=self.input_fn, + batch_size=self.batch_size, + steps=self.eval_steps, + metrics=self.metrics, + hooks=self.hooks, + name=self.name) def every_n_step_end(self, step, outputs): super(ValidationMonitor, self).every_n_step_end(step, outputs) @@ -700,8 +714,9 @@ class ValidationMonitor(EveryN): # Early stopping logic. if self.early_stopping_rounds is not None: if self.early_stopping_metric not in validation_outputs: - raise ValueError("Metric %s missing from outputs %s." % ( - self.early_stopping_metric, set(validation_outputs.keys()))) + raise ValueError("Metric %s missing from outputs %s." % + (self.early_stopping_metric, + set(validation_outputs.keys()))) current_value = validation_outputs[self.early_stopping_metric] if (self._best_value is None or (self.early_stopping_metric_minimize and (current_value < self._best_value)) or @@ -712,9 +727,9 @@ class ValidationMonitor(EveryN): self._best_value_step = step stop_now = (step - self._best_value_step >= self.early_stopping_rounds) if stop_now: - logging.info("Stopping. Best step: {} with {} = {}." - .format(self._best_value_step, - self.early_stopping_metric, self._best_value)) + logging.info("Stopping. Best step: {} with {} = {}.".format( + self._best_value_step, self.early_stopping_metric, + self._best_value)) self._early_stopped = True return True return False @@ -763,8 +778,11 @@ class CaptureVariable(EveryN): self._var_values[step] = _extract_output(outputs, self._var_name) -def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, - output_dir=None, summary_writer=None): +def get_default_monitors(loss_op=None, + summary_op=None, + save_summary_steps=100, + output_dir=None, + summary_writer=None): """Returns a default set of typically-used monitors. Args: @@ -782,9 +800,12 @@ def get_default_monitors(loss_op=None, summary_op=None, save_summary_steps=100, if loss_op is not None: monitors.append(PrintTensor(tensor_names={"loss": loss_op.name})) if summary_op is not None: - monitors.append(SummarySaver(summary_op, save_steps=save_summary_steps, - output_dir=output_dir, - summary_writer=summary_writer)) + monitors.append( + SummarySaver( + summary_op, + save_steps=save_summary_steps, + output_dir=output_dir, + summary_writer=summary_writer)) return monitors @@ -794,8 +815,10 @@ class GraphDump(BaseMonitor): Note, this is very expensive, prefer `PrintTensor` in production. """ - IGNORE_OPS = ["Const", "Assign", "Identity", "Placeholder", - "RandomUniform", "Cast", "RestoreSlice"] + IGNORE_OPS = [ + "Const", "Assign", "Identity", "Placeholder", "RandomUniform", "Cast", + "RestoreSlice" + ] def __init__(self, ignore_ops=None): """Initializes GraphDump monitor. @@ -856,7 +879,7 @@ class GraphDump(BaseMonitor): this_output = self.data[step] if step in self.data else {} other_output = other_dump.data[step] if step in other_dump.data else {} for key in this_output: - if not isinstance(key, str) and not isinstance(key, unicode): + if not isinstance(key, six.string_types): continue if key not in other_output: raise ValueError("%s missing at step %s.", (key, step)) @@ -881,8 +904,8 @@ class ExportMonitor(EveryN): """Monitor that exports Estimator every N steps.""" @deprecation.deprecated("2017-03-25", - "ExportMonitor is deprecated. Please pass an " - "ExportStrategy to Experiment instead.") + "ExportMonitor is deprecated. Please pass an " + "ExportStrategy to Experiment instead.") def __init__(self, every_n_steps, export_dir, @@ -1088,8 +1111,7 @@ class CheckpointSaver(BaseMonitor): class StepCounter(EveryN): """Steps per second monitor.""" - def __init__(self, every_n_steps=100, output_dir=None, - summary_writer=None): + def __init__(self, every_n_steps=100, output_dir=None, summary_writer=None): super(StepCounter, self).__init__(every_n_steps=every_n_steps) self._summary_tag = "global_step/sec" self._last_reported_step = None @@ -1101,7 +1123,8 @@ class StepCounter(EveryN): def set_estimator(self, estimator): super(StepCounter, self).set_estimator(estimator) if self._summary_writer is None: - self._summary_writer = core_summary.FileWriterCache.get(estimator.model_dir) + self._summary_writer = core_summary.FileWriterCache.get( + estimator.model_dir) def every_n_step_end(self, current_step, outputs): current_time = time.time() @@ -1109,8 +1132,9 @@ class StepCounter(EveryN): added_steps = current_step - self._last_reported_step elapsed_time = current_time - self._last_reported_time steps_per_sec = added_steps / elapsed_time - summary = Summary(value=[Summary.Value(tag=self._summary_tag, - simple_value=steps_per_sec)]) + summary = Summary(value=[ + Summary.Value(tag=self._summary_tag, simple_value=steps_per_sec) + ]) self._summary_writer.add_summary(summary, current_step) self._last_reported_step = current_step self._last_reported_time = current_time diff --git a/tensorflow/contrib/learn/python/learn/ops/ops_test.py b/tensorflow/contrib/learn/python/learn/ops/ops_test.py index d0b9eb8abcbee187b6c53b7b419882f0a1e7da51..80d4923db37feb2a1304218f501ab51f9e0d9a14 100644 --- a/tensorflow/contrib/learn/python/learn/ops/ops_test.py +++ b/tensorflow/contrib/learn/python/learn/ops/ops_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib.layers import conv2d from tensorflow.contrib.learn.python.learn import ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes diff --git a/tensorflow/contrib/learn/python/learn/trainable.py b/tensorflow/contrib/learn/python/learn/trainable.py index 972fec026f25d39dca75e8c5bafffb57fcd323fa..429b6040be21d8cbe1f2bba58090366552fdfbe7 100644 --- a/tensorflow/contrib/learn/python/learn/trainable.py +++ b/tensorflow/contrib/learn/python/learn/trainable.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """`Trainable` interface.""" from __future__ import absolute_import @@ -28,18 +27,31 @@ class Trainable(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod - def fit(self, x=None, y=None, input_fn=None, steps=None, batch_size=None, - monitors=None, max_steps=None): + def fit(self, + x=None, + y=None, + input_fn=None, + steps=None, + batch_size=None, + monitors=None, + max_steps=None): """Trains a model given training data `x` predictions and `y` labels. Args: - x: Matrix of shape [n_samples, n_features...] or the dictionary of Matrices. - Can be iterator that returns arrays of features or dictionary of arrays of features. - The training input samples for fitting the model. If set, `input_fn` must be `None`. - y: Vector or matrix [n_samples] or [n_samples, n_outputs] or the dictionary of same. - Can be iterator that returns array of labels or dictionary of array of labels. - The training label values (class labels in classification, real numbers in regression). - If set, `input_fn` must be `None`. Note: For classification, label values must + x: Matrix of shape [n_samples, n_features...] or the dictionary of + Matrices. + Can be iterator that returns arrays of features or dictionary of arrays + of features. + The training input samples for fitting the model. If set, `input_fn` + must be `None`. + y: Vector or matrix [n_samples] or [n_samples, n_outputs] or the + dictionary of same. + Can be iterator that returns array of labels or dictionary of array of + labels. + The training label values (class labels in classification, real numbers + in regression). + If set, `input_fn` must be `None`. Note: For classification, label + values must be integers representing the class index (i.e. values from 0 to n_classes-1). input_fn: Input function returning a tuple of: diff --git a/tensorflow/contrib/learn/python/learn/utils/export.py b/tensorflow/contrib/learn/python/learn/utils/export.py index 6af2287761299f6725f9547917101c18b0cc0164..cb34cb1d26b6812c7f3f39e9f965615de5a8ef07 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export.py +++ b/tensorflow/contrib/learn/python/learn/utils/export.py @@ -20,7 +20,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.framework import deprecated -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.session_bundle import exporter from tensorflow.contrib.session_bundle import gc from tensorflow.python.client import session as tf_session @@ -78,7 +78,7 @@ def _export_graph(graph, saver, checkpoint_path, export_dir, default_graph_signature=default_graph_signature, named_graph_signatures=named_graph_signatures, assets_collection=ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)) - return export.export(export_dir, contrib_variables.get_global_step(), + return export.export(export_dir, training_util.get_global_step(), session, exports_to_keep=exports_to_keep) @@ -295,7 +295,7 @@ def _export_estimator(estimator, checkpoint_path = (checkpoint_path or tf_saver.latest_checkpoint(estimator._model_dir)) with ops.Graph().as_default() as g: - contrib_variables.create_global_step(g) + training_util.create_global_step(g) if use_deprecated_input_fn: examples = array_ops.placeholder(dtype=dtypes.string, diff --git a/tensorflow/contrib/learn/python/learn/utils/export_test.py b/tensorflow/contrib/learn/python/learn/utils/export_test.py index 95070ada3b9d3ccb00009bd9b885e8163d7fbed4..9bfb1fc952c07bd6c09d1f1074e8dc5539dc0529 100644 --- a/tensorflow/contrib/learn/python/learn/utils/export_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/export_test.py @@ -50,6 +50,7 @@ def _training_input_fn(): class ExportTest(test.TestCase): + def _get_default_signature(self, export_meta_filename): """ Gets the default signature from the export.meta file. """ with session.Session(): @@ -69,18 +70,18 @@ class ExportTest(test.TestCase): # Only the written checkpoints are exported. self.assertTrue( saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')), - 'Exported checkpoint expected but not found: %s' % - os.path.join(export_dir, '00000001', 'export')) + 'Exported checkpoint expected but not found: %s' % os.path.join( + export_dir, '00000001', 'export')) self.assertTrue( saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')), - 'Exported checkpoint expected but not found: %s' % - os.path.join(export_dir, '00000010', 'export')) + 'Exported checkpoint expected but not found: %s' % os.path.join( + export_dir, '00000010', 'export')) self.assertEquals( six.b(os.path.join(export_dir, '00000010')), export_monitor.last_export_dir) # Validate the signature signature = self._get_default_signature( - os.path.join(export_dir, '00000010', 'export.meta')) + os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField(expected_signature)) def testExportMonitor_EstimatorProvidesSignature(self): @@ -116,8 +117,7 @@ class ExportTest(test.TestCase): def _serving_input_fn(): return { _X_KEY: - random_ops.random_uniform( - shape=(1,), minval=0.0, maxval=1000.0) + random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) }, None input_feature_key = 'my_example_key' @@ -160,8 +160,7 @@ class ExportTest(test.TestCase): input_feature_key: None, _X_KEY: - random_ops.random_uniform( - shape=(1,), minval=0.0, maxval=1000.0) + random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) }, None monitor = learn.monitors.ExportMonitor( @@ -182,8 +181,7 @@ class ExportTest(test.TestCase): def _serving_input_fn(): return { input_feature_key: - array_ops.placeholder( - dtype=dtypes.string, shape=(1,)) + array_ops.placeholder(dtype=dtypes.string, shape=(1,)) }, None monitor = learn.monitors.ExportMonitor( @@ -204,11 +202,9 @@ class ExportTest(test.TestCase): def _serving_input_fn(): return { input_feature_key: - array_ops.placeholder( - dtype=dtypes.string, shape=(1,)), + array_ops.placeholder(dtype=dtypes.string, shape=(1,)), _X_KEY: - random_ops.random_uniform( - shape=(1,), minval=0.0, maxval=1000.0) + random_ops.random_uniform(shape=(1,), minval=0.0, maxval=1000.0) }, None export_dir = os.path.join(tempfile.mkdtemp(), 'export') @@ -227,8 +223,8 @@ class ExportTest(test.TestCase): def _regression_signature(examples, unused_features, predictions): signatures = {} - signatures['regression'] = (exporter.regression_signature(examples, - predictions)) + signatures['regression'] = ( + exporter.regression_signature(examples, predictions)) return signatures['regression'], signatures random.seed(42) @@ -248,10 +244,10 @@ class ExportTest(test.TestCase): with self.assertRaises(errors.NotFoundError): saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export')) self.assertTrue( - saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export'))) + saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export'))) # Validate the signature signature = self._get_default_signature( - os.path.join(export_dir, '00000010', 'export.meta')) + os.path.join(export_dir, '00000010', 'export.meta')) self.assertTrue(signature.HasField('regression_signature')) diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py index 76cfd88e1d68856907131f7e2bae65d4c9fcc4b1..e7d091e18a8f186f89f5217442c24fb106c5cdab 100644 --- a/tensorflow/contrib/learn/python/learn/utils/gc_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -34,12 +34,13 @@ def _create_parser(base_dir): # create a simple parser that pulls the export_version from the directory. def parser(path): # Modify the path object for RegEx match for Windows Paths - if os.name == 'nt': - match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$", - compat.as_str_any(path.path).replace('\\','/')) + if os.name == "nt": + match = re.match( + "^" + compat.as_str_any(base_dir).replace("\\", "/") + "/(\\d+)$", + compat.as_str_any(path.path).replace("\\", "/")) else: match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$", - compat.as_str_any(path.path)) + compat.as_str_any(path.path)) if not match: return None return path._replace(export_version=int(match.group(1))) @@ -63,7 +64,9 @@ class GcTest(test_util.TensorFlowTestCase): def testModExportVersion(self): paths = [ - gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 4), + gc.Path("/foo", 5), + gc.Path("/foo", 6), gc.Path("/foo", 9) ] mod = gc.mod_export_version(2) @@ -73,14 +76,21 @@ class GcTest(test_util.TensorFlowTestCase): def testOneOfEveryNExportVersions(self): paths = [ - gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3), - gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7), - gc.Path("/foo", 8), gc.Path("/foo", 33) + gc.Path("/foo", 0), + gc.Path("/foo", 1), + gc.Path("/foo", 3), + gc.Path("/foo", 5), + gc.Path("/foo", 6), + gc.Path("/foo", 7), + gc.Path("/foo", 8), + gc.Path("/foo", 33) ] one_of = gc.one_of_every_n_export_versions(3) self.assertEqual( one_of(paths), [ - gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8), + gc.Path("/foo", 3), + gc.Path("/foo", 6), + gc.Path("/foo", 8), gc.Path("/foo", 33) ]) @@ -98,13 +108,19 @@ class GcTest(test_util.TensorFlowTestCase): f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) self.assertEqual( f(paths), [ - gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6), - gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9) + gc.Path("/foo", 0), + gc.Path("/foo", 3), + gc.Path("/foo", 6), + gc.Path("/foo", 7), + gc.Path("/foo", 8), + gc.Path("/foo", 9) ]) def testNegation(self): paths = [ - gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 4), + gc.Path("/foo", 5), + gc.Path("/foo", 6), gc.Path("/foo", 9) ] mod = gc.negation(gc.mod_export_version(2)) @@ -121,8 +137,7 @@ class GcTest(test_util.TensorFlowTestCase): gfile.MakeDirs(os.path.join(base_dir, "ignore")) self.assertEqual( - gc.get_paths(base_dir, _create_parser(base_dir)), - [ + gc.get_paths(base_dir, _create_parser(base_dir)), [ gc.Path(os.path.join(base_dir, "0"), 0), gc.Path(os.path.join(base_dir, "1"), 1), gc.Path(os.path.join(base_dir, "2"), 2) @@ -131,10 +146,10 @@ class GcTest(test_util.TensorFlowTestCase): def testMixedStrTypes(self): temp_dir = compat.as_bytes(test.get_temp_dir()) - for sub_dir in ['str', b'bytes', u'unicode']: + for sub_dir in ["str", b"bytes", u"unicode"]: base_dir = os.path.join( - (temp_dir if isinstance(sub_dir, bytes) else temp_dir.decode()), - sub_dir) + (temp_dir + if isinstance(sub_dir, bytes) else temp_dir.decode()), sub_dir) self.assertFalse(gfile.Exists(base_dir)) gfile.MakeDirs(os.path.join(compat.as_str_any(base_dir), "42")) gc.get_paths(base_dir, _create_parser(base_dir)) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 6ffd2a133995a6ff8b35540221fb5676bf5de19f..1593380007b2799fb1d17e92408ab19a7b47fe1e 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -33,7 +33,6 @@ from __future__ import division from __future__ import print_function import os -import tempfile import time from tensorflow.contrib.layers.python.layers import feature_column @@ -51,6 +50,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.summary import summary_iterator from tensorflow.python.training import saver from tensorflow.python.util import compat @@ -391,7 +391,8 @@ def make_export_strategy(serving_input_fn, default_output_alternative_key=None, assets_extra=None, as_text=False, - exports_to_keep=5): + exports_to_keep=5, + strip_default_attrs=None): """Create an ExportStrategy for use with Experiment. Args: @@ -412,12 +413,16 @@ def make_export_strategy(serving_input_fn, exports_to_keep: Number of exports to keep. Older exports will be garbage-collected. Defaults to 5. Set to None to disable garbage collection. + strip_default_attrs: Boolean. If True, default attrs in the + `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. """ - def export_fn(estimator, export_dir_base, checkpoint_path=None): + def export_fn(estimator, export_dir_base, checkpoint_path=None, + strip_default_attrs=False): """Exports the given Estimator as a SavedModel. Args: @@ -426,6 +431,8 @@ def make_export_strategy(serving_input_fn, graph and checkpoints. checkpoint_path: The checkpoint path to export. If None (the default), the most recent checkpoint found within the model directory is chosen. + strip_default_attrs: Boolean. If `True`, default-valued attributes will + be removed from the NodeDefs. Returns: The string path to the exported directory. @@ -444,7 +451,8 @@ def make_export_strategy(serving_input_fn, serving_input_fn, assets_extra=assets_extra, as_text=as_text, - checkpoint_path=checkpoint_path) + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) else: export_result = estimator.export_savedmodel( export_dir_base, @@ -452,12 +460,13 @@ def make_export_strategy(serving_input_fn, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, as_text=as_text, - checkpoint_path=checkpoint_path) + checkpoint_path=checkpoint_path, + strip_default_attrs=strip_default_attrs) garbage_collect_exports(export_dir_base, exports_to_keep) return export_result - return export_strategy.ExportStrategy('Servo', export_fn) + return export_strategy.ExportStrategy('Servo', export_fn, strip_default_attrs) def make_parsing_export_strategy(feature_columns, @@ -465,7 +474,8 @@ def make_parsing_export_strategy(feature_columns, assets_extra=None, as_text=False, exports_to_keep=5, - target_core=False): + target_core=False, + strip_default_attrs=None): """Create an ExportStrategy for use with Experiment, using `FeatureColumn`s. Creates a SavedModel export that expects to be fed with a single string @@ -493,6 +503,9 @@ def make_parsing_export_strategy(feature_columns, target_core: If True, prepare an ExportStrategy for use with tensorflow.python.estimator.*. If False (default), prepare an ExportStrategy for use with tensorflow.contrib.learn.python.learn.*. + strip_default_attrs: Boolean. If True, default attrs in the + `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -509,7 +522,8 @@ def make_parsing_export_strategy(feature_columns, default_output_alternative_key=default_output_alternative_key, assets_extra=assets_extra, as_text=as_text, - exports_to_keep=exports_to_keep) + exports_to_keep=exports_to_keep, + strip_default_attrs=strip_default_attrs) def _default_compare_fn(curr_best_eval_result, cand_eval_result): @@ -543,15 +557,16 @@ def _default_compare_fn(curr_best_eval_result, cand_eval_result): class BestModelSelector(object): """A helper that keeps track of export selection candidates.""" - def __init__(self, compare_fn=None): + def __init__(self, event_file_pattern=None, compare_fn=None): """Constructor of this class. Args: + event_file_pattern: absolute event file name pattern. compare_fn: a function that returns true if the candidate is better than the current best model. """ - self._best_eval_result = None self._compare_fn = compare_fn or _default_compare_fn + self._best_eval_result = self._get_best_eval_result(event_file_pattern) def update(self, checkpoint_path, eval_result): """Records a given checkpoint and exports if this is the best model. @@ -581,11 +596,40 @@ class BestModelSelector(object): else: return '', None + def _get_best_eval_result(self, event_files): + """Get the best eval result from event files. -def make_best_model_export_strategy(serving_input_fn, - exports_to_keep=1, - compare_fn=None, - default_output_alternative_key=None): + Args: + event_files: Absolute pattern of event files. + + Returns: + The best eval result. + """ + if not event_files: + return None + + best_eval_result = None + for event_file in gfile.Glob(os.path.join(event_files)): + for event in summary_iterator.summary_iterator(event_file): + if event.HasField('summary'): + event_eval_result = {} + for value in event.summary.value: + if value.HasField('simple_value'): + event_eval_result[value.tag] = value.simple_value + if best_eval_result is None or self._compare_fn( + best_eval_result, event_eval_result): + best_eval_result = event_eval_result + return best_eval_result + + +def make_best_model_export_strategy( + serving_input_fn, + exports_to_keep=1, + model_dir=None, + event_file_pattern=None, + compare_fn=None, + default_output_alternative_key=None, + strip_default_attrs=None): """Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment. Args: @@ -593,10 +637,24 @@ def make_best_model_export_strategy(serving_input_fn, `InputFnOps`. exports_to_keep: an integer indicating how many historical best models need to be preserved. + model_dir: Directory where model parameters, graph etc. are saved. This will + be used to load eval metrics from the directory when the export strategy + is created. So the best metrics would not be lost even if the export + strategy got preempted, which guarantees that only the best model would + be exported regardless of preemption. If None, however, the export + strategy would not be preemption-safe. To be preemption-safe, both + model_dir and event_file_pattern would be needed. + event_file_pattern: event file name pattern relative to model_dir, e.g. + "eval_continuous/*.tfevents.*". If None, however, the export strategy + would not be preemption-safe. To be preemption-safe, both + model_dir and event_file_pattern would be needed. compare_fn: a function that select the 'best' candidate from a dictionary of evaluation result keyed by corresponding checkpoint path. default_output_alternative_key: the key for default serving signature for multi-headed inference graphs. + strip_default_attrs: Boolean. If True, default attrs in the + `GraphDef` will be stripped on write. This is recommended for better + forward compatibility of the resulting `SavedModel`. Returns: An ExportStrategy that can be passed to the Experiment constructor. @@ -604,9 +662,13 @@ def make_best_model_export_strategy(serving_input_fn, best_model_export_strategy = make_export_strategy( serving_input_fn, exports_to_keep=exports_to_keep, - default_output_alternative_key=default_output_alternative_key) + default_output_alternative_key=default_output_alternative_key, + strip_default_attrs=strip_default_attrs) - best_model_selector = BestModelSelector(compare_fn) + full_event_file_pattern = os.path.join( + model_dir, + event_file_pattern) if model_dir and event_file_pattern else None + best_model_selector = BestModelSelector(full_event_file_pattern, compare_fn) def export_fn(estimator, export_dir_base, checkpoint_path, eval_result=None): """Exports the given Estimator as a SavedModel. @@ -682,22 +744,36 @@ def extend_export_strategy(base_export_strategy, ValueError: If `estimator` is a ${tf.estimator.Estimator} instance and `default_output_alternative_key` was specified or if post_export_fn does not return a valid directory. + RuntimeError: If unable to create temporary or final export directory. """ - tmp_base_export_dir = tempfile.mkdtemp() + tmp_base_export_folder = 'temp-base-export-' + str(int(time.time())) + tmp_base_export_dir = os.path.join(export_dir_base, tmp_base_export_folder) + if gfile.Exists(tmp_base_export_dir): + raise RuntimeError('Failed to obtain base export directory') + gfile.MakeDirs(tmp_base_export_dir) tmp_base_export = base_export_strategy.export( estimator, tmp_base_export_dir, checkpoint_path) - tmp_post_export_dir = tempfile.mkdtemp() + + tmp_post_export_folder = 'temp-post-export-' + str(int(time.time())) + tmp_post_export_dir = os.path.join(export_dir_base, tmp_post_export_folder) + if gfile.Exists(tmp_post_export_dir): + raise RuntimeError('Failed to obtain temp export directory') + + gfile.MakeDirs(tmp_post_export_dir) tmp_post_export = post_export_fn(tmp_base_export, tmp_post_export_dir) if not tmp_post_export.startswith(tmp_post_export_dir): raise ValueError('post_export_fn must return a sub-directory of {}' .format(tmp_post_export_dir)) - export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) - - gfile.Rename( - os.path.join(tmp_post_export_dir, export_relpath), - os.path.join(export_dir_base, export_relpath)) - return os.path.join(export_dir_base, export_relpath) + post_export_relpath = os.path.relpath(tmp_post_export, tmp_post_export_dir) + post_export = os.path.join(export_dir_base, post_export_relpath) + if gfile.Exists(post_export): + raise RuntimeError('Failed to obtain final export directory') + gfile.Rename(tmp_post_export, post_export) + + gfile.DeleteRecursively(tmp_base_export_dir) + gfile.DeleteRecursively(tmp_post_export_dir) + return post_export name = post_export_name if post_export_name else base_export_strategy.name return export_strategy.ExportStrategy(name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index ec3a88003f01b3b62591c13472029601b11ba491..14bf1136e8e9ab1488c4850d458382028ec5583d 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -24,13 +24,14 @@ import time from tensorflow.contrib.layers.python.layers import feature_column as fc from tensorflow.contrib.learn.python.learn import export_strategy as export_strategy_lib from tensorflow.contrib.learn.python.learn.estimators import constants -from tensorflow.contrib.learn.python.learn.estimators import estimator as core_estimator +from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn from tensorflow.contrib.learn.python.learn.utils import input_fn_utils from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils from tensorflow.core.framework import tensor_shape_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.estimator import estimator as core_estimator from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops @@ -41,7 +42,7 @@ from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.util import compat -class TestEstimator(core_estimator.Estimator): +class TestEstimator(estimator.Estimator): def __init__(self, *args, **kwargs): super(TestEstimator, self).__init__(*args, **kwargs) @@ -55,7 +56,8 @@ class TestEstimator(core_estimator.Estimator): default_output_alternative_key=None, assets_extra=None, as_text=False, - checkpoint_path=None): + checkpoint_path=None, + strip_default_attrs=False): if not os.path.exists(export_dir): os.makedirs(export_dir) @@ -93,9 +95,9 @@ class SavedModelExportUtilsTest(test.TestCase): name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.REGRESS_OUTPUTS].CopyFrom( - meta_graph_pb2.TensorInfo(name="output-tensor-1:0", - dtype=dtype_float, - tensor_shape=shape)) + meta_graph_pb2.TensorInfo( + name="output-tensor-1:0", dtype=dtype_float, + tensor_shape=shape)) expected_signature_def.method_name = signature_constants.REGRESS_METHOD_NAME self.assertEqual(actual_signature_def, expected_signature_def) @@ -506,7 +508,9 @@ class SavedModelExportUtilsTest(test.TestCase): input_example = constant_op.constant(["input string"]) input_ops = input_fn_utils.InputFnOps({ "features": input_features - }, None, {"default input": input_example}) + }, None, { + "default input": input_example + }) input_alternatives, _ = ( saved_model_export_utils.get_input_alternatives(input_ops)) output_1 = constant_op.constant([1.0]) @@ -527,8 +531,9 @@ class SavedModelExportUtilsTest(test.TestCase): model_fn.ModeKeys.INFER, predictions={"some_output": constant_op.constant(["4"])}, output_alternatives=provided_output_alternatives) - output_alternatives, _ = (saved_model_export_utils.get_output_alternatives( - model_fn_ops, "head-1")) + output_alternatives, _ = ( + saved_model_export_utils.get_output_alternatives( + model_fn_ops, "head-1")) signature_defs = saved_model_export_utils.build_all_signature_defs( input_alternatives, output_alternatives, "head-1") @@ -546,7 +551,9 @@ class SavedModelExportUtilsTest(test.TestCase): "default_input_alternative:head-3": signature_def_utils.predict_signature_def({ "default input": input_example - }, {"some_output_3": output_3}), + }, { + "some_output_3": output_3 + }), # "features_input_alternative:head-1": # signature_def_utils.regression_signature_def(input_features, # output_1), @@ -589,8 +596,9 @@ class SavedModelExportUtilsTest(test.TestCase): model_fn.ModeKeys.INFER, predictions={"some_output": constant_op.constant(["4"])}, output_alternatives=provided_output_alternatives) - output_alternatives, _ = (saved_model_export_utils.get_output_alternatives( - model_fn_ops, "head-1")) + output_alternatives, _ = ( + saved_model_export_utils.get_output_alternatives( + model_fn_ops, "head-1")) with self.assertRaisesRegexp( ValueError, "A default input_alternative must be provided"): @@ -706,25 +714,72 @@ class SavedModelExportUtilsTest(test.TestCase): self.assertNotEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_0", {"loss": 100})) + "fake_ckpt_0", { + "loss": 100 + })) self.assertNotEqual("", test_estimator.last_exported_dir) self.assertNotEqual("", test_estimator.last_exported_checkpoint) self.assertEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_1", {"loss": 101})) + "fake_ckpt_1", { + "loss": 101 + })) self.assertEqual(test_estimator.last_exported_dir, os.path.join(export_dir_base, "fake_ckpt_0")) self.assertNotEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_2", {"loss": 10})) + "fake_ckpt_2", { + "loss": 10 + })) + self.assertEqual(test_estimator.last_exported_dir, + os.path.join(export_dir_base, "fake_ckpt_2")) + + self.assertEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_3", { + "loss": 20 + })) + self.assertEqual(test_estimator.last_exported_dir, + os.path.join(export_dir_base, "fake_ckpt_2")) + + def test_make_best_model_export_strategy_with_preemption(self): + model_dir = self.get_temp_dir() + eval_dir_base = os.path.join(model_dir, "eval_continuous") + core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 50}, 1) + core_estimator._write_dict_to_summary(eval_dir_base, {"loss": 60}, 2) + + test_estimator = TestEstimator() + export_strategy = saved_model_export_utils.make_best_model_export_strategy( + serving_input_fn=None, + exports_to_keep=3, + model_dir=model_dir, + event_file_pattern="eval_continuous/*.tfevents.*", + compare_fn=None) + + export_dir_base = os.path.join(self.get_temp_dir(), "export") + self.assertEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_0", { + "loss": 100 + })) + self.assertEqual("", test_estimator.last_exported_dir) + self.assertEqual("", test_estimator.last_exported_checkpoint) + + self.assertNotEqual("", + export_strategy.export(test_estimator, export_dir_base, + "fake_ckpt_2", { + "loss": 10 + })) self.assertEqual(test_estimator.last_exported_dir, os.path.join(export_dir_base, "fake_ckpt_2")) self.assertEqual("", export_strategy.export(test_estimator, export_dir_base, - "fake_ckpt_3", {"loss": 20})) + "fake_ckpt_3", { + "loss": 20 + })) self.assertEqual(test_estimator.last_exported_dir, os.path.join(export_dir_base, "fake_ckpt_2")) @@ -766,10 +821,11 @@ class SavedModelExportUtilsTest(test.TestCase): test_estimator = TestEstimator() tmpdir = tempfile.mkdtemp() - final_path = final_export_strategy.export(test_estimator, tmpdir, - os.path.join( - tmpdir, "checkpoint")) - self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + export_model_dir = os.path.join(tmpdir, "model") + checkpoint_path = os.path.join(tmpdir, "checkpoint") + final_path = final_export_strategy.export(test_estimator, export_model_dir, + checkpoint_path) + self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) def test_extend_export_strategy_same_name(self): @@ -795,10 +851,11 @@ class SavedModelExportUtilsTest(test.TestCase): test_estimator = TestEstimator() tmpdir = tempfile.mkdtemp() - final_path = final_export_strategy.export(test_estimator, tmpdir, - os.path.join( - tmpdir, "checkpoint")) - self.assertEqual(os.path.join(tmpdir, "rewrite"), final_path) + export_model_dir = os.path.join(tmpdir, "model") + checkpoint_path = os.path.join(tmpdir, "checkpoint") + final_path = final_export_strategy.export(test_estimator, export_model_dir, + checkpoint_path) + self.assertEqual(os.path.join(export_model_dir, "rewrite"), final_path) def test_extend_export_strategy_raises_error(self): diff --git a/tensorflow/contrib/legacy_seq2seq/python/__init__.py b/tensorflow/contrib/legacy_seq2seq/python/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..52e83069cb0c68b510da46149248369dce376647 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/__init__.py +++ b/tensorflow/contrib/legacy_seq2seq/python/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..52e83069cb0c68b510da46149248369dce376647 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py +++ b/tensorflow/contrib/legacy_seq2seq/python/kernel_tests/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function diff --git a/tensorflow/contrib/libsvm/BUILD b/tensorflow/contrib/libsvm/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..df96402a4ffd51840f77d58d8066487030362340 --- /dev/null +++ b/tensorflow/contrib/libsvm/BUILD @@ -0,0 +1,102 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") +load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") +load("//tensorflow:tensorflow.bzl", "tf_kernel_library") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_test") + +tf_custom_op_library( + name = "python/ops/_libsvm_ops.so", + srcs = [ + "kernels/decode_libsvm_op.cc", + "ops/libsvm_ops.cc", + ], + deps = [ + "//tensorflow/core/kernels:bounds_check_lib", + ], +) + +tf_kernel_library( + name = "libsvm_kernels", + srcs = ["kernels/decode_libsvm_op.cc"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:bounds_check_lib", + ], +) + +tf_gen_op_libs( + op_lib_names = ["libsvm_ops"], + deps = [ + "//tensorflow/core:lib", + ], +) + +tf_gen_op_wrapper_py( + name = "libsvm_ops", + deps = [":libsvm_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "libsvm", + srcs = [ + "__init__.py", + "python/ops/libsvm_ops.py", + ], + dso = [ + ":python/ops/_libsvm_ops.so", + ], + kernels = [ + ":libsvm_kernels", + ":libsvm_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":libsvm_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + ], +) + +tf_py_test( + name = "decode_libsvm_op_test", + srcs = ["python/kernel_tests/decode_libsvm_op_test.py"], + additional_deps = [ + ":libsvm", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/libsvm/__init__.py b/tensorflow/contrib/libsvm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a875863caab29eb59a1834ca9184a5e272cb6656 --- /dev/null +++ b/tensorflow/contrib/libsvm/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Libsvm decoder. + +@@decode_libsvm +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.libsvm.python.ops.libsvm_ops import decode_libsvm + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "decode_libsvm", +] + +remove_undocumented(__name__) diff --git a/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..720c74e3de5907fa006227d1278c45fd2175fe5f --- /dev/null +++ b/tensorflow/contrib/libsvm/kernels/decode_libsvm_op.cc @@ -0,0 +1,168 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { + +template +class DecodeLibsvmOp : public OpKernel { + public: + explicit DecodeLibsvmOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_features", &num_features_)); + OP_REQUIRES(ctx, (num_features_ >= 1), + errors::InvalidArgument("Invalid number of features \"", + num_features_, "\"")); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat(); + + Tensor* label_tensor; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, input_tensor->shape(), &label_tensor)); + auto label = label_tensor->flat(); + + std::vector out_values; + std::vector> out_indices; + for (int i = 0; i < input_flat.size(); ++i) { + StringPiece line(input_flat(i)); + str_util::RemoveWhitespaceContext(&line); + + StringPiece piece; + OP_REQUIRES(ctx, str_util::ConsumeNonWhitespace(&line, &piece), + errors::InvalidArgument("No label found for input[", i, + "]: \"", input_flat(i), "\"")); + + Tlabel label_value; + OP_REQUIRES(ctx, + strings::SafeStringToNumeric(piece, &label_value), + errors::InvalidArgument("Label format incorrect: ", piece)); + + label(i) = label_value; + + str_util::RemoveLeadingWhitespace(&line); + while (str_util::ConsumeNonWhitespace(&line, &piece)) { + size_t p = piece.find(':'); + OP_REQUIRES(ctx, (p != StringPiece::npos), + errors::InvalidArgument("Invalid feature \"", piece, "\"")); + + int64 feature_index; + OP_REQUIRES( + ctx, strings::safe_strto64(piece.substr(0, p), &feature_index), + errors::InvalidArgument("Feature format incorrect: ", piece)); + OP_REQUIRES(ctx, (feature_index >= 0), + errors::InvalidArgument( + "Feature index should be >= 0, got ", feature_index)); + + T feature_value; + OP_REQUIRES( + + ctx, + strings::SafeStringToNumeric(piece.substr(p + 1), + &feature_value), + errors::InvalidArgument("Feature format incorrect: ", piece)); + + out_values.emplace_back(feature_value); + out_indices.emplace_back(std::pair(i, feature_index)); + + str_util::RemoveLeadingWhitespace(&line); + } + } + + Tensor* indices_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 1, + TensorShape({static_cast(out_indices.size()), + input_tensor->shape().dims() + 1}), + &indices_tensor)); + auto indices = indices_tensor->matrix(); + // Translate flat index to shaped index like np.unravel_index + // Calculate factors for each dimension + std::vector factors(input_tensor->shape().dims()); + factors[input_tensor->shape().dims() - 1] = 1; + for (int j = input_tensor->shape().dims() - 2; j >= 0; j--) { + factors[j] = factors[j + 1] * input_tensor->shape().dim_size(j + 1); + } + for (int i = 0; i < out_indices.size(); i++) { + indices(i, 0) = out_indices[i].first; + int64 value = out_indices[i].first; + for (int j = 0; j < input_tensor->shape().dims(); j++) { + indices(i, j) = value / factors[j]; + value = value % factors[j]; + } + indices(i, input_tensor->shape().dims()) = out_indices[i].second; + } + + Tensor* values_tensor; + OP_REQUIRES_OK(ctx, + ctx->allocate_output( + 2, TensorShape({static_cast(out_values.size())}), + &values_tensor)); + auto values = values_tensor->vec(); + std::copy_n(out_values.begin(), out_values.size(), &values(0)); + + Tensor* shape_tensor; + OP_REQUIRES_OK(ctx, ctx->allocate_output( + 3, TensorShape({input_tensor->shape().dims() + 1}), + &shape_tensor)); + auto shape = shape_tensor->flat(); + for (int i = 0; i < input_tensor->shape().dims(); i++) { + shape(i) = input_tensor->shape().dim_size(i); + } + shape(input_tensor->shape().dims()) = num_features_; + } + + private: + int64 num_features_; +}; + +#define REGISTER_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); \ + REGISTER_KERNEL_BUILDER(Name("DecodeLibsvm") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype") \ + .TypeConstraint("label_dtype"), \ + DecodeLibsvmOp); + +REGISTER_KERNEL(float); +REGISTER_KERNEL(double); +REGISTER_KERNEL(int32); +REGISTER_KERNEL(int64); +#undef REGISTER_KERNEL + +} // namespace tensorflow diff --git a/tensorflow/contrib/libsvm/ops/libsvm_ops.cc b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..dec946189e3cd67e2557b83806c0db79a46e5f82 --- /dev/null +++ b/tensorflow/contrib/libsvm/ops/libsvm_ops.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +using shape_inference::InferenceContext; + +REGISTER_OP("DecodeLibsvm") + .Input("input: string") + .Output("label: label_dtype") + .Output("feature_indices: int64") + .Output("feature_values: dtype") + .Output("feature_shape: int64") + .Attr("dtype: {float, double, int32, int64} = DT_FLOAT") + .Attr("label_dtype: {float, double, int32, int64} = DT_INT64") + .Attr("num_features: int >= 1") + .SetShapeFn([](InferenceContext* c) { + c->set_output(0, c->input(0)); + + c->set_output(1, c->Matrix(InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim)); + c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); + c->set_output(3, c->Vector(InferenceContext::kUnknownDim)); + + return Status::OK(); + }) + + .Doc(R"doc( +Convert LibSVM input to tensors. The output consists of +a label and a feature tensor. The shape of the label tensor +is the same as input and the shape of the feature tensor is +`[input_shape, num_features]`. + +input: Each string is a record in the LibSVM. +label: A tensor of the same shape as input. +feature_indices: A 2-D int64 tensor of dense_shape [N, ndims]. +feature_values: A 1-D tensor of any type and dense_shape [N]. +feature_shape: A 1-D int64 tensor of dense_shape [ndims]. +num_features: The number of features. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..423dcce8de9b9c77fcfdc8c90c909e2918852905 --- /dev/null +++ b/tensorflow/contrib/libsvm/python/kernel_tests/decode_libsvm_op_test.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for DecodeLibsvm op.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.libsvm.python.ops import libsvm_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import sparse_ops +from tensorflow.python.platform import test + + +class DecodeLibsvmOpTest(test.TestCase): + + def testBasic(self): + with self.test_session() as sess: + content = [ + "1 1:3.4 2:0.5 4:0.231", "1 2:2.5 3:inf 5:0.503", + "2 3:2.5 2:nan 1:0.105" + ] + sparse_features, labels = libsvm_ops.decode_libsvm( + content, num_features=6) + features = sparse_ops.sparse_tensor_to_dense( + sparse_features, validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [1, 1, 2]) + self.assertAllClose( + features, [[0, 3.4, 0.5, 0, 0.231, 0], [0, 0, 2.5, np.inf, 0, 0.503], + [0, 0.105, np.nan, 2.5, 0, 0]]) + + def testNDimension(self): + with self.test_session() as sess: + content = [["1 1:3.4 2:0.5 4:0.231", "1 1:3.4 2:0.5 4:0.231"], + ["1 2:2.5 3:inf 5:0.503", "1 2:2.5 3:inf 5:0.503"], + ["2 3:2.5 2:nan 1:0.105", "2 3:2.5 2:nan 1:0.105"]] + sparse_features, labels = libsvm_ops.decode_libsvm( + content, num_features=6, label_dtype=dtypes.float64) + features = sparse_ops.sparse_tensor_to_dense( + sparse_features, validate_indices=False) + + self.assertAllEqual(labels.get_shape().as_list(), [3, 2]) + + features, labels = sess.run([features, labels]) + self.assertAllEqual(labels, [[1, 1], [1, 1], [2, 2]]) + self.assertAllClose( + features, [[[0, 3.4, 0.5, 0, 0.231, 0], [0, 3.4, 0.5, 0, 0.231, 0]], [ + [0, 0, 2.5, np.inf, 0, 0.503], [0, 0, 2.5, np.inf, 0, 0.503] + ], [[0, 0.105, np.nan, 2.5, 0, 0], [0, 0.105, np.nan, 2.5, 0, 0]]]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b3022505635bca81625cf7abd2be5628a4760970 --- /dev/null +++ b/tensorflow/contrib/libsvm/python/ops/libsvm_ops.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Libsvm decoder.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.libsvm.ops import gen_libsvm_ops +from tensorflow.contrib.util import loader +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.platform import resource_loader + + +_libsvm_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_libsvm_ops.so")) + + +def decode_libsvm(content, num_features, dtype=None, label_dtype=None): + """Convert Libsvm records to a tensor of label and a tensor of feature. + + Args: + content: A `Tensor` of type `string`. Each string is a record/row in + the Libsvm format. + num_features: The number of features. + dtype: The type of the output feature tensor. Default to tf.float32. + label_dtype: The type of the output label tensor. Default to tf.int64. + + Returns: + features: A `SparseTensor` of the shape `[input_shape, num_features]`. + labels: A `Tensor` of the same shape as content. + """ + labels, indices, values, shape = gen_libsvm_ops.decode_libsvm( + content, num_features, dtype=dtype, label_dtype=label_dtype) + return sparse_tensor.SparseTensor(indices, values, shape), labels + + +ops.NotDifferentiable("DecodeLibSVM") diff --git a/tensorflow/contrib/linear_optimizer/BUILD b/tensorflow/contrib/linear_optimizer/BUILD index fe2f183ac970cef4ebf6ca1a927b5a48eefb7d7b..cea3627ed565f0de86d8d9bb6b45c4b19c5b5558 100644 --- a/tensorflow/contrib/linear_optimizer/BUILD +++ b/tensorflow/contrib/linear_optimizer/BUILD @@ -126,6 +126,7 @@ py_library( py_test( name = "sdca_estimator_test", srcs = ["python/sdca_estimator_test.py"], + shard_count = 4, srcs_version = "PY2AND3", deps = [ ":sdca_estimator_py", diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py index 7526f3ae0dbdb3d6827e9d7f690090b8438e4f6e..3f5fdc18bb8f47cceee8f81dd5ded02059344b8b 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sdca_ops.py @@ -211,9 +211,8 @@ class SdcaModel(object): sums.append( math_ops.reduce_sum( math_ops.abs(math_ops.cast(weights, dtypes.float64)))) - sum = math_ops.add_n(sums) # SDCA L1 regularization cost is: l1 * sum(|weights|) - return self._options['symmetric_l1_regularization'] * sum + return self._options['symmetric_l1_regularization'] * math_ops.add_n(sums) def _l2_loss(self, l2): """Computes the (un-normalized) l2 loss of the model.""" @@ -225,9 +224,8 @@ class SdcaModel(object): sums.append( math_ops.reduce_sum( math_ops.square(math_ops.cast(weights, dtypes.float64)))) - sum = math_ops.add_n(sums) # SDCA L2 regularization cost is: l2 * sum(weights^2) / 2 - return l2 * sum / 2.0 + return l2 * math_ops.add_n(sums) / 2.0 def _convert_n_to_tensor(self, input_list, as_ref=False): """Converts input list to a set of tensors.""" diff --git a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py index 7e214905b13db6a7e2f54f15873f5a9aedb4f44f..ec726bbed41a86eb314e3591ecaedaa6bf0e5e9b 100644 --- a/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py +++ b/tensorflow/contrib/linear_optimizer/python/ops/sharded_mutable_dense_hashtable.py @@ -102,7 +102,7 @@ class ShardedMutableDenseHashTable(lookup.LookupInterface): keys.get_shape()) def lookup(self, keys, name=None): - if keys.dtype != self._key_dtype: + if keys.dtype.base_dtype != self._key_dtype: raise TypeError('Signature mismatch. Keys must be dtype %s, got %s.' % (self._key_dtype, keys.dtype)) self._check_keys(keys) diff --git a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py index 701fc1c0597d1de0b0189e86feafbd1c5bbdc818..05794a42c5f2d0eece6adab36fb5610078cece31 100644 --- a/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py +++ b/tensorflow/contrib/linear_optimizer/python/sdca_estimator.py @@ -19,7 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import variables as contrib_variables +from tensorflow.python.training import training_util from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib from tensorflow.contrib.learn.python.learn.estimators import prediction_key @@ -154,7 +154,7 @@ def sdca_model_fn(features, labels, mode, params, config=None): _add_bias_column(feature_columns, features, bias, columns_to_variables) def _train_op_fn(unused_loss): - global_step = contrib_variables.get_global_step() + global_step = training_util.get_global_step() sdca_model, train_op = optimizer.get_train_step( columns_to_variables, weight_column_name, loss_type, features, labels, global_step) diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD index 52460123cc10ec9b2ee13043fd43f84508b05000..44c4a7e2ca8d019ca602c7f2b492cd1e70b17561 100644 --- a/tensorflow/contrib/lite/BUILD +++ b/tensorflow/contrib/lite/BUILD @@ -6,8 +6,11 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") +exports_files(["LICENSE"]) + exports_files(glob([ "testdata/*.bin", + "testdata/*.pb", "models/testdata/*", ])) @@ -25,16 +28,35 @@ config_setting( }, ) -load( - "//tensorflow:tensorflow.bzl", - "tf_cc_test", -) - cc_library( name = "schema_fbs_version", hdrs = ["version.h"], ) +cc_library( + name = "arena_planner", + srcs = ["arena_planner.cc"], + hdrs = ["arena_planner.h"], + deps = [ + ":context", + ":graph_info", + ":memory_planner", + ":simple_memory_arena", + ], +) + +cc_test( + name = "arena_planner_test", + size = "small", + srcs = ["arena_planner_test.cc"], + deps = [ + ":arena_planner", + "//tensorflow/contrib/lite/testing:util", + "//tensorflow/core:lib", + "@com_google_googletest//:gtest", + ], +) + # Main library. No ops are included here. # TODO(aselle): Resolve problems preventing C99 usage. cc_library( @@ -43,6 +65,25 @@ cc_library( hdrs = ["context.h"], ) +cc_library( + name = "graph_info", + hdrs = ["graph_info.h"], + deps = [":context"], +) + +cc_library( + name = "memory_planner", + hdrs = ["memory_planner.h"], + deps = [":context"], +) + +cc_library( + name = "simple_memory_arena", + srcs = ["simple_memory_arena.cc"], + hdrs = ["simple_memory_arena.h"], + deps = [":context"], +) + cc_library( name = "builtin_op_data", hdrs = [ @@ -66,27 +107,31 @@ cc_library( srcs = [ "allocation.cc", "error_reporter.cc", + "graph_info.cc", "interpreter.cc", "model.cc", "nnapi_delegate.cc", "optional_debug_tools.cc", - "simple_memory_arena.cc", ], hdrs = [ "allocation.h", "context.h", "error_reporter.h", + "graph_info.h", "interpreter.h", "model.h", "nnapi_delegate.h", "optional_debug_tools.h", - "simple_memory_arena.h", ], copts = tflite_copts(), deps = [ + ":arena_planner", ":builtin_op_data", ":context", + ":graph_info", + ":memory_planner", ":schema_fbs_version", + ":simple_memory_arena", "//tensorflow/contrib/lite/kernels:gemm_support", "//tensorflow/contrib/lite/nnapi:nnapi_lib", "//tensorflow/contrib/lite/schema:schema_fbs", @@ -111,6 +156,7 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -123,6 +169,22 @@ cc_test( deps = [ ":framework", ":string_util", + "//tensorflow/contrib/lite/kernels/internal:tensor_utils", + "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + +# Test graph utils +cc_test( + name = "graph_info_test", + size = "small", + srcs = ["graph_info_test.cc"], + deps = [ + ":framework", + ":string_util", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -133,7 +195,8 @@ cc_test( size = "small", srcs = ["simple_memory_arena_test.cc"], deps = [ - ":framework", + ":simple_memory_arena", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -152,6 +215,7 @@ cc_test( ], deps = [ ":framework", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -163,6 +227,7 @@ cc_test( srcs = ["context_test.cc"], deps = [ ":framework", + "//tensorflow/contrib/lite/testing:util", "@com_google_googletest//:gtest", ], ) @@ -171,18 +236,18 @@ cc_test( # Model tests -cc_library( - name = "models_test_utils", - testonly = 1, - hdrs = ["models/test_utils.h"], - deps = select({ - "//tensorflow:android": [], - "//conditions:default": [ - "@com_google_absl//absl/strings", - "//tensorflow/core:test", - ], - }), -) +#cc_library( +# name = "models_test_utils", +# testonly = 1, +# hdrs = ["models/test_utils.h"], +# deps = select({ +# "//tensorflow:android": [], +# "//conditions:default": [ +# "@com_google_absl//absl/strings", +# "//tensorflow/core:test", +# ], +# }), +#) filegroup( name = "all_files", diff --git a/tensorflow/contrib/lite/Makefile b/tensorflow/contrib/lite/Makefile index 78402727abdd2742ffff54bf59ca076d8b97b042..7f316292724ea0baaf034d4e914773ad97a957d4 100644 --- a/tensorflow/contrib/lite/Makefile +++ b/tensorflow/contrib/lite/Makefile @@ -56,7 +56,7 @@ LIBS := \ -lz # If we're on Linux, also link in the dl library. -ifeq ($(OS),LINUX) +ifeq ($(HOST_OS),LINUX) LIBS += -ldl -lpthread endif diff --git a/tensorflow/contrib/lite/README.md b/tensorflow/contrib/lite/README.md index c7464bcc9d39b0e884e76f5a3ffa152e98bb0f47..3e55d2a496c1d83ec0501df27deee4e19a5012a7 100644 --- a/tensorflow/contrib/lite/README.md +++ b/tensorflow/contrib/lite/README.md @@ -4,7 +4,7 @@ TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded dev TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device. ![image](g3doc/TFLite-Architecture.jpg) -# Getting Started with a Demo App +# Getting Started with an Android Demo App This section contains an example application using TensorFlow Lite for Android devices. The demo is a sample camera app that classifies images continuously using a quantized Mobilenet model. A device running Android 5.0 ( API 21) or higher is required to run the demo. @@ -17,7 +17,7 @@ There are 3 ways to get the demo app to your device In the demo app, inference is done using the TensorFlow Lite Java API. The demo app classifies frames in real-time, displaying the top most probable classifications. It also displays the time taken to detect the object. ## Downloading the pre-built binary -The fastest path to trying the demo, is to download the pre-built binary +The fastest path to trying the demo, is to download the pre-built binary [TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified. @@ -69,7 +69,7 @@ android_ndk_repository( Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md). -### Build the source code +### Build the source code Run bazel with the following command to build the demo. Build the demo app: @@ -86,6 +86,17 @@ environment (due to a Bazel bug). ### More about the demo The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app. +# iOS Demo App + +Similar to the Android demo app, there's an iOS camera app that uses exactly the same model (224 * 224 quantized Mobilenet). + +This demo app requires a camera so it doesn't work with simulators. It need to be executed on a real iOS device. Follow the instructions to build and run the demo app: + +1. Run `third_party/tensorflow/contrib/lite/examples/ios/download_models.sh` to download the model files used by the demo app. +1. Install [CocoaPods](https://cocoapods.org/) if it wasn't installed yet: `sudo gem install cocoapods`. +1. Run `pod install` in `tensorflow/contrib/lite/examples/ios/camera` to generate the workspace file. +1. Open the project by running `open tflite_camera_example.xcworkspace`, and build the app in XCode. + # TensorFlow Lite Quick Start ## Step 1. Decide which GraphDef to use @@ -131,7 +142,7 @@ Since we employ several formats, the following definitions may be useful: - SavedModel - A collection of GraphDef and CheckPoint together with a signature that labels input and output arguments to a model. A GraphDef and Checkpoint can be extracted from a saved model. - - TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. + - TensorFlow lite model (.tflite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs. ### Freeze Graph To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph. @@ -153,17 +164,18 @@ bazel-bin/tensorflow/python/tools/freeze_graph\ The user has to first build the freeze_graph script using bazel and then run the script. The input_binary flag has to be enabled to ensure that the protobuf is read and written in binary format. The user has to input the .pb and the .ckpt files to freeze the graph The output_node_names may not be obvious outside of the code that built the model. The easiest way to find them is to visualize the graph, either with graphviz, or [in tensorboard](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets-2/#3). -This frozen Graphdef is now ready to be converted to flatbuffer format (.lite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. +This frozen Graphdef is now ready to be converted to flatbuffer format (.tflite) for use on Android or iOS. On Android users have the flexibility to use either the float or quantized versions of the frozen graphdef, if available, using the Tensorflow Optimizing Converter tool. -Here is a sample command line to convert the frozen Graphdef to '.lite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. +Here is a sample command line to convert the frozen Graphdef to '.tflite' format for The Tensorflow Optimizing Converter supports both float and quantized models, however, different configuration parameters are needed depending on whether a FLOAT or QUANTIZED mode is being used. +(Here is a link to the pb [file](https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_1.0_224_frozen.tgz)). ``` bazel build tensorflow/contrib/lite/toco:toco -bazel-bin/tensorflow/contrib/lite/toco/toco -- \ +bazel-bin/tensorflow/contrib/lite/toco/toco \ --input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ - --output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \ + --output_file=/tmp/mobilenet_v1_1.0_224.tflite --inference_type=FLOAT \ --input_type=FLOAT --input_arrays=input \ --output_arrays=MobilenetV1/Predictions/Reshape_1 --input_shapes=1,224,224,3 ``` @@ -174,9 +186,9 @@ bazel-bin/tensorflow/contrib/lite/toco/toco -- \ - Setting the input_array, output_array and input_shape arguments are a bit trickier. The easiest way to find these values is to explore the graph in tensorboard . The user should reuse the arguments that were used for specifying the output nodes for inference in the `freeze_graph`step. Note, it is also possible to use the Tensorflow Optimizing Converter through protos either from Python or from the command line see the -documentation [here](https://github.com/tensorflow/tensorflow/tree/mastertensorflow/contrib/lite/python:toco_from_protos target) A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, +documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/python/toco_from_protos.py). A developer can then integrate the conversion step into their model design workflow to ensure that a model will be easily convertible to a mobile inference graph. For example, -``` +```python import tensorflow as tf img = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3)) @@ -191,9 +203,15 @@ For detailed instructions on how to use the Tensorflow Optimizing Converter, ple You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues). +If you would like to see a visual description of your TensorFlow Lite model after conversion, you can use tensorflow/contrib/lite/tools/visualize.py by running +```sh +bazel run tensorflow/contrib/lite/tools:visualize -- model.tflite model_viz.html +``` +and then visualize the resulting HTML file in a browser. + ## Step 3. Use the TensorFlow Lite model for inference in a mobile app -After completion of Step 2 the developer should have a .lite model. +After completion of Step 2 the developer should have a .tflite model. ### For Android Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app). @@ -204,3 +222,7 @@ Note that you'd need to follow instructions for installing TensorFlow on Android ### For iOS Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app. + +## Core ML support + +Core ML is a machine learning framework used across Apple products. In addition to using Tensorflow Lite models directly in their applications, developers have the option to convert their trained Tensorflow models to the [CoreML](https://developer.apple.com/machine-learning/) format for use on Apple devices. For information on how to use the converter please refer to the [Tensorflow-CoreML converter documentation](https://github.com/tf-coreml/tf-coreml). diff --git a/tensorflow/contrib/lite/allocation.h b/tensorflow/contrib/lite/allocation.h index ee8a7ccd0b232f9e48095567fd4aefe94f595bc3..68aee2e64473320c461ec8b3f194904e7b8da43c 100644 --- a/tensorflow/contrib/lite/allocation.h +++ b/tensorflow/contrib/lite/allocation.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Main abstraction controlling the tflite interpreter. // See context.h for the API for defining operations (TfLiteRegistration). -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#define TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ #include #include @@ -91,4 +91,4 @@ class MemoryAllocation : public Allocation { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ +#endif // TENSORFLOW_CONTRIB_LITE_ALLOCATION_H_ diff --git a/tensorflow/contrib/lite/arena_planner.cc b/tensorflow/contrib/lite/arena_planner.cc new file mode 100644 index 0000000000000000000000000000000000000000..87b17c338e7afc33d32dd9688cc0825ac319dd19 --- /dev/null +++ b/tensorflow/contrib/lite/arena_planner.cc @@ -0,0 +1,251 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/arena_planner.h" + +namespace tflite { + +namespace { + +// Memory allocation tuning +constexpr const int kDefaultArenaAlignment = 64; +constexpr const int kDefaultTensorAlignment = 4; + +} // namespace + +struct AllocationInfo { + // The node index requesting this allocation. + int node; + // The tensor index to be allocated or deallocated. + int tensor; + // Whether to allocate or deallocate + enum { ALLOC, DEALLOC } type; +}; + +ArenaPlanner::ArenaPlanner(TfLiteContext* context, + std::unique_ptr graph_info) + : context_(context), + graph_info_(std::move(graph_info)), + arena_(kDefaultArenaAlignment), + persistent_arena_(kDefaultArenaAlignment) {} + +ArenaPlanner::~ArenaPlanner() {} + +int64_t ArenaPlanner::BasePointer(TfLiteAllocationType type) { + if (type == kTfLiteArenaRwPersistent) { + return persistent_arena_.BasePointer(); + } + if (type == kTfLiteArenaRw) { + return arena_.BasePointer(); + } + return 0; +} + +TfLiteStatus ArenaPlanner::ResetAllocations() { + TF_LITE_ENSURE_STATUS(arena_.Clear()); + TF_LITE_ENSURE_STATUS(persistent_arena_.Clear()); + allocs_.clear(); + allocs_.resize(graph_info_->num_tensors()); + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::PlanAllocations() { + // Invalidate any existing data. + TF_LITE_ENSURE_STATUS(ResetAllocations()); + + // Keeps track of references to each tensor. + std::vector refcounts(graph_info_->num_tensors(), 0); + + // There will be an entry in alloc_queue_ for the allocation of each tensor + // and another for their deallocation. + alloc_queue_.reserve(2 * graph_info_->num_tensors()); + + // We must make sure the output tensors are never overwritten. We do that by + // artificially adding one to their ref-counts so they are never selected + // for deallocation. + for (int tensor_index : graph_info_->outputs()) { + refcounts[tensor_index]++; + } + + // Count references to node input tensors. + for (int i = 0; i < graph_info_->num_nodes(); ++i) { + const TfLiteNode& node = graph_info_->node(i); + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]++; + } + } + } + + // Queue all graph inputs for allocation. + for (int tensor_index : graph_info_->inputs()) { + if (tensor_index != kOptionalTensor) { + alloc_queue_.push_back({0, tensor_index, AllocationInfo::ALLOC}); + } + } + + // Go through the graph in execution order. + for (int i = 0; i < graph_info_->num_nodes(); ++i) { + const TfLiteNode& node = graph_info_->node(i); + + // First queue output tensors for allocation. + TfLiteIntArray* node_outputs = node.outputs; + for (int j = 0; j < node_outputs->size; ++j) { + int tensor_index = node_outputs->data[j]; + alloc_queue_.push_back({i, tensor_index, AllocationInfo::ALLOC}); + } + + // Then update the ref-counts of the node's inputs, and if necessary queue + // them for deallocation. + TfLiteIntArray* node_inputs = node.inputs; + for (int j = 0; j < node_inputs->size; ++j) { + int tensor_index = node_inputs->data[j]; + if (tensor_index != kOptionalTensor) { + refcounts[tensor_index]--; + if (refcounts[tensor_index] == 0) { + alloc_queue_.push_back({i, tensor_index, AllocationInfo::DEALLOC}); + } + } + } + } + + // Note that graph outputs will never be scheduled for deallocation. We + // could do that here for completeness, but it won't have any effect. + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::ExecuteAllocations(int first_node, int last_node) { + TF_LITE_ENSURE_STATUS(CalculateAllocations(first_node, last_node)); + TF_LITE_ENSURE_STATUS(Commit()); + + for (int i = 0; i < graph_info_->num_tensors(); ++i) { + // TODO(ahentz): we could do this only for the tensors that were modified + // in CalculateAllocations(), instead of redoing it for tensors that + // already had proper pointers. However we must be very careful, because + // SimpleMemoryArena::Commit() could move the base pointer. + TF_LITE_ENSURE_STATUS(ResolveTensorAllocation(i)); + } + + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::Commit() { + TF_LITE_ENSURE_STATUS(arena_.Commit(context_)); + TF_LITE_ENSURE_STATUS(persistent_arena_.Commit(context_)); + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateAllocations(int first_node, int last_node) { + int active_node = first_node; + // When dynamic tensors are present this method is called multiple times. + // The items in the alloc_queue_ referring to nodes before first_node were + // processed previously and should be skipped. Entries after last_node are + // not yet ready to be handled. + for (const auto& alloc_info : alloc_queue_) { + if (alloc_info.node < first_node) continue; + if (alloc_info.node > last_node) break; + if (alloc_info.node == active_node) { + // This is the first allocation/deallocation for a given node. It is + // time to deallocate the previous temporaries and allocate new ones. + if (active_node != first_node) { + TF_LITE_ENSURE_STATUS( + CalculateDeallocationOfInternalTensors(active_node - 1)); + } + TF_LITE_ENSURE_STATUS(CalculateAllocationOfInternalTensors(active_node)); + ++active_node; + } + // Handle the current item. + if (alloc_info.type == AllocationInfo::ALLOC) { + TF_LITE_ENSURE_STATUS(CalculateTensorAllocation(alloc_info.tensor)); + } else { + TF_LITE_ENSURE_STATUS(CalculateTensorDeallocation(alloc_info.tensor)); + } + } + + // Don't forget to deallocate temporaries of last node. + TF_LITE_ENSURE_STATUS( + CalculateDeallocationOfInternalTensors(active_node - 1)); + + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::ResolveTensorAllocation(int tensor_index) { + TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); + if (tensor.allocation_type == kTfLiteArenaRw) { + // Skip resolution if the size of the tensor is zero, leaving it as a + // nullptr. + if (allocs_[tensor_index].size != 0) { + TF_LITE_ENSURE_STATUS(arena_.ResolveAlloc(context_, allocs_[tensor_index], + &tensor.data.raw)); + } + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_STATUS(persistent_arena_.ResolveAlloc( + context_, allocs_[tensor_index], &tensor.data.raw)); + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateTensorAllocation(int tensor_index) { + TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_STATUS(arena_.Allocate(context_, kDefaultTensorAlignment, + tensor.bytes, + &allocs_[tensor_index])); + } + if (tensor.allocation_type == kTfLiteArenaRwPersistent) { + TF_LITE_ENSURE_STATUS( + persistent_arena_.Allocate(context_, kDefaultTensorAlignment, + tensor.bytes, &allocs_[tensor_index])); + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateTensorDeallocation(int tensor_index) { + TfLiteTensor& tensor = *graph_info_->tensor(tensor_index); + if (tensor.allocation_type == kTfLiteArenaRw) { + TF_LITE_ENSURE_STATUS(arena_.Deallocate(context_, allocs_[tensor_index])); + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateAllocationOfInternalTensors( + int node_index) { + if (node_index < graph_info_->num_nodes()) { + const TfLiteNode& node = graph_info_->node(node_index); + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TF_LITE_ENSURE_STATUS(CalculateTensorAllocation(tensor_index)); + } + } + return kTfLiteOk; +} + +TfLiteStatus ArenaPlanner::CalculateDeallocationOfInternalTensors( + int node_index) { + if (node_index < graph_info_->num_nodes()) { + const TfLiteNode& node = graph_info_->node(node_index); + TfLiteIntArray* node_temporaries = node.temporaries; + for (int i = 0; i < node_temporaries->size; ++i) { + int tensor_index = node_temporaries->data[i]; + TF_LITE_ENSURE_STATUS(CalculateTensorDeallocation(tensor_index)); + } + } + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/arena_planner.h b/tensorflow/contrib/lite/arena_planner.h new file mode 100644 index 0000000000000000000000000000000000000000..58bc164619c2c053b9492e9a0e5de2da30e199af --- /dev/null +++ b/tensorflow/contrib/lite/arena_planner.h @@ -0,0 +1,107 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ +#define TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ + +#include +#include + +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/graph_info.h" +#include "tensorflow/contrib/lite/memory_planner.h" +#include "tensorflow/contrib/lite/simple_memory_arena.h" + +namespace tflite { + +class AllocationInfo; + +// A memory planner that makes all the allocations using arenas. +// +// Before a model is executed by the interpreter, this class determines when +// each tensor needs to be allocated and deallocated, and preallocates all the +// necessary memory (the PlanAllocations phase). It then assigns portions of +// this memory buffer to each tensor (the ExecuteAllocations phase). Tensors may +// share some of the bufer if a tensor B is to be allocated after another tensor +// A has been deallocated. +// +// If dynamic tensors are used the planning steps can be repeated during model +// execution. Since dynamic tensors don't have sizes until after the +// corresponding operation is executed, this class supports incremental +// planning. +class ArenaPlanner : public MemoryPlanner { + public: + // Ownership of 'context' is not taken and it must remain util the + // ArenaPlanner is destroyed. + ArenaPlanner(TfLiteContext* context, std::unique_ptr graph_info); + ~ArenaPlanner() override; + ArenaPlanner(const ArenaPlanner&) = delete; + ArenaPlanner& operator=(const ArenaPlanner&) = delete; + + TfLiteStatus ResetAllocations() override; + TfLiteStatus PlanAllocations() override; + TfLiteStatus ExecuteAllocations(int first_node, int last_node) override; + + // Returns the base arena location for a given allocation type. + int64_t BasePointer(TfLiteAllocationType type); + + private: + // Make sure all the arenas have reserved enough memory to store all their + // tensors. + TfLiteStatus Commit(); + + // Traverse the allocation queue and reserve space in the appropriate arena + // for all tensors affected by ops in the interval [first_node, last_node]. + TfLiteStatus CalculateAllocations(int first_node, int last_node); + + // Assign absolute memory location to a tensor, based on its relative + // position inside the corresponding arena buffer. + TfLiteStatus ResolveTensorAllocation(int tensor_index); + + // Register an allocation for the given tensor. + TfLiteStatus CalculateTensorAllocation(int tensor_index); + + // Register a deallocation for the given tensor. + TfLiteStatus CalculateTensorDeallocation(int tensor_index); + + // Register an allocation for all internal (temporary) tensors of + // 'node_index'. + TfLiteStatus CalculateAllocationOfInternalTensors(int node_index); + + // Register a deallocation for all internal (temporary) tensors of + // 'node_index'. + TfLiteStatus CalculateDeallocationOfInternalTensors(int node_index); + + TfLiteContext* context_; + std::unique_ptr graph_info_; + + // Stores allocation data for all tensors. + std::vector allocs_; + + // A chronological list of instructions to allocated and deallocate tensors, + // reflecting the way they are used in the graph. + std::vector alloc_queue_; + + // Raw memory buffer that is allocated for all temporary and graph outputs. + // that are declared kTfLiteArenaRw. + SimpleMemoryArena arena_; + + // Raw memory buffer that is allocated for persistent tensors that are + // declared as kTfLiteArenaRwPersistent. + SimpleMemoryArena persistent_arena_; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_ARENA_PLANNER_H_ diff --git a/tensorflow/contrib/lite/arena_planner_test.cc b/tensorflow/contrib/lite/arena_planner_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8a8755e2c9e81474f2ff9cd2b85c0eb3d5c3441 --- /dev/null +++ b/tensorflow/contrib/lite/arena_planner_test.cc @@ -0,0 +1,468 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/arena_planner.h" + +#include + +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tflite { +namespace { + +// A simple op to be used in tests, as syntactic sugar. +class TestOp { + public: + TestOp(std::initializer_list inputs, std::initializer_list outputs, + std::initializer_list temporaries) + : inputs_(inputs), outputs_(outputs), temporaries_(temporaries) {} + + const std::vector& inputs() const { return inputs_; } + const std::vector& outputs() const { return outputs_; } + const std::vector& temporaries() const { return temporaries_; } + + private: + std::vector inputs_; + std::vector outputs_; + std::vector temporaries_; +}; + +// A test graph where inputs are processed by the given nodes to produce +// outputs. +class TestGraph { + public: + TestGraph(std::initializer_list inputs, + std::initializer_list nodes, + std::initializer_list outputs) + : inputs_(inputs), outputs_(outputs) { + int max_tensor_index = 0; + + for (int t : inputs) { + max_tensor_index = std::max(max_tensor_index, t); + } + for (int t : outputs) { + max_tensor_index = std::max(max_tensor_index, t); + } + for (const auto& node : nodes) { + auto int_array = [](const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; + }; + + nodes_.push_back(TfLiteNode()); + nodes_.back().inputs = int_array(node.inputs()); + for (int t : node.inputs()) { + max_tensor_index = std::max(max_tensor_index, t); + } + nodes_.back().outputs = int_array(node.outputs()); + for (int t : node.outputs()) { + max_tensor_index = std::max(max_tensor_index, t); + } + nodes_.back().temporaries = int_array(node.temporaries()); + for (int t : node.temporaries()) { + max_tensor_index = std::max(max_tensor_index, t); + } + } + + for (int i = 0; i <= max_tensor_index; ++i) { + tensors_.push_back(TfLiteTensor()); + // Set some default values for allocation_type and bytes, which are the + // only fields used by the arena planner. + tensors_.back().allocation_type = kTfLiteArenaRw; + tensors_.back().bytes = (i + 1) * 3; + } + } + + ~TestGraph() { + for (auto node : nodes_) { + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + TfLiteIntArrayFree(node.temporaries); + } + } + + const std::vector& nodes() { return nodes_; } + std::vector* tensors() { return &tensors_; } + const std::vector& inputs() { return inputs_; } + const std::vector& outputs() { return outputs_; } + + private: + std::vector nodes_; + std::vector tensors_; + std::vector inputs_; + std::vector outputs_; +}; + +// The GraphInfo for a TestGraph. +class TestGraphInfo : public GraphInfo { + public: + explicit TestGraphInfo(TestGraph* graph) : graph_(graph) {} + + size_t num_tensors() const override { return graph_->tensors()->size(); } + TfLiteTensor* tensor(size_t index) override { + return &graph_->tensors()->at(index); + } + size_t num_nodes() const override { return graph_->nodes().size(); } + const TfLiteNode& node(size_t index) const override { + return graph_->nodes()[index]; + } + const std::vector& inputs() const override { return graph_->inputs(); } + const std::vector& outputs() const override { return graph_->outputs(); } + + private: + TestGraph* graph_; +}; + +void ReportError(TfLiteContext* context, const char* format, ...) { + const size_t kBufferSize = 1024; + char temp_buffer[kBufferSize]; + + va_list args; + va_start(args, format); + vsnprintf(temp_buffer, kBufferSize, format, args); + va_end(args); + + LOG(INFO) << temp_buffer; +} + +class ArenaPlannerTest : public ::testing::Test { + protected: + void SetGraph(TestGraph* graph) { + graph_ = graph; + context_.ReportError = ReportError; + planner_.reset(new ArenaPlanner( + &context_, std::unique_ptr(new TestGraphInfo(graph)))); + CHECK(planner_->ResetAllocations() == kTfLiteOk); + CHECK(planner_->PlanAllocations() == kTfLiteOk); + } + + void Execute(int start, int end) { + CHECK(planner_->ExecuteAllocations(start, end) == kTfLiteOk); + } + + // Returns the actual offset of a given tensor, relative to the start of its + // arena. + int64_t GetOffset(int tensor_index) { + const TfLiteTensor& tensor = (*graph_->tensors())[tensor_index]; + return reinterpret_cast(tensor.data.raw) - + planner_->BasePointer(tensor.allocation_type); + } + + // Returns the first aligned offset after a given tensor. + int64_t GetOffsetAfter(int tensor_index) { + const TfLiteTensor& tensor = (*graph_->tensors())[tensor_index]; + int64_t offset = GetOffset(tensor_index) + tensor.bytes; + // We must make sure the offset is aligned to kDefaultArenaAlignment. + if (offset % 4 != 0) { + offset += 4 - offset % 4; + } + return offset; + }; + + TfLiteContext context_; + TestGraph* graph_; + std::unique_ptr planner_; +}; + +TEST_F(ArenaPlannerTest, EmptyGraph) { + TestGraph graph({}, {}, {}); + SetGraph(&graph); + Execute(0, 10); +} + +TEST_F(ArenaPlannerTest, GraphWithNoOps) { + TestGraph graph({0, 10}, {}, {5, 11}); + SetGraph(&graph); + Execute(0, 10); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(10), GetOffsetAfter(0)); + // The outputs are never allocated because they are not connected to any + // inputs. + EXPECT_TRUE((*graph.tensors())[5].data.raw == nullptr); + EXPECT_TRUE((*graph.tensors())[11].data.raw == nullptr); +} + +TEST_F(ArenaPlannerTest, GraphWithOneOp) { + TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); + SetGraph(&graph); + Execute(0, 10); + EXPECT_EQ(GetOffset(1), 0); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); +} + +TEST_F(ArenaPlannerTest, ZeroSizedTensors) { + TestGraph graph({1}, {{{1}, {2}, {}}}, {2}); + (*graph.tensors())[1].bytes = 0; + SetGraph(&graph); + // TODO(ahentz): this is currently broken because the arena finds two + // allocations with the same offset and returns an error. + ASSERT_FALSE(planner_->ExecuteAllocations(0, 10) == kTfLiteOk); + // EXPECT_EQ(GetOffset(1), 0); + // EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); +} + +TEST_F(ArenaPlannerTest, SimpleGraph) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, 5}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +4 +5 -2 -0 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithTemporary) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4}, {3}, {}} // Third op + }, + {3}); + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithOptionals) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4, 5}, {}}, // Second op + {{4, -1, 5}, {3}, {}} // Third op, with optional + }, + {3}); + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +4 +5 -2 -0 +3 -4 -5 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithLargeTensor) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4, -1}, {3}, {}} // Third op, with optional + }, + {3}); + + // Make #1 very large so its vacancy can be filled with #5 and #4. + (*graph.tensors())[1].bytes = 40; + + SetGraph(&graph); + Execute(0, 10); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithPersistentTensor) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4, -1}, {3}, {}} // Third op, with optional + }, + {3}); + + // Make #1 persistent so it goes into its own arena. + (*graph.tensors())[1].allocation_type = kTfLiteArenaRwPersistent; + + SetGraph(&graph); + Execute(0, 10); + + // Make sure #0 and #1 were given different memory locations (because they + // will both have offset=0, in different arenas.) + EXPECT_NE((*graph.tensors())[0].data.raw, (*graph.tensors())[1].data.raw); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), 0); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, SimpleGraphWithDynamicTensor) { + TestGraph graph({0, -1, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2}, {}}, // First op + {{2, 0}, {4}, {5}}, // Second op, with temporary + {{4, -1}, {3}, {}} // Third op, with optional + }, + {3}); + + // Make #1 dynaic so it does not get allocated. + (*graph.tensors())[1].allocation_type = kTfLiteDynamic; + + SetGraph(&graph); + Execute(0, 10); + + EXPECT_EQ((*graph.tensors())[1].data.raw, nullptr); + + // Alloc(+) and dealloc(-) order: +0 +1 +2 -1 +5 +4 -2 -0 -5 +3 -4 + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(3), 0); +} + +TEST_F(ArenaPlannerTest, LargerGraphAndStepwiseAllocation) { + TestGraph graph({0, 1}, + { + /* in, out, tmp */ + {{0, 1}, {2, 3}, {}}, + {{2, 0}, {4, 5}, {6}}, + {{1, -1}, {7}, {}}, + {{7, 3}, {8}, {9}}, + {{4, 5, 8}, {10}, {}}, + }, + {10}); + SetGraph(&graph); + + auto is_unallocated = [&](int tensor_index) { + return (*graph.tensors())[tensor_index].data.raw == nullptr; + }; + + // The allocation plan is made at the beginning and is independent of + // the execution steps. Here's the allocation order: + // Op0: +0 +1 +2 +3 + // Op1: +6 +4 +5 -6 -0 -2 + // Op2: +7 -1 + // Op3: +9 +8 -9 -3 -7 + // Op4: +10 -4 -5 -8 + + Execute(0, 0); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_TRUE(is_unallocated(6)); + EXPECT_TRUE(is_unallocated(4)); + EXPECT_TRUE(is_unallocated(5)); + EXPECT_TRUE(is_unallocated(7)); + EXPECT_TRUE(is_unallocated(9)); + EXPECT_TRUE(is_unallocated(8)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(1, 1); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_TRUE(is_unallocated(7)); + EXPECT_TRUE(is_unallocated(9)); + EXPECT_TRUE(is_unallocated(8)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(2, 2); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + // Here's an interesting allocation. Even though #6 requires only 21 bytes, + // its deallocation freed up 24 bytes due to the alignment requirements in + // the arena. That means we can fit #7 in the same space! + EXPECT_EQ(GetOffset(7), GetOffsetAfter(3)); + EXPECT_TRUE(is_unallocated(9)); + EXPECT_TRUE(is_unallocated(8)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(3, 3); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(7), GetOffsetAfter(3)); + // The deallocation of #0, #1 and #2 freed up 24 bytes but that's not enough + // for #9, so it goes at the end. + EXPECT_EQ(GetOffset(9), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(8), GetOffsetAfter(9)); + EXPECT_TRUE(is_unallocated(10)); + + Execute(4, 4); + EXPECT_EQ(GetOffset(0), 0); + EXPECT_EQ(GetOffset(1), GetOffsetAfter(0)); + EXPECT_EQ(GetOffset(2), GetOffsetAfter(1)); + EXPECT_EQ(GetOffset(3), GetOffsetAfter(2)); + EXPECT_EQ(GetOffset(6), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(4), GetOffsetAfter(6)); + EXPECT_EQ(GetOffset(5), GetOffsetAfter(4)); + EXPECT_EQ(GetOffset(7), GetOffsetAfter(3)); + EXPECT_EQ(GetOffset(9), GetOffsetAfter(5)); + EXPECT_EQ(GetOffset(8), GetOffsetAfter(9)); + // There's just enough space at the beginning for #10 due to the + // deallocation of #0, #1, #2 and #3 (total 36 bytes, #10 needs + // only 33.) + EXPECT_EQ(GetOffset(10), 0); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index e3c9cdd99beb93e356c148298dcbe6498fbe0306..19829e4991651111e13fc1805f97daef8bc016a7 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -5,25 +5,25 @@ def tflite_copts(): copts = [ "-DFARMHASH_NO_CXX_STRING", ] + select({ - "//tensorflow:android_arm64": [ + str(Label("//tensorflow:android_arm64")): [ "-std=c++11", "-O3", ], - "//tensorflow:android_arm": [ + str(Label("//tensorflow:android_arm")): [ "-mfpu=neon", "-mfloat-abi=softfp", "-std=c++11", "-O3", ], - "//tensorflow:android_x86": [ + str(Label("//tensorflow:android_x86")): [ "-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK", ], - "//tensorflow:ios_x86_64": [ + str(Label("//tensorflow:ios_x86_64")): [ "-msse4.1", ], "//conditions:default": [], }) + select({ - "//tensorflow:with_default_optimizations": [], + str(Label("//tensorflow:with_default_optimizations")): [], "//conditions:default": ["-DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK"], }) @@ -89,11 +89,11 @@ def tflite_jni_linkopts(): return tflite_jni_linkopts_unstripped() + select({ "//tensorflow:android": [ "-s", # Omit symbol table. + "-latomic", # Required for some uses of ISO C++11 in x86. ], "//conditions:default": [], }) - def tflite_jni_binary(name, copts=tflite_copts(), linkopts=tflite_jni_linkopts(), @@ -223,11 +223,12 @@ def gen_selected_ops(name, model): """ out = name + "_registration.cc" tool = "//tensorflow/contrib/lite/tools:generate_op_registrations" + tflite_path = "//tensorflow/contrib/lite" native.genrule( name = name, srcs = [model], outs = [out], - cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s)") - % (tool, model, out), + cmd = ("$(location %s) --input_model=$(location %s) --output_registration=$(location %s) --tflite_path=%s") + % (tool, model, out, tflite_path[2:]), tools = [tool], ) diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh index e0f2ef768bfed544ed8acd6c0e3a5823e61a1e8c..4a9023ff33de15dd384531d51e39de4ffeecdb8b 100755 --- a/tensorflow/contrib/lite/build_ios_universal_lib.sh +++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh @@ -1,5 +1,24 @@ #!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." + make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8 make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8 make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8 diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 93072bf90bd8a18d9011a74c2eec95d86dbdce8a..5dbeadd16582ec586adab100b8a46e10182bd5ee 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#define TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ #include @@ -83,7 +83,14 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteRNNParams; -typedef struct { TfLiteFusedActivation activation; } TfLiteFullyConnectedParams; +typedef struct { + bool time_major; + TfLiteFusedActivation activation; +} TfLiteSequenceRNNParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteFullyConnectedParams; typedef enum { kTfLiteLshProjectionUnknown = 0, @@ -91,9 +98,13 @@ typedef enum { kTfLiteLshProjectionDense = 2, } TfLiteLSHProjectionType; -typedef struct { TfLiteLSHProjectionType type; } TfLiteLSHProjectionParams; +typedef struct { + TfLiteLSHProjectionType type; +} TfLiteLSHProjectionParams; -typedef struct { float beta; } TfLiteSoftmaxParams; +typedef struct { + float beta; +} TfLiteSoftmaxParams; typedef struct { int axis; @@ -104,10 +115,24 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteAddParams; +typedef struct { +} TfLiteSpaceToBatchNDParams; + +typedef struct { +} TfLiteBatchToSpaceNDParams; + typedef struct { TfLiteFusedActivation activation; } TfLiteMulParams; +typedef struct { + TfLiteFusedActivation activation; +} TfLiteSubParams; + +typedef struct { + TfLiteFusedActivation activation; +} TfLiteDivParams; + typedef struct { TfLiteFusedActivation activation; } TfLiteL2NormParams; @@ -126,10 +151,12 @@ typedef struct { } TfLiteLSTMParams; typedef struct { - int new_height; - int new_width; + bool align_corners; } TfLiteResizeBilinearParams; +typedef struct { +} TfLitePadParams; + typedef struct { // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. // For now we will fix the maximum possible number of dimensions. @@ -157,8 +184,34 @@ typedef struct { TfLiteCombinerType combiner; } TfLiteEmbeddingLookupSparseParams; +typedef struct { + int axis; +} TfLiteGatherParams; + +typedef struct { +} TfLiteTransposeParams; + +typedef struct { + bool keep_dims; +} TfLiteMeanParams; + +typedef struct { + // TODO(ahentz): We can't have dynamic data in this struct, at least not yet. + // For now we will fix the maximum possible number of dimensions. + int squeeze_dims[8]; + int num_squeeze_dims; +} TfLiteSqueezeParams; + +typedef struct { + int begin_mask; + int end_mask; + int ellipsis_mask; + int new_axis_mask; + int shrink_axis_mask; +} TfLiteStridedSliceParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ +#endif // TENSORFLOW_CONTRIB_LITE_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h index 41257a53b145cbe7e252c9d4de6ea7ef654431b5..b0c4d3431f9a67bc87d51ada91ed73f1661023a2 100644 --- a/tensorflow/contrib/lite/context.h +++ b/tensorflow/contrib/lite/context.h @@ -26,8 +26,8 @@ limitations under the License. // TfLiteRegistration - the implementation of a conceptual operation. // // Some abstractions in this file are created and managed by Interpreter. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#define TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ #include #include @@ -38,6 +38,9 @@ extern "C" { typedef enum { kTfLiteOk = 0, kTfLiteError = 1 } TfLiteStatus; +// Forward declare so GetNode can use this is in Context. +typedef struct _TfLiteRegistration TfLiteRegistration; + #define kOptionalTensor (-1) // Fixed size list of integers. Used for dimensions and inputs/outputs tensor @@ -141,6 +144,7 @@ typedef struct { // A union of points that points to memory for a given tensor. typedef union { int* i32; + int64_t* i64; float* f; char* raw; const char* raw_const; @@ -204,9 +208,56 @@ void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims, // Resize the allocated data of a (dynamic) tensor. void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor); +// A structure representing an instance of a node. +// This structure only exhibits the inputs, outputs and user defined data, not +// other features like the type. +typedef struct { + // Inputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* inputs; + + // Outputs to this node expressed as indices into the simulator's tensors. + TfLiteIntArray* outputs; + + // Temporary tensors uses during the computations. This usually contains no + // tensors, but ops are allowed to change that if they need scratch space of + // any sort. + TfLiteIntArray* temporaries; + + // Opaque data provided by the node implementer through `Registration.init`. + void* user_data; + + // Opaque data provided to the node if the node is a builtin. This is usually + // a structure defined in builtin_op_data.h + void* builtin_data; + + // Custom initial data. This is the opaque data provided in the flatbuffer. + // WARNING: This is an experimental interface that is subject to change. + const void* custom_initial_data; + int custom_initial_data_size; +} TfLiteNode; + typedef struct TfLiteContext { // Number of tensors in the context. int tensors_size; + + // The execution plan contains a list of the node indices in execution + // order. execution_plan->size is the current number of nodes. And, + // execution_plan->data[0] is the first node that needs to be run. + // TfLiteDelegates can traverse the current execution plan by iterating + // through each member of this array and using GetNodeAndRegistration() to + // access details about a node. i.e. + // TfLiteIntArray* execution_plan; + // TF_LITE_ENSURE_STATUS(context->GetExecutionPlan(context, &execution_plan)); + // for (int exec_index = 0; exec_index < execution_plan->size; exec_index++) { + // int node_index = execution_plan->data[exec_index]; + // TfLiteNode* node; + // TfLiteRegistration* reg; + // context->GetNodeAndRegistration(context, node_index, &node, ®); + // } + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetExecutionPlan)(struct TfLiteContext* context, + TfLiteIntArray** execution_plan); + // An tensor of tensors in the interpreter context (of length `tensors_size`) TfLiteTensor* tensors; @@ -226,34 +277,23 @@ typedef struct TfLiteContext { TfLiteStatus (*AddTensors)(struct TfLiteContext*, int tensors_to_add, int* first_new_tensor_index); + // Get a Tensor node by node_index. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus (*GetNodeAndRegistration)(struct TfLiteContext*, int node_index, + TfLiteNode** node, + TfLiteRegistration** registration); + + // Replace ops with delegate. + TfLiteStatus (*ReplaceSubgraphsWithDelegateKernels)( + struct TfLiteContext*, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace); + // TODO(ahentz): we should create a more general mechanism for this sort of // library-global objects. void* gemm_context; } TfLiteContext; -// A structure representing an instance of a node. -// This structure only exhibits the inputs, outputs and user defined data, not -// other features like the type. -typedef struct { - // Inputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* inputs; - - // Outputs to this node expressed as indices into the simulator's tensors. - TfLiteIntArray* outputs; - - // Temporary tensors uses during the computations. This usually contains no - // tensors, but ops are allowed to change that if they need scratch space of - // any sort. - TfLiteIntArray* temporaries; - - // Opaque data provided by the node implementer through `Registration.init`. - void* user_data; - - // Opaque data provided to the node if the node is a builtin. - void* builtin_data; -} TfLiteNode; - -typedef struct { +typedef struct _TfLiteRegistration { // Initializes the op from serialized data. // If a built-in op: // `buffer` is the op's params data (TfLiteLSTMParams*). @@ -290,9 +330,27 @@ typedef struct { // NN API. Note, it is the responsibility of the registration binder to // set this properly. int32_t builtin_code; + + // Custom op name. If the op is a builtin, this will be null. + // WARNING: This is an experimental interface that is subject to change. + const char* custom_name; } TfLiteRegistration; +// WARNING: This is an experimental interface that is subject to change. +typedef struct { + // Data that delegate needs to identify itself. This data is owned by the + // delegate. The delegate is owned in the user code, so the delegate is + // responsible for doing this when it is destroyed. + void* data_; + // Invoked by ModifyGraphWithDelegate. This prepare is called, giving the + // delegate a view of the current graph through TfLiteContext*. It typically + // will look at the nodes and call ReplaceSubgraphsWithDelegateKernels() + // to ask the TensorFlow lite runtime to create macro-nodes to represent + // delegated subgraphs of the original graph. + TfLiteStatus (*Prepare)(TfLiteContext* context, void* data); +} TfLiteDelegate; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_CONTEXT_H_ diff --git a/tensorflow/contrib/lite/context_test.cc b/tensorflow/contrib/lite/context_test.cc index d0a104f43d9b9d148d80ce26b8ecf732d51ef110..20d6f69a25e9f0bb4323cf5d067b8ebd37bb3c23 100644 --- a/tensorflow/contrib/lite/context_test.cc +++ b/tensorflow/contrib/lite/context_test.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/contrib/lite/context.h" #include +#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { @@ -68,7 +69,7 @@ TEST(IntArray, TestIntArrayEqual) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh index 571d857be7292998996a4fb8101f0070064aa6be..a93ed201d647ddf2359a57254a959871c13fb94f 100755 --- a/tensorflow/contrib/lite/download_dependencies.sh +++ b/tensorflow/contrib/lite/download_dependencies.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,9 @@ set -e +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR/../../.." + DOWNLOADS_DIR=tensorflow/contrib/lite/downloads BZL_FILE_PATH=tensorflow/workspace.bzl @@ -26,15 +29,13 @@ if [ ! -f $BZL_FILE_PATH ]; then exit 1; fi -EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v mirror.bazel | head -n1)" GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip" FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz" FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip" -MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip" -QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -90,8 +91,6 @@ download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse" download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash" download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers" -download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models" -download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" @@ -100,7 +99,4 @@ replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#s replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" -cp ${DOWNLOADS_DIR}/models/models/* tensorflow/contrib/lite/examples/ios/simple/data/ -cp ${DOWNLOADS_DIR}/quantized_models/* tensorflow/contrib/lite/examples/ios/camera/data/ - echo "download_dependencies.sh completed successfully." >&2 diff --git a/tensorflow/contrib/lite/error_reporter.cc b/tensorflow/contrib/lite/error_reporter.cc index 6ba5384a94dbf9de03fb2e4e2f63074525eafa2d..03fcd5409ceab1895cea3b9e0e4fcb5a127e6a45 100644 --- a/tensorflow/contrib/lite/error_reporter.cc +++ b/tensorflow/contrib/lite/error_reporter.cc @@ -39,7 +39,9 @@ int ErrorReporter::ReportError(void*, const char* format, ...) { } int StderrReporter::Report(const char* format, va_list args) { - return vfprintf(stderr, format, args); + const int result = vfprintf(stderr, format, args); + fputc('\n', stderr); + return result; } ErrorReporter* DefaultErrorReporter() { diff --git a/tensorflow/contrib/lite/error_reporter.h b/tensorflow/contrib/lite/error_reporter.h index 637d456ce7a754c7da34e551869e49b4efd18e3b..da193d2586e9123341b9a41be049ee2a4382017a 100644 --- a/tensorflow/contrib/lite/error_reporter.h +++ b/tensorflow/contrib/lite/error_reporter.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#define TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ #include #include "tensorflow/contrib/lite/context.h" @@ -25,10 +25,10 @@ namespace tflite { // // Usage: // ErrorReporter foo; -// foo.Report("test %d\n", 5); +// foo.Report("test %d", 5); // or // va_list args; -// foo.Report("test %d\n", args); // where args is va_list +// foo.Report("test %d", args); // where args is va_list // // Sublclass ErrorReporter to provide another reporting destination. // For example, if you have a GUI program, you might redirect to a buffer @@ -51,4 +51,4 @@ ErrorReporter* DefaultErrorReporter(); } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_ERROR_REPORTER_H_ diff --git a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm index ea398ad14e8be4c5a0021befc7cc076549b47e23..d74e275f0439b1ce56b29e0eadff5f211f6a4faa 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/camera/CameraExampleViewController.mm @@ -123,7 +123,11 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo]; AVCaptureDeviceInput* deviceInput = [AVCaptureDeviceInput deviceInputWithDevice:device error:&error]; - assert(error == nil); + + if (error != nil) { + NSLog(@"Failed to initialize AVCaptureDeviceInput. Note: This app doesn't work with simulator"); + assert(NO); + } if ([session canAddInput:deviceInput]) [session addInput:deviceInput]; @@ -221,14 +225,8 @@ static void GetTopN(const uint8_t* prediction, const int prediction_size, const assert(pixelBuffer != NULL); OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer); - int doReverseChannels; - if (kCVPixelFormatType_32ARGB == sourcePixelFormat) { - doReverseChannels = 1; - } else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) { - doReverseChannels = 0; - } else { - assert(false); // Unknown source format - } + assert(sourcePixelFormat == kCVPixelFormatType_32ARGB || + sourcePixelFormat == kCVPixelFormatType_32BGRA); const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer); const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer); diff --git a/tensorflow/contrib/lite/examples/ios/camera/Podfile b/tensorflow/contrib/lite/examples/ios/camera/Podfile index 4ae6fb6b94e4489f63506b05a2f348b7daafd3b7..c7d3b1c966eaa0de71f5c37a6a77b3881e30ddd7 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/Podfile +++ b/tensorflow/contrib/lite/examples/ios/camera/Podfile @@ -2,4 +2,4 @@ platform :ios, '8.0' inhibit_all_warnings! target 'tflite_camera_example' - pod 'TensorFlow-experimental' + pod 'TensorFlowLite' diff --git a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj index c98183276bd60d2a0ad023ba26aad12572a02786..b0236e9c608ec35437bcfe79c51149a76f9f416e 100644 --- a/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/lite/examples/ios/camera/tflite_camera_example.xcodeproj/project.pbxproj @@ -16,7 +16,6 @@ 1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; }; 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */; }; AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = AC1F82641FBA3CBD0052BA77 /* labels.txt */; }; - AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */; }; ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */; }; /* End PBXBuildFile section */ @@ -38,7 +37,6 @@ 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.debug.xcconfig"; sourceTree = ""; }; 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.release.xcconfig"; sourceTree = ""; }; AC1F82641FBA3CBD0052BA77 /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; - AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_quant_v1_224.tflite; sourceTree = ""; }; /* End PBXFileReference section */ @@ -47,7 +45,6 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */, 1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */, 1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */, 54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */, @@ -60,7 +57,6 @@ 24D7686C331131624F4454A0 /* Frameworks */ = { isa = PBXGroup; children = ( - AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */, 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */, 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */, 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, @@ -336,7 +332,6 @@ ../../../downloads/, ); IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; MTL_ENABLE_DEBUG_INFO = YES; ONLY_ACTIVE_ARCH = YES; SDKROOT = iphoneos; @@ -384,7 +379,6 @@ ../../../downloads/, ); IPHONEOS_DEPLOYMENT_TARGET = 8.0; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; MTL_ENABLE_DEBUG_INFO = NO; SDKROOT = iphoneos; TARGETED_DEVICE_FAMILY = "1,2"; diff --git a/tensorflow/contrib/lite/examples/ios/download_models.sh b/tensorflow/contrib/lite/examples/ios/download_models.sh new file mode 100755 index 0000000000000000000000000000000000000000..ccd163758c5830dc9367e023dcb3a604e07ca5db --- /dev/null +++ b/tensorflow/contrib/lite/examples/ios/download_models.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -ex + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip" +QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip" +DOWNLOADS_DIR=$(mktemp -d) + +cd $SCRIPT_DIR + +download_and_extract() { + local usage="Usage: download_and_extract URL DIR" + local url="${1:?${usage}}" + local dir="${2:?${usage}}" + echo "downloading ${url}" >&2 + mkdir -p "${dir}" + tempdir=$(mktemp -d) + tempdir2=$(mktemp -d) + + curl -L ${url} > ${tempdir}/zipped.zip + unzip ${tempdir}/zipped.zip -d ${tempdir2} + + # If the zip file contains nested directories, extract the files from the + # inner directory. + if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then + # unzip has no strip components, so unzip to a temp dir, and move the + # files we want from the tempdir to destination. + cp -R ${tempdir2}/*/* ${dir}/ + else + cp -R ${tempdir2}/* ${dir}/ + fi + rm -rf ${tempdir2} ${tempdir} +} + +download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models" +download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models" + +file ${DOWNLOADS_DIR}/models + +cp ${DOWNLOADS_DIR}/models/models/* simple/data/ +cp ${DOWNLOADS_DIR}/quantized_models/* camera/data/ + diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h index 75b1f1da384b527e8332dfba08fec87c65eff8b1..94046d9728258901091f018fd0d081651145f400 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h @@ -14,8 +14,8 @@ #import -@interface AppDelegate : UIResponder +@interface AppDelegate : UIResponder -@property (strong, nonatomic) UIWindow *window; +@property(strong, nonatomic) UIWindow *window; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm index 1e808eb976ff3eeda4cf6f81b3c1794c6a037dc8..d1215fa0bffd978b4aaadbd8bc13b07723703c9a 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm @@ -22,8 +22,7 @@ didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { UITabBarController *bar = [[UITabBarController alloc] init]; - [bar setViewControllers: - @[[[RunModelViewController alloc] init]]]; + [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]]; bar.selectedIndex = 0; self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]]; self.window.rootViewController = bar; @@ -31,14 +30,19 @@ return YES; } -- (void)applicationWillResignActive:(UIApplication *)application {} +- (void)applicationWillResignActive:(UIApplication *)application { +} -- (void)applicationDidEnterBackground:(UIApplication *)application {} +- (void)applicationDidEnterBackground:(UIApplication *)application { +} -- (void)applicationWillEnterForeground:(UIApplication *)application {} +- (void)applicationWillEnterForeground:(UIApplication *)application { +} -- (void)applicationDidBecomeActive:(UIApplication *)application {} +- (void)applicationDidBecomeActive:(UIApplication *)application { +} -- (void)applicationWillTerminate:(UIApplication *)application {} +- (void)applicationWillTerminate:(UIApplication *)application { +} @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/Podfile b/tensorflow/contrib/lite/examples/ios/simple/Podfile index 1740ad64573a84fae6de0fcf284eb06afec67e25..e4aca2be82d437a0225d2c15d3e486b0344aa978 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/Podfile +++ b/tensorflow/contrib/lite/examples/ios/simple/Podfile @@ -1,5 +1,5 @@ platform :ios, '8.0' inhibit_all_warnings! -target 'tf_simple_example' - pod 'TensorFlow-experimental' +target 'tflite_simple_example' + pod 'TensorFlowLite' diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist index 1a3eaa8a2c18d1cd24dfd475d396b00ec4d86c9d..a19a43a7541e3d751116e868dbcbdd607d15ab4a 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModel-Info.plist @@ -7,7 +7,7 @@ CFBundleDisplayName tflite-simple-example CFBundleExecutable - tf_simple_example + tflite_simple_example CFBundleIdentifier $(PRODUCT_BUNDLE_IDENTIFIER) CFBundleInfoDictionaryVersion diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h index 4e1a83ccf5a12c609baadab7359c55ec4f464ed8..a4b358b4eb7f6ba109638405091b798d30bd1768 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h @@ -18,7 +18,7 @@ - (IBAction)getUrl:(id)sender; -@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView; -@property (weak, nonatomic) IBOutlet UITextField *urlTextField; +@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView; +@property(weak, nonatomic) IBOutlet UITextField *urlTextField; @end diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm index 965d83010516c6db72c9e8b1c33079b3eda204de..0ab7aa25d0b4e6d2c02e61ec1d82b85258b3dfbc 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm @@ -14,10 +14,10 @@ #import "RunModelViewController.h" -#include -#include #include #include +#include +#include #include #include #include @@ -29,9 +29,6 @@ #include "ios_image_load.h" -#define LOG(x) std::cerr -#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); } - NSString* RunInferenceOnImage(); @interface RunModelViewController () @@ -49,15 +46,12 @@ NSString* RunInferenceOnImage(); // Returns the top N confidence values over threshold in the provided vector, // sorted by confidence in descending order. -static void GetTopN( - const float* prediction, - const int prediction_size, - const int num_results, const float threshold, - std::vector >* top_results) { +static void GetTopN(const float* prediction, const int prediction_size, const int num_results, + const float threshold, std::vector >* top_results) { // Will contain top N results in ascending order. - std::priority_queue, - std::vector >, - std::greater > > top_result_pq; + std::priority_queue, std::vector >, + std::greater > > + top_result_pq; const long count = prediction_size; for (int i = 0; i < count; ++i) { @@ -88,27 +82,29 @@ static void GetTopN( NSString* FilePathForResourceName(NSString* name, NSString* extension) { NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension]; if (file_path == NULL) { - LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." - << [extension UTF8String] << "' in bundle."; + NSLog(@"Couldn't find '%@.%@' in bundle.", name, extension); + exit(-1); } return file_path; } NSString* RunInferenceOnImage() { - std::string graph; + NSString* graph = @"mobilenet_v1_1.0_224"; const int num_threads = 1; std::string input_layer_type = "float"; std::vector sizes = {1, 224, 224, 3}; - NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite"); + const NSString* graph_path = FilePathForResourceName(graph, @"tflite"); - std::unique_ptr model(tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); + std::unique_ptr model( + tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String])); if (!model) { - LOG(FATAL) << "Failed to mmap model " << graph; + NSLog(@"Failed to mmap model %@.", graph); + exit(-1); } - LOG(INFO) << "Loaded model " << graph; + NSLog(@"Loaded model %@.", graph); model->error_reporter(); - LOG(INFO) << "resolved reporter"; + NSLog(@"Resolved reporter."); #ifdef TFLITE_CUSTOM_OPS_HEADER tflite::MutableOpResolver resolver; @@ -120,7 +116,8 @@ NSString* RunInferenceOnImage() { std::unique_ptr interpreter; tflite::InterpreterBuilder(*model, resolver)(&interpreter); if (!interpreter) { - LOG(FATAL) << "Failed to construct interpreter"; + NSLog(@"Failed to construct interpreter."); + exit(-1); } if (num_threads != -1) { @@ -134,7 +131,8 @@ NSString* RunInferenceOnImage() { } if (interpreter->AllocateTensors() != kTfLiteOk) { - LOG(FATAL) << "Failed to allocate tensors!"; + NSLog(@"Failed to allocate tensors."); + exit(-1); } // Read the label list @@ -143,7 +141,7 @@ NSString* RunInferenceOnImage() { std::ifstream t; t.open([labels_path UTF8String]); std::string line; - while(t){ + while (t) { std::getline(t, line); label_strings.push_back(line); } @@ -154,7 +152,8 @@ NSString* RunInferenceOnImage() { int image_width; int image_height; int image_channels; - std::vector image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); + std::vector image_data = + LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels); const int wanted_width = 224; const int wanted_height = 224; const int wanted_channels = 3; @@ -178,7 +177,8 @@ NSString* RunInferenceOnImage() { } if (interpreter->Invoke() != kTfLiteOk) { - LOG(FATAL) << "Failed to invoke!"; + NSLog(@"Failed to invoke!"); + exit(-1); } float* output = interpreter->typed_output_tensor(0); @@ -208,12 +208,9 @@ NSString* RunInferenceOnImage() { ss << "\n"; } - LOG(INFO) << "Predictions: " << ss.str(); - std::string predictions = ss.str(); NSString* result = @""; - result = [NSString stringWithFormat: @"%@ - %s", result, - predictions.c_str()]; - + result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()]; + NSLog(@"Predictions: %@", result); return result; } diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h index 7287d0d63d5b4c0b9c9a528578b6341cdb9c9954..98934ce41d349b33d4fc010a39a956e52f3d5721 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h @@ -17,9 +17,7 @@ #include -std::vector LoadImageFromFile(const char* file_name, - int* out_width, - int* out_height, - int* out_channels); +std::vector LoadImageFromFile(const char* file_name, int* out_width, + int* out_height, int* out_channels); #endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_ diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm index 789522d2a9900b136f91f77c4ada682f1a316848..cb0fe1a7650c572d3745066431f2759daa94ffc9 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm @@ -14,17 +14,16 @@ #include "ios_image_load.h" -#include -#include #include #include +#include +#include #import #import -std::vector LoadImageFromFile(const char* file_name, - int* out_width, int* out_height, - int* out_channels) { +std::vector LoadImageFromFile(const char* file_name, int* out_width, int* out_height, + int* out_channels) { FILE* file_handle = fopen(file_name, "rb"); fseek(file_handle, 0, SEEK_END); const size_t bytes_in_file = ftell(file_handle); @@ -32,11 +31,10 @@ std::vector LoadImageFromFile(const char* file_name, std::vector file_data(bytes_in_file); fread(file_data.data(), 1, bytes_in_file, file_handle); fclose(file_handle); - CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(), - bytes_in_file, - kCFAllocatorNull); - CGDataProviderRef image_provider = - CGDataProviderCreateWithCFData(file_data_ref); + + CFDataRef file_data_ref = + CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull); + CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref); const char* suffix = strrchr(file_name, '.'); if (!suffix || suffix == file_name) { @@ -44,12 +42,10 @@ std::vector LoadImageFromFile(const char* file_name, } CGImageRef image; if (strcasecmp(suffix, ".png") == 0) { - image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); - } else if ((strcasecmp(suffix, ".jpg") == 0) || - (strcasecmp(suffix, ".jpeg") == 0)) { - image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, - kCGRenderingIntentDefault); + image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); + } else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) { + image = + CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault); } else { CFRelease(image_provider); CFRelease(file_data_ref); @@ -68,9 +64,10 @@ std::vector LoadImageFromFile(const char* file_name, const int bytes_in_image = (bytes_per_row * height); std::vector result(bytes_in_image); const int bits_per_component = 8; - CGContextRef context = CGBitmapContextCreate(result.data(), width, height, - bits_per_component, bytes_per_row, color_space, - kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); + + CGContextRef context = + CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row, + color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big); CGColorSpaceRelease(color_space); CGContextDrawImage(context, CGRectMake(0, 0, width, height), image); CGContextRelease(context); diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm index d70550a730720e5d6799a186c1beb3cfa04b0b9d..05cb55ddd7a230593863e64b351f6aac31a1b4d7 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/main.mm +++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm @@ -14,7 +14,7 @@ #import -int main(int argc, char * argv[]) { +int main(int argc, char *argv[]) { @autoreleasepool { NSString *delegateClassName = @"AppDelegate"; return UIApplicationMain(argc, argv, nil, delegateClassName); diff --git a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj index 9277c230b8cce1b5673a50d32d7640d52e2e8f9d..f5b8382d5ae4ac80a7edb52c34ebaf12ad65f4db 100644 --- a/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj +++ b/tensorflow/contrib/lite/examples/ios/simple/simple.xcodeproj/project.pbxproj @@ -9,7 +9,7 @@ /* Begin PBXBuildFile section */ 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; }; 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; }; - 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */; }; + 1E6F42DBB39A4A3871D4F848 /* libPods-tflite_simple_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 73DBC33C5DD9A526EE6D1EF2 /* libPods-tflite_simple_example.a */; }; 594C14B11FB9037100EE8BFE /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = 594C14AF1FB9037100EE8BFE /* labels.txt */; }; 594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */; }; 59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; }; @@ -24,8 +24,7 @@ 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; }; 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; }; 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; }; - 5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; - 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = ""; }; + 5911579B1CF4011C00C31E3A /* tflite_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tflite_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; }; 594C14AF1FB9037100EE8BFE /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = ""; }; 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = ""; }; 59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = ""; }; @@ -38,7 +37,9 @@ 59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = ""; }; 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = ""; }; 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = ""; }; - 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 5D6203B9FAEEB9824194DBE8 /* Pods-tflite_simple_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_simple_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example.release.xcconfig"; sourceTree = ""; }; + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tflite_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tflite_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; }; + 987DD5BCAB2DD8B682674E20 /* Pods-tflite_simple_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_simple_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example.debug.xcconfig"; sourceTree = ""; }; /* End PBXFileReference section */ /* Begin PBXFrameworksBuildPhase section */ @@ -46,9 +47,9 @@ isa = PBXFrameworksBuildPhase; buildActionMask = 2147483647; files = ( - 594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */, 1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */, 1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */, + 1E6F42DBB39A4A3871D4F848 /* libPods-tflite_simple_example.a in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -58,11 +59,10 @@ 24D7686C331131624F4454A0 /* Frameworks */ = { isa = PBXGroup; children = ( - 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */, 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */, 1C0D73481ECCC41B008C1DAB /* CoreImage.framework */, 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */, - 73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */, + 73DBC33C5DD9A526EE6D1EF2 /* libPods-tflite_simple_example.a */, ); name = Frameworks; sourceTree = ""; @@ -82,13 +82,14 @@ 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */, 5911579C1CF4011C00C31E3A /* Products */, 24D7686C331131624F4454A0 /* Frameworks */, + 5CE7E4179B26BF77944D8637 /* Pods */, ); sourceTree = ""; }; 5911579C1CF4011C00C31E3A /* Products */ = { isa = PBXGroup; children = ( - 5911579B1CF4011C00C31E3A /* tf_simple_example.app */, + 5911579B1CF4011C00C31E3A /* tflite_simple_example.app */, ); name = Products; sourceTree = ""; @@ -103,24 +104,36 @@ path = data; sourceTree = ""; }; + 5CE7E4179B26BF77944D8637 /* Pods */ = { + isa = PBXGroup; + children = ( + 987DD5BCAB2DD8B682674E20 /* Pods-tflite_simple_example.debug.xcconfig */, + 5D6203B9FAEEB9824194DBE8 /* Pods-tflite_simple_example.release.xcconfig */, + ); + name = Pods; + sourceTree = ""; + }; /* End PBXGroup section */ /* Begin PBXNativeTarget section */ - 5911579A1CF4011C00C31E3A /* tf_simple_example */ = { + 5911579A1CF4011C00C31E3A /* tflite_simple_example */ = { isa = PBXNativeTarget; - buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */; + buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tflite_simple_example" */; buildPhases = ( + A507411BCC70190B9ABD2721 /* [CP] Check Pods Manifest.lock */, 591157971CF4011C00C31E3A /* Sources */, 591157981CF4011C00C31E3A /* Frameworks */, 591157991CF4011C00C31E3A /* Resources */, + 25E1671BDC7334C678FB5DFB /* [CP] Embed Pods Frameworks */, + 10976C49D86B7F8A59157601 /* [CP] Copy Pods Resources */, ); buildRules = ( ); dependencies = ( ); - name = tf_simple_example; + name = tflite_simple_example; productName = tf_ios_makefile_example; - productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */; + productReference = 5911579B1CF4011C00C31E3A /* tflite_simple_example.app */; productType = "com.apple.product-type.application"; }; /* End PBXNativeTarget section */ @@ -152,7 +165,7 @@ projectDirPath = ""; projectRoot = ""; targets = ( - 5911579A1CF4011C00C31E3A /* tf_simple_example */, + 5911579A1CF4011C00C31E3A /* tflite_simple_example */, ); }; /* End PBXProject section */ @@ -171,6 +184,57 @@ }; /* End PBXResourcesBuildPhase section */ +/* Begin PBXShellScriptBuildPhase section */ + 10976C49D86B7F8A59157601 /* [CP] Copy Pods Resources */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Copy Pods Resources"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example-resources.sh\"\n"; + showEnvVarsInLog = 0; + }; + 25E1671BDC7334C678FB5DFB /* [CP] Embed Pods Frameworks */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + ); + name = "[CP] Embed Pods Frameworks"; + outputPaths = ( + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_simple_example/Pods-tflite_simple_example-frameworks.sh\"\n"; + showEnvVarsInLog = 0; + }; + A507411BCC70190B9ABD2721 /* [CP] Check Pods Manifest.lock */ = { + isa = PBXShellScriptBuildPhase; + buildActionMask = 2147483647; + files = ( + ); + inputPaths = ( + "${PODS_PODFILE_DIR_PATH}/Podfile.lock", + "${PODS_ROOT}/Manifest.lock", + ); + name = "[CP] Check Pods Manifest.lock"; + outputPaths = ( + "$(DERIVED_FILE_DIR)/Pods-tflite_simple_example-checkManifestLockResult.txt", + ); + runOnlyForDeploymentPostprocessing = 0; + shellPath = /bin/sh; + shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n"; + showEnvVarsInLog = 0; + }; +/* End PBXShellScriptBuildPhase section */ + /* Begin PBXSourcesBuildPhase section */ 591157971CF4011C00C31E3A /* Sources */ = { isa = PBXSourcesBuildPhase; @@ -274,6 +338,7 @@ }; 591157B31CF4011D00C31E3A /* Debug */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 987DD5BCAB2DD8B682674E20 /* Pods-tflite_simple_example.debug.xcconfig */; buildSettings = { CLANG_DEBUG_INFORMATION_LEVEL = default; CODE_SIGN_IDENTITY = "iPhone Developer"; @@ -283,15 +348,10 @@ GCC_ENABLE_CPP_RTTI = YES; HEADER_SEARCH_PATHS = ( "$(inherited)", - ../../../../../../, - ../../../downloads/flatbuffers/include/, - ../../../downloads/eigen/, - ../../../downloads/, ); INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; IPHONEOS_DEPLOYMENT_TARGET = 9.2; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; OTHER_LDFLAGS = "$(inherited)"; PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example"; @@ -304,6 +364,7 @@ }; 591157B41CF4011D00C31E3A /* Release */ = { isa = XCBuildConfiguration; + baseConfigurationReference = 5D6203B9FAEEB9824194DBE8 /* Pods-tflite_simple_example.release.xcconfig */; buildSettings = { CLANG_DEBUG_INFORMATION_LEVEL = default; CODE_SIGN_IDENTITY = "iPhone Developer"; @@ -313,15 +374,10 @@ GCC_ENABLE_CPP_RTTI = YES; HEADER_SEARCH_PATHS = ( "$(inherited)", - ../../../../../../, - ../../../downloads/flatbuffers/include/, - ../../../downloads/eigen/, - ../../../downloads/, ); INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist"; IPHONEOS_DEPLOYMENT_TARGET = 9.2; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; - LIBRARY_SEARCH_PATHS = ../../../gen/lib/; ONLY_ACTIVE_ARCH = YES; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; OTHER_LDFLAGS = "$(inherited)"; @@ -344,7 +400,7 @@ defaultConfigurationIsVisible = 0; defaultConfigurationName = Release; }; - 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = { + 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tflite_simple_example" */ = { isa = XCConfigurationList; buildConfigurations = ( 591157B31CF4011D00C31E3A /* Debug */, diff --git a/tensorflow/contrib/lite/examples/label_image/BUILD b/tensorflow/contrib/lite/examples/label_image/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..959347b5491514ddc13af57ea6f7385a0d39e418 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/BUILD @@ -0,0 +1,83 @@ +# Description: +# TensorFlow Lite Example Label Image. + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_linkopts") + +exports_files(glob([ + "testdata/*.bmp", +])) + +tf_cc_binary( + name = "label_image", + srcs = [ + "get_top_n.h", + "get_top_n_impl.h", + "label_image.cc", + ], + linkopts = tflite_linkopts() + select({ + "//tensorflow:android": [ + "-pie", # Android 5.0 and later supports only PIE + "-lm", # some builtin ops, e.g., tanh, need -lm + ], + "//conditions:default": [], + }), + deps = [ + ":bitmap_helpers", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + ], +) + +cc_library( + name = "bitmap_helpers", + srcs = ["bitmap_helpers.cc"], + hdrs = [ + "bitmap_helpers.h", + "bitmap_helpers_impl.h", + "label_image.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:schema_fbs_version", + "//tensorflow/contrib/lite:string", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/schema:schema_fbs", + ], +) + +# TODO(ahentz): Test disabled as it has a memory leek from read_bmp +# cc_test( +# name = "label_image_test", +# srcs = [ +# "get_top_n.h", +# "get_top_n_impl.h", +# "label_image_test.cc", +# ], +# data = [ +# "testdata/grace_hopper.bmp", +# ], +# deps = [ +# ":bitmap_helpers", +# "//testing/base/public:gunit", +# ], +# ) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc new file mode 100644 index 0000000000000000000000000000000000000000..0b38cd38c83927c65d251b9356301b6bef7521f2 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.cc @@ -0,0 +1,120 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include // NOLINT(build/include_order) + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" + +#define LOG(x) std::cerr + +namespace tflite { +namespace label_image { + +uint8_t* decode_bmp(const uint8_t* input, int row_size, uint8_t* const output, + int width, int height, int channels, bool top_down) { + for (int i = 0; i < height; i++) { + int src_pos; + int dst_pos; + + for (int j = 0; j < width; j++) { + if (!top_down) { + src_pos = ((height - 1 - i) * row_size) + j * channels; + } else { + src_pos = i * row_size + j * channels; + } + + dst_pos = (i * width + j) * channels; + + switch (channels) { + case 1: + output[dst_pos] = input[src_pos]; + break; + case 3: + // BGR -> RGB + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + break; + case 4: + // BGRA -> RGBA + output[dst_pos] = input[src_pos + 2]; + output[dst_pos + 1] = input[src_pos + 1]; + output[dst_pos + 2] = input[src_pos]; + output[dst_pos + 3] = input[src_pos + 3]; + break; + default: + LOG(FATAL) << "Unexpected number of channels: " << channels; + break; + } + } + } + + return output; +} + +uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, + int* channels, Settings* s) { + int begin, end; + + std::ifstream file(input_bmp_name, std::ios::in | std::ios::binary); + if (!file) { + LOG(FATAL) << "input file " << input_bmp_name << " not found\n"; + exit(-1); + } + + begin = file.tellg(); + file.seekg(0, std::ios::end); + end = file.tellg(); + size_t len = end - begin; + + if (s->verbose) LOG(INFO) << "len: " << len << "\n"; + + const uint8_t* img_bytes = new uint8_t[len]; + file.seekg(0, std::ios::beg); + file.read((char*)img_bytes, len); + const int32_t header_size = + *(reinterpret_cast(img_bytes + 10)); + *width = *(reinterpret_cast(img_bytes + 18)); + *height = *(reinterpret_cast(img_bytes + 22)); + const int32_t bpp = *(reinterpret_cast(img_bytes + 28)); + *channels = bpp / 8; + + if (s->verbose) + LOG(INFO) << "width, height, channels: " << *width << ", " << *height + << ", " << *channels << "\n"; + + // there may be padding bytes when the width is not a multiple of 4 bytes + // 8 * channels == bits per pixel + const int row_size = (8 * *channels * *width + 31) / 32 * 4; + + // if height is negative, data layout is top down + // otherwise, it's bottom up + bool top_down = (*height < 0); + + // Decode image, allocating tensor once the image size is known + uint8_t* output = new uint8_t[abs(*height) * *width * *channels]; + const uint8_t* bmp_pixels = &img_bytes[header_size]; + return decode_bmp(bmp_pixels, row_size, output, *width, abs(*height), + *channels, top_down); +} + +} // namespace label_image +} // namespace tflite diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h new file mode 100644 index 0000000000000000000000000000000000000000..97343dde6b31694e5b2de20b35a7083fb8fe4a0e --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H_ + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h" +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" + +namespace tflite { +namespace label_image { + +uint8_t* read_bmp(const std::string& input_bmp_name, int* width, int* height, + int* channels, Settings* s); + +template +void resize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s); + +// explicit instantiation +template void resize(uint8_t*, unsigned char*, int, int, int, int, int, + int, Settings*); +template void resize(float*, unsigned char*, int, int, int, int, int, + int, Settings*); + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_H diff --git a/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..2a64c1de725b601e9b6e9325d9faacb37df0e626 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/bitmap_helpers_impl.h @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/version.h" + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/version.h" + +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" + +namespace tflite { +namespace label_image { + +template +void resize(T* out, uint8_t* in, int image_height, int image_width, + int image_channels, int wanted_height, int wanted_width, + int wanted_channels, Settings* s) { + int number_of_pixels = image_height * image_width * image_channels; + std::unique_ptr interpreter(new Interpreter); + + int base_index = 0; + + // two inputs: input and new_sizes + interpreter->AddTensors(2, &base_index); + // one output + interpreter->AddTensors(1, &base_index); + // set input and output tensors + interpreter->SetInputs({0, 1}); + interpreter->SetOutputs({2}); + + // set parameters of tensors + TfLiteQuantizationParams quant; + interpreter->SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "input", + {1, image_height, image_width, image_channels}, quant); + interpreter->SetTensorParametersReadWrite(1, kTfLiteInt32, "new_size", {2}, + quant); + interpreter->SetTensorParametersReadWrite( + 2, kTfLiteFloat32, "output", + {1, wanted_height, wanted_width, wanted_channels}, quant); + + ops::builtin::BuiltinOpResolver resolver; + TfLiteRegistration* resize_op = + resolver.FindOp(BuiltinOperator_RESIZE_BILINEAR); + auto* params = reinterpret_cast( + malloc(sizeof(TfLiteResizeBilinearParams))); + params->align_corners = false; + interpreter->AddNodeWithParameters({0, 1}, {2}, nullptr, 0, params, resize_op, + nullptr); + + interpreter->AllocateTensors(); + + // fill input image + // in[] are integers, cannot do memcpy() directly + auto input = interpreter->typed_tensor(0); + for (int i = 0; i < number_of_pixels; i++) { + input[i] = in[i]; + } + + // fill new_sizes + interpreter->typed_tensor(1)[0] = wanted_height; + interpreter->typed_tensor(1)[1] = wanted_width; + + interpreter->Invoke(); + + auto output = interpreter->typed_tensor(2); + auto output_number_of_pixels = + wanted_height * wanted_height * wanted_channels; + + for (int i = 0; i < output_number_of_pixels; i++) { + if (s->input_floating) + out[i] = (output[i] - s->input_mean) / s->input_std; + else + out[i] = (uint8_t)output[i]; + } +} + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_BITMAP_HELPERS_IMPL_H_ diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n.h b/tensorflow/contrib/lite/examples/label_image/get_top_n.h new file mode 100644 index 0000000000000000000000000000000000000000..70a7586fe6a008f0da20a7bac928ca676e5914ab --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/get_top_n.h @@ -0,0 +1,38 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H + +#include "tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h" + +namespace tflite { +namespace label_image { + +template +void get_top_n(T* prediction, int prediction_size, size_t num_results, + float threshold, std::vector>* top_results, + bool input_floating); + +// explicit instantiation so that we can use them otherwhere +template void get_top_n(uint8_t*, int, size_t, float, + std::vector>*, bool); +template void get_top_n(float*, int, size_t, float, + std::vector>*, bool); + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_H diff --git a/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h b/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..e416fbd39b125ea65d1155b19ab0967a9062e71a --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/get_top_n_impl.h @@ -0,0 +1,70 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H + +#include +#include + +namespace tflite { +namespace label_image { + +extern bool input_floating; + +// Returns the top N confidence values over threshold in the provided vector, +// sorted by confidence in descending order. +template +void get_top_n(T* prediction, int prediction_size, size_t num_results, + float threshold, std::vector>* top_results, + bool input_floating) { + // Will contain top N results in ascending order. + std::priority_queue, std::vector>, + std::greater>> + top_result_pq; + + const long count = prediction_size; // NOLINT(runtime/int) + for (int i = 0; i < count; ++i) { + float value; + if (input_floating) + value = prediction[i]; + else + value = prediction[i] / 255.0; + // Only add it if it beats the threshold and has a chance at being in + // the top N. + if (value < threshold) { + continue; + } + + top_result_pq.push(std::pair(value, i)); + + // If at capacity, kick the smallest value out. + if (top_result_pq.size() > num_results) { + top_result_pq.pop(); + } + } + + // Copy to output vector and reverse into descending order. + while (!top_result_pq.empty()) { + top_results->push_back(top_result_pq.top()); + top_result_pq.pop(); + } + std::reverse(top_results->begin(), top_results->end()); +} + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_GET_TOP_N_IMPL_H diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.cc b/tensorflow/contrib/lite/examples/label_image/label_image.cc new file mode 100644 index 0000000000000000000000000000000000000000..a91467d345fdce1268635a69a96939921dc170e8 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image.cc @@ -0,0 +1,308 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) +#include // NOLINT(build/include_order) + +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/model.h" +#include "tensorflow/contrib/lite/optional_debug_tools.h" +#include "tensorflow/contrib/lite/string_util.h" + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" + +#define LOG(x) std::cerr + +namespace tflite { +namespace label_image { + +double get_us(struct timeval t) { return (t.tv_sec * 1000000 + t.tv_usec); } + +// Takes a file name, and loads a list of labels from it, one per line, and +// returns a vector of the strings. It pads with empty strings so the length +// of the result is a multiple of 16, because our model expects that. +TfLiteStatus ReadLabelsFile(const string& file_name, + std::vector* result, + size_t* found_label_count) { + std::ifstream file(file_name); + if (!file) { + LOG(FATAL) << "Labels file " << file_name << " not found\n"; + return kTfLiteError; + } + result->clear(); + string line; + while (std::getline(file, line)) { + result->push_back(line); + } + *found_label_count = result->size(); + const int padding = 16; + while (result->size() % padding) { + result->emplace_back(); + } + return kTfLiteOk; +} + +void RunInference(Settings* s) { + if (!s->model_name.c_str()) { + LOG(ERROR) << "no model file name\n"; + exit(-1); + } + + std::unique_ptr model; + std::unique_ptr interpreter; + model = tflite::FlatBufferModel::BuildFromFile(s->model_name.c_str()); + if (!model) { + LOG(FATAL) << "\nFailed to mmap model " << s->model_name << "\n"; + exit(-1); + } + LOG(INFO) << "Loaded model " << s->model_name << "\n"; + model->error_reporter(); + LOG(INFO) << "resolved reporter\n"; + + tflite::ops::builtin::BuiltinOpResolver resolver; + + tflite::InterpreterBuilder(*model, resolver)(&interpreter); + if (!interpreter) { + LOG(FATAL) << "Failed to construct interpreter\n"; + exit(-1); + } + + interpreter->UseNNAPI(s->accel); + + if (s->verbose) { + LOG(INFO) << "tensors size: " << interpreter->tensors_size() << "\n"; + LOG(INFO) << "nodes size: " << interpreter->nodes_size() << "\n"; + LOG(INFO) << "inputs: " << interpreter->inputs().size() << "\n"; + LOG(INFO) << "input(0) name: " << interpreter->GetInputName(0) << "\n"; + + int t_size = interpreter->tensors_size(); + for (int i = 0; i < t_size; i++) { + if (interpreter->tensor(i)->name) + LOG(INFO) << i << ": " << interpreter->tensor(i)->name << ", " + << interpreter->tensor(i)->bytes << ", " + << interpreter->tensor(i)->type << ", " + << interpreter->tensor(i)->params.scale << ", " + << interpreter->tensor(i)->params.zero_point << "\n"; + } + } + + if (s->number_of_threads != -1) { + interpreter->SetNumThreads(s->number_of_threads); + } + + int image_width = 224; + int image_height = 224; + int image_channels = 3; + uint8_t* in = read_bmp(s->input_bmp_name, &image_width, &image_height, + &image_channels, s); + + int input = interpreter->inputs()[0]; + if (s->verbose) LOG(INFO) << "input: " << input << "\n"; + + const std::vector inputs = interpreter->inputs(); + const std::vector outputs = interpreter->outputs(); + + if (s->verbose) { + LOG(INFO) << "number of inputs: " << inputs.size() << "\n"; + LOG(INFO) << "number of outputs: " << outputs.size() << "\n"; + } + + if (interpreter->AllocateTensors() != kTfLiteOk) { + LOG(FATAL) << "Failed to allocate tensors!"; + } + + if (s->verbose) PrintInterpreterState(interpreter.get()); + + // get input dimension from the input tensor metadata + // assuming one input only + TfLiteIntArray* dims = interpreter->tensor(input)->dims; + int wanted_height = dims->data[1]; + int wanted_width = dims->data[2]; + int wanted_channels = dims->data[3]; + + switch (interpreter->tensor(input)->type) { + case kTfLiteFloat32: + s->input_floating = true; + resize(interpreter->typed_tensor(input), in, image_height, + image_width, image_channels, wanted_height, wanted_width, + wanted_channels, s); + break; + case kTfLiteUInt8: + resize(interpreter->typed_tensor(input), in, + image_height, image_width, image_channels, wanted_height, + wanted_width, wanted_channels, s); + break; + default: + LOG(FATAL) << "cannot handle input type " + << interpreter->tensor(input)->type << " yet"; + exit(-1); + } + + struct timeval start_time, stop_time; + gettimeofday(&start_time, NULL); + for (int i = 0; i < s->loop_count; i++) { + if (interpreter->Invoke() != kTfLiteOk) { + LOG(FATAL) << "Failed to invoke tflite!\n"; + } + } + gettimeofday(&stop_time, NULL); + LOG(INFO) << "invoked \n"; + LOG(INFO) << "average time: " + << (get_us(stop_time) - get_us(start_time)) / (s->loop_count * 1000) + << " ms \n"; + + const int output_size = 1000; + const size_t num_results = 5; + const float threshold = 0.001f; + + std::vector> top_results; + + int output = interpreter->outputs()[0]; + switch (interpreter->tensor(output)->type) { + case kTfLiteFloat32: + get_top_n(interpreter->typed_output_tensor(0), output_size, + num_results, threshold, &top_results, true); + break; + case kTfLiteUInt8: + get_top_n(interpreter->typed_output_tensor(0), + output_size, num_results, threshold, &top_results, + false); + break; + default: + LOG(FATAL) << "cannot handle output type " + << interpreter->tensor(input)->type << " yet"; + exit(-1); + } + + std::vector labels; + size_t label_count; + + if (ReadLabelsFile(s->labels_file_name, &labels, &label_count) != kTfLiteOk) + exit(-1); + + for (const auto& result : top_results) { + const float confidence = result.first; + const int index = result.second; + LOG(INFO) << confidence << ": " << index << " " << labels[index] << "\n"; + } +} + +void display_usage() { + LOG(INFO) << "label_image\n" + << "--accelerated, -a: [0|1], use Android NNAPI or note\n" + << "--count, -c: loop interpreter->Invoke() for certain times\n" + << "--input_mean, -b: input mean\n" + << "--input_std, -s: input standard deviation\n" + << "--image, -i: image_name.bmp\n" + << "--labels, -l: labels for the model\n" + << "--tflite_model, -m: model_name.tflite\n" + << "--threads, -t: number of threads\n" + << "--verbose, -v: [0|1] print more information\n" + << "\n"; +} + +int Main(int argc, char** argv) { + Settings s; + + int c; + while (1) { + static struct option long_options[] = { + {"accelerated", required_argument, 0, 'a'}, + {"count", required_argument, 0, 'c'}, + {"verbose", required_argument, 0, 'v'}, + {"image", required_argument, 0, 'i'}, + {"labels", required_argument, 0, 'l'}, + {"tflite_model", required_argument, 0, 'm'}, + {"threads", required_argument, 0, 't'}, + {"input_mean", required_argument, 0, 'b'}, + {"input_std", required_argument, 0, 's'}, + {0, 0, 0, 0}}; + + /* getopt_long stores the option index here. */ + int option_index = 0; + + c = getopt_long(argc, argv, "a:b:c:f:i:l:m:s:t:v:", long_options, + &option_index); + + /* Detect the end of the options. */ + if (c == -1) break; + + switch (c) { + case 'a': + s.accel = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'b': + s.input_mean = strtod(optarg, NULL); + break; + case 'c': + s.loop_count = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'i': + s.input_bmp_name = optarg; + break; + case 'l': + s.labels_file_name = optarg; + break; + case 'm': + s.model_name = optarg; + break; + case 's': + s.input_std = strtod(optarg, NULL); + break; + case 't': + s.number_of_threads = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'v': + s.verbose = strtol( // NOLINT(runtime/deprecated_fn) + optarg, (char**)NULL, 10); + break; + case 'h': + case '?': + /* getopt_long already printed an error message. */ + display_usage(); + exit(-1); + default: + exit(-1); + } + } + RunInference(&s); + return 0; +} + +} // namespace label_image +} // namespace tflite + +int main(int argc, char** argv) { + return tflite::label_image::Main(argc, argv); +} diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.h b/tensorflow/contrib/lite/examples/label_image/label_image.h new file mode 100644 index 0000000000000000000000000000000000000000..4de32e33fb4ef2ab5d0e111886cdc737398147e9 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H +#define TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H + +#include "tensorflow/contrib/lite/string.h" + +namespace tflite { +namespace label_image { + +struct Settings { + bool verbose = false; + bool accel = false; + bool input_floating = false; + int loop_count = 1; + float input_mean = 127.5f; + float input_std = 127.5f; + string model_name = "./mobilenet_quant_v1_224.tflite"; + string input_bmp_name = "./grace_hopper.bmp"; + string labels_file_name = "./labels.txt"; + string input_layer_type = "uint8_t"; + int number_of_threads = 4; +}; + +} // namespace label_image +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_EXAMPLES_LABEL_IMAGE_LABEL_IMAGE_H diff --git a/tensorflow/contrib/lite/examples/label_image/label_image.md b/tensorflow/contrib/lite/examples/label_image/label_image.md new file mode 100644 index 0000000000000000000000000000000000000000..9ce32cf101897f2d41cd14a485aeb432344928a0 --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image.md @@ -0,0 +1,78 @@ +label_image for TensorFlow Lite inspired by TensorFlow's label_image. + +To build label_image for Android, run $TENSORFLOW_ROOT/configure +and set Android NDK or configure NDK setting in +$TENSORFLOW_ROOT/WORKSPACE first. + +To build it for android ARMv8: +``` +> bazel build --config monolithic --cxxopt=-std=c++11 \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cpu=arm64-v8a \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` +or +``` +> bazel build --config android_arm64 --config monolithic --cxxopt=-std=c++11 \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` + +To build it for android arm-v7a: +``` +> bazel build --config monolithic --cxxopt=-std=c++11 \ + --crosstool_top=//external:android/crosstool \ + --host_crosstool_top=@bazel_tools//tools/cpp:toolchain \ + --cpu=armeabi-v7a \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` +or +``` +> bazel build --config android_arm --config monolithic --cxxopt=-std=c++11 \ + //tensorflow/contrib/lite/examples/label_image:label_image +``` + +Build it for desktop machines (tested on Ubuntu and OS X) +``` +> bazel build --config opt --cxxopt=-std=c++11 //tensorflow/contrib/lite/examples/label_image:label_image +``` +To run it. Prepare `./mobilenet_quant_v1_224.tflite`, `./grace_hopper.bmp`, and `./labels.txt`. + +Run it: +``` +> ./label_image +Loaded model ./mobilenet_quant_v1_224.tflite +resolved reporter +invoked +average time: 100.986 ms +0.439216: 653 military uniform +0.372549: 458 bow tie +0.0705882: 466 bulletproof vest +0.0235294: 514 cornet +0.0196078: 835 suit +``` +Run `interpreter->Invoker()` 100 times: +``` +> ./label_image -c 100 +Loaded model ./mobilenet_quant_v1_224.tflite +resolved reporter +invoked +average time: 33.4694 ms +... +``` + +Run a floating point (`mobilenet_v1_1.0_224.tflite`) model, +``` +> ./label_image -f 1 -m mobilenet_v1_1.0_224.tflite +Loaded model mobilenet_v1_1.0_224.tflite +resolved reporter +invoked +average time: 263.493 ms +0.88615: 653 military uniform +0.0422316: 440 bearskin +0.0109948: 466 bulletproof vest +0.0105327: 401 academic gown +0.00947104: 723 ping-pong bal +``` + +See the source code for other command line options. diff --git a/tensorflow/contrib/lite/examples/label_image/label_image_test.cc b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ce35483f76e8f40ced79e1ee30774c62d0eba94e --- /dev/null +++ b/tensorflow/contrib/lite/examples/label_image/label_image_test.cc @@ -0,0 +1,61 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/contrib/lite/examples/label_image/bitmap_helpers.h" +#include "tensorflow/contrib/lite/examples/label_image/get_top_n.h" +#include "tensorflow/contrib/lite/examples/label_image/label_image.h" + +using ::testing::ElementsAreArray; + +namespace tflite { +namespace label_image { + +TEST(LabelImageTest, GraceHopper) { + std::string lena_file = + "tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp"; + int height, width, channels; + Settings s; + uint8_t *data; + + data = read_bmp(lena_file, &width, &height, &channels, &s); + ASSERT_EQ(height, 606); + ASSERT_EQ(width, 517); + ASSERT_EQ(channels, 3); + + uint8_t *out = new uint8_t[606 * 517 * 3]; + downsize(out, data, 606, 517, 3, 214, 214, 3, &s); + ASSERT_EQ(out[0], 0x15); + ASSERT_EQ(out[214 * 214 * 3 - 1], 0x12); +} + +TEST(LabelImageTest, GetTopN) { + uint8_t in[] = {1, 1, 2, 2, 4, 4, 16, 32, 128, 64}; + + std::vector> top_results; + get_top_n(in, 10, 5, 0.025, &top_results, false); + ASSERT_EQ(top_results.size(), 4); + ASSERT_EQ(top_results[0].second, 8); +} + +} // namespace label_image +} // namespace tflite + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp b/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp new file mode 100644 index 0000000000000000000000000000000000000000..0d94cd3e930a138b7c20308f5ba375576484d48b Binary files /dev/null and b/tensorflow/contrib/lite/examples/label_image/testdata/grace_hopper.bmp differ diff --git a/tensorflow/contrib/lite/g3doc/custom_operators.md b/tensorflow/contrib/lite/g3doc/custom_operators.md index 204a489a93519309bb09238f1b2c8bbd4f1f19e4..d7cc854ebac08e79d346df0aca6e1fa56b490156 100644 --- a/tensorflow/contrib/lite/g3doc/custom_operators.md +++ b/tensorflow/contrib/lite/g3doc/custom_operators.md @@ -73,7 +73,7 @@ TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { } TfLiteRegistration* Register_SIN() { - static TfLiteRegistration r = {nullptr, nullptr, SinResize, SinEval}; + static TfLiteRegistration r = {nullptr, nullptr, SinPrepare, SinEval}; return &r; } ``` diff --git a/tensorflow/contrib/lite/g3doc/ios.md b/tensorflow/contrib/lite/g3doc/ios.md index ce8b37fbf9b0db5dee60784e85a3cbf0326fddb6..a359b8d4b481dbc15cc86db14eabda5433722b8b 100644 --- a/tensorflow/contrib/lite/g3doc/ios.md +++ b/tensorflow/contrib/lite/g3doc/ios.md @@ -45,6 +45,10 @@ into a universal file containing armv7, armv7s, arm64, i386, and x86_64 architectures. The resulting library is in `tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a`. +If you get an error such as `no such file or directory: 'x86_64'` when running +`build_ios_universal_lib.sh`: open Xcode > Preferences > Locations, and ensure +a value is selected in the "Command Line Tools" dropdown. + ## Using in your own application You'll need to update various settings in your app to link against TensorFlow diff --git a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md index 9ade04eb8c696d7e0e39a8104e02b6e5feec95eb..b1bbb7c67013acfb575cc1e9f9390ba191cbd08e 100644 --- a/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md +++ b/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md @@ -1,4 +1,4 @@ -# TensorFlow Compatibility Guide +# TensorFlow Lite & TensorFlow Compatibility Guide TensorFlow Lite supports a number of TensorFlow operations used in common inference models. As they are processed by the TensorFlow Lite Optimizing @@ -329,18 +329,18 @@ Inputs { 0: a tensor } Outputs { - 0: a tensor equivalent to max(0, min(input, 1) + 0: a tensor equivalent to max(0, input) } ``` -**RELU1** +**RELU_N1_TO_1** ``` Inputs { 0: a tensor } Outputs { - 0: a tensor equivalent to max(-1, min(input, 6) + 0: a tensor equivalent to max(-1, min(input, 1) } ``` diff --git a/tensorflow/contrib/lite/graph_info.cc b/tensorflow/contrib/lite/graph_info.cc new file mode 100644 index 0000000000000000000000000000000000000000..e60ed2c2463cb621015ba725ca030e8d8c02f3c7 --- /dev/null +++ b/tensorflow/contrib/lite/graph_info.cc @@ -0,0 +1,224 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/graph_info.h" +#include + +namespace tflite { + +namespace { + +// Provide a range iterable wrapper for TfLiteIntArray* (C lists that TfLite +// C api uses. Can't use the google array_view, since we can't depend on even +// absl for embedded device reasons. +// TODO(aselle): Move this into central utilities. +class TfLiteIntArrayView { + public: + // Construct a view of a TfLiteIntArray*. Note, `int_array` should be non-null + // and this view does not take ownership of it. + explicit TfLiteIntArrayView(const TfLiteIntArray* int_array) + : int_array_(int_array) {} + + typedef const int* const_iterator; + const_iterator begin() const { return int_array_->data; } + const_iterator end() const { return &int_array_->data[int_array_->size]; } + + TfLiteIntArrayView(const TfLiteIntArrayView&) = default; + TfLiteIntArrayView& operator=(const TfLiteIntArrayView& rhs) = default; + + private: + const TfLiteIntArray* int_array_; +}; + +// Helper class that actually performs partitioning by subgraph. +// Outputs to a provided `subgraphs` structure. +// +// Example usage: +// PartitionGraphIntoIndependentSubgraphsImpl partitioner( +// info, nodes_to_part, subgraphs); +// partitioner.Partition(); +class PartitionGraphIntoIndependentSubgraphsImpl { + public: + PartitionGraphIntoIndependentSubgraphsImpl( + const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, + std::vector* subgraphs) + : info_(info), + subgraphs_(subgraphs), + node_type_(info->num_nodes(), Subgraph::kTfNonPartition) { + // Populate the node_type_ map. + for (auto node_index : TfLiteIntArrayView(nodes_to_partition)) { + node_type_[node_index] = Subgraph::kTfPartition; + } + } + + // Actually partition the graph. + void Partition() { + // Initialize here to make Partition() re-entrant. + subgraphs_->clear(); + tensor_epochs_.clear(); + tensor_epochs_.resize(info_->num_tensors(), kEpochAlwaysReady); + node_epochs_.clear(); + node_epochs_.resize(info_->num_nodes(), kEpochNotReady); + // Set computed tensors to be kEpochNotReady (initializer set everything to + // AlwaysReady). + for (int node_index = 0; node_index < info_->num_nodes(); node_index++) { + const TfLiteNode& node = info_->node(node_index); + for (int output_tensor_index : TfLiteIntArrayView(node.outputs)) { + tensor_epochs_[output_tensor_index] = kEpochNotReady; + } + } + + // Do a graph traversal where each iteration in the loop is an epoch + // that corresponds to a subgraph that only contains nodes that are of + // the same node_type_. + while (true) { + BuildSubgraph(); + if (subgraphs_->back().nodes.empty()) { + subgraphs_->pop_back(); + break; + } + } + + // Mark model outputs as subgraph outputs. All the rest have already been + // identified. + for (int output_index : info_->outputs()) { + int output_epoch = tensor_epochs_[output_index]; + Subgraph& output_subgraph = (*subgraphs_)[output_epoch]; + output_subgraph.output_tensors.push_back(output_index); + } + // Make sure every subgraph's inputs and outputs are unique. Since the + // list of inputs and outputs is generated in a way that produces + // duplicates. + for (Subgraph& subgraph : *subgraphs_) { + // Sort and uniquefy using standard library algorithms. + auto uniquefy = [](std::vector* items) { + std::sort(items->begin(), items->end()); + auto last = std::unique(items->begin(), items->end()); + items->erase(last, items->end()); + }; + uniquefy(&subgraph.input_tensors); + uniquefy(&subgraph.output_tensors); + } + } + + private: + // Special integer values needed for tensor_epochs_ and node_epochs_. + enum { + // The node or tensor is not ready to be assigned an epoch. e.g. a node's + // inputs have not all been assigned epochs. + kEpochNotReady = -1, + // Used for tensor_epochs_. This means that the tensor is always ready. + // e.g. an input to the whole model or a constant that has no dependencies. + kEpochAlwaysReady = -2 + }; + + // Updates the node `node_index` and returns true if it is assigned to an + // epoch. False is returned if the node is already set to an epoch, its inputs + // are not all assigned to epochs, or if it cannot be assigned to the current + // epoch since the epoch's node_type doesn't match. + bool UpdateNode(int node_index) { + const TfLiteNode& node = info_->node(node_index); + Subgraph& current_subgraph = subgraphs_->back(); + int current_epoch = subgraphs_->size() - 1; + // Check if node is already done. + if (node_epochs_[node_index] != kEpochNotReady) { + return false; + } + // See if all dependencies of this node are already assigned to a + // subgraph. + for (int input_tensor_index : TfLiteIntArrayView(node.inputs)) { + if (tensor_epochs_[input_tensor_index] == kEpochNotReady) { + return false; + } + } + // When we are starting a new epoch, the first ready node defines + // the type of that epoch. + if (current_subgraph.type == Subgraph::kTfUnexplored) { + current_subgraph.type = node_type_[node_index]; + } + // The node gets assigned to this epoch if it is the same type as + // the epoch's assigned type. Note, if this is the current ready + // node encountered during this epoch, this condition will be + // automatically true. + if (current_subgraph.type == node_type_[node_index]) { + node_epochs_[node_index] = current_epoch; + current_subgraph.nodes.push_back(node_index); + // All outputs of this node now are assigned to this epoch as + // well. + for (int output_tensor_index : TfLiteIntArrayView(node.outputs)) { + tensor_epochs_[output_tensor_index] = current_epoch; + } + // Look at our inputs one more time to update that tensor's + // epochs' outputs + for (int input_tensor_index : TfLiteIntArrayView(node.inputs)) { + int input_epoch = tensor_epochs_[input_tensor_index]; + int node_epoch = current_epoch; + if (input_epoch != node_epoch) { + current_subgraph.input_tensors.push_back(input_tensor_index); + // Set inputs to be outputs of the subgraph where they reside. + // the if condition makes sure inputs to the whole computation + // are not included (i.e. those initialized to -2 above). + if (input_epoch >= 0) { + Subgraph& input_subgraph = (*subgraphs_)[input_epoch]; + input_subgraph.output_tensors.push_back(input_tensor_index); + } + } + } + return true; + } else { + return false; + } + } + + // Completely populates the current subgraph by doing graph traversal + void BuildSubgraph() { + subgraphs_->emplace_back(Subgraph()); + // loop until no more nodes can be updated. + while (true) { + bool did_something = false; + for (int node_index = 0; node_index < info_->num_nodes(); node_index++) { + if (UpdateNode(node_index)) { + did_something = true; + } + } + if (!did_something) return; + } + } + + // Temporary data needed for partitioning. + const GraphInfo* info_; + // List of subgraphs to populate + std::vector* subgraphs_; + std::vector node_type_; + // Maps from tensor index to the epoch in which it is assigned. Also special + // negative values of kEpochNotAssigned if not assigned, kEpochNotReady if it + // is an input or constant. + std::vector tensor_epochs_; + // Maps from tensor index to the epoch in which it is assigned. Also special + // negative values of kEpochNotAssigned if not assigned. + std::vector node_epochs_; +}; + +} // namespace + +TfLiteStatus PartitionGraphIntoIndependentSubgraphs( + const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, + std::vector* subgraphs) { + PartitionGraphIntoIndependentSubgraphsImpl(info, nodes_to_partition, + subgraphs) + .Partition(); + return kTfLiteOk; +} + +} // namespace tflite diff --git a/tensorflow/contrib/lite/graph_info.h b/tensorflow/contrib/lite/graph_info.h new file mode 100644 index 0000000000000000000000000000000000000000..313af5fb7574b42bcdd53b4baad06e4ccfb34053 --- /dev/null +++ b/tensorflow/contrib/lite/graph_info.h @@ -0,0 +1,79 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ +#define TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ + +#include + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// Basic information about an inference graph, where execution nodes +// are connected via tensors. +class GraphInfo { + public: + virtual ~GraphInfo() {} + + // Total number of tensors in the graph. + virtual size_t num_tensors() const = 0; + + // Returns a tensor given its index which is expected to be between 0 and + // num_tensors(). + virtual TfLiteTensor* tensor(size_t index) = 0; + + // Total number of nodes in the graph. + virtual size_t num_nodes() const = 0; + + // Returns a node given its index which is expected to be between 0 and + // num_nodes(). + virtual const TfLiteNode& node(size_t index) const = 0; + + // Returns the indices of the input tensors. + virtual const std::vector& inputs() const = 0; + + // Returns the indices of the output tensors. + virtual const std::vector& outputs() const = 0; +}; + +// Represents a subgraph of a TensorFlow Lite graph. +struct Subgraph { + enum Type { + kTfUnexplored = 0, // temporarily used during creation + kTfPartition, + kTfNonPartition + }; + Type type = kTfUnexplored; + // Nodes within the subgraph + std::vector nodes; + // Tensors that stride output from another subgraph that this depends on, + // or global inputs to the TensorFlow Lite full graph. + std::vector input_tensors; + // Outputs that are consumed by other subgraphs or are global output tensors. + // All output tensors of the nodes in the subgraph that do not appear in this + // list are intermediate results that can be potentially elided. + std::vector output_tensors; +}; + +// Partitions a list of node indices `nodes_to_partition` into subgraphs. +// Each subgraph is in dependency order (i.e. all members of the subgraph). +// `subgraphs` is assumed to be empty. +TfLiteStatus PartitionGraphIntoIndependentSubgraphs( + const GraphInfo* info, const TfLiteIntArray* nodes_to_partition, + std::vector* subgraphs); + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_GRAPH_INFO_H_ diff --git a/tensorflow/contrib/lite/graph_info_test.cc b/tensorflow/contrib/lite/graph_info_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ea38b43993fef71c6820c7a978351d92d5420287 --- /dev/null +++ b/tensorflow/contrib/lite/graph_info_test.cc @@ -0,0 +1,270 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/contrib/lite/graph_info.h" +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +// Makes a TfLiteIntArray* from std::vector, must free with TfLiteIntFree(). +TfLiteIntArray* ConvertVector(const std::vector& x) { + TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); + for (size_t i = 0; i < x.size(); i++) lite->data[i] = x[i]; + return lite; +} + +// A very simple test graph that supports setting in/out tensors on nodes. +class SimpleTestGraph : public GraphInfo { + public: + ~SimpleTestGraph() override { + for (auto& node : nodes_) { + TfLiteIntArrayFree(node.inputs); + TfLiteIntArrayFree(node.outputs); + } + } + + size_t num_tensors() const override { return tensors_.size(); } + size_t num_nodes() const override { return nodes_.size(); } + const TfLiteNode& node(size_t index) const override { return nodes_[index]; } + TfLiteTensor* tensor(size_t index) override { return &tensors_[index]; } + const std::vector& inputs() const override { return inputs_; } + const std::vector& outputs() const override { return outputs_; } + + void AddNode(const std::vector& inputs, + const std::vector& outputs) { + nodes_.push_back(TfLiteNode()); + TfLiteNode& node = nodes_.back(); + node.inputs = ConvertVector(inputs); + node.outputs = ConvertVector(outputs); + } + + void AddTensors(int count) { tensors_.resize(count + tensors_.size()); } + + void SetInputsAndOutputs(const std::vector& inputs, + const std::vector& outputs) { + inputs_ = inputs; + outputs_ = outputs; + } + + private: + std::vector nodes_; + std::vector tensors_; + std::vector inputs_; + std::vector outputs_; +}; + +// Partition a graph to generate a list of subgraphs. This wraps the API call +// we are testing and handles memory management and conversion to +// TfLiteIntArray. Populates `subgraphs` with resulting generated subgraphs. +void PartitionGraph(const SimpleTestGraph& graph, + const std::vector& nodes_to_partition, + std::vector* subgraphs) { + TfLiteIntArray* nodes_to_partition_int_array = + ConvertVector(nodes_to_partition); + PartitionGraphIntoIndependentSubgraphs(&graph, nodes_to_partition_int_array, + subgraphs); + TfLiteIntArrayFree(nodes_to_partition_int_array); +} + +// Check a generated list of subgraphs against the expected list of subgraphs. +void CheckPartitionSubgraphs(const std::vector& generated_subgraphs, + const std::vector& expected_subgraphs) { + ASSERT_EQ(generated_subgraphs.size(), expected_subgraphs.size()); + for (int subgraph_index = 0; subgraph_index < generated_subgraphs.size(); + subgraph_index++) { + EXPECT_EQ(generated_subgraphs[subgraph_index].nodes, + expected_subgraphs[subgraph_index].nodes); + EXPECT_EQ(generated_subgraphs[subgraph_index].input_tensors, + expected_subgraphs[subgraph_index].input_tensors); + EXPECT_EQ(generated_subgraphs[subgraph_index].output_tensors, + expected_subgraphs[subgraph_index].output_tensors); + } +} + +// Test an empty trivial graph with no partitions. +TEST(PartitionTest, Nodes0_PartitionNodes0) { + SimpleTestGraph graph; + std::vector nodes_to_partition = {}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + CheckPartitionSubgraphs(generated_subgraphs, {}); +} + +// Test a 1 node graph with no partitions. +// Input: tensor(0) -> node(0) -> tensor(1), nodes_to_partition=[] +// Output: [kTfNoPartition, tensor(0) -> node(0) -> tensor(1)] +TEST(PartitionTest, Nodes1PartitionNodes0) { + SimpleTestGraph graph; + graph.AddTensors(2); + graph.AddNode({0}, {1}); + graph.SetInputsAndOutputs({0}, {1}); + std::vector nodes_to_partition = {}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph; + expected_subgraph.type = Subgraph::kTfNonPartition; + expected_subgraph.nodes = {0}; + expected_subgraph.input_tensors = {0}; + expected_subgraph.output_tensors = {1}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph}); +} + +// Test a 1 node graph with no inputs that is fully partitioned. +// Input: node(0) -> tensor(1), nodes_to_partition=[node0] +// Output: [kTfPartition, node(0) -> tensor(1)] +TEST(PartitionTest, Nodes1PartitionNodes0Inputs0) { + SimpleTestGraph graph; + graph.AddTensors(1); + graph.AddNode({}, {0}); + graph.SetInputsAndOutputs({}, {0}); + std::vector generated_subgraphs; + std::vector nodes_to_partition = {0}; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph; + expected_subgraph.type = Subgraph::kTfPartition; + expected_subgraph.nodes = {0}; + expected_subgraph.input_tensors = {}; + expected_subgraph.output_tensors = {0}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph}); +} + +// Test a 1 node graph that is partitioned completely. +// Input: tensor(0) -> node(0) -> tensor(1), nodes_to_partition=[node0] +// Output: [kTfPartition, tensor(0) -> node(0) -> tensor(1)] +TEST(PartitionTest, Nodes1PartitionNodes1) { + SimpleTestGraph graph; + graph.AddTensors(2); + graph.AddNode({0}, {1}); + graph.SetInputsAndOutputs({0}, {1}); + std::vector nodes_to_partition = {0}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph; + expected_subgraph.type = Subgraph::kTfPartition; + expected_subgraph.nodes = {0}; + expected_subgraph.input_tensors = {0}; + expected_subgraph.output_tensors = {1}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph}); +} + +// Test a 2 node graph where 1 node is partitioned and the other is not. +// Input: tensor(0) -> node(0) -> tensor(1) -> node(1) -> tensor(2), +// nodes_to_partition = [1] +// Output: [kTfNonPartition, tensor(0) -> node(0) -> tensor(1), +// kTfPartition, tensor(1) -> node(1), tensor(2)] +TEST(PartitionTest, Nodes2PartitionNodes1) { + SimpleTestGraph graph; + graph.AddTensors(3); + graph.AddNode({0}, {1}); + graph.AddNode({1}, {2}); + graph.SetInputsAndOutputs({0}, {2}); + std::vector nodes_to_partition = {1}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph0; + expected_subgraph0.type = Subgraph::kTfPartition; + expected_subgraph0.nodes = {0}; + expected_subgraph0.input_tensors = {0}; + expected_subgraph0.output_tensors = {1}; + Subgraph expected_subgraph1; + expected_subgraph1.type = Subgraph::kTfPartition; + expected_subgraph1.nodes = {1}; + expected_subgraph1.input_tensors = {1}; + expected_subgraph1.output_tensors = {2}; + CheckPartitionSubgraphs(generated_subgraphs, + {expected_subgraph0, expected_subgraph1}); +} + +// Test a 2 node graph where both nodes are fully partitioned. +// Input: tensor(0) -> node(0) -> tensor(1) -> node(1) -> tensor(2), +// nodes_to_partition = [0, 1] +// Output: [kTfPartition, tensor(0) -> node(0) -> node(1) -> tensor(1)] +TEST(PartitionTest, Nodes2PartitionNodes2) { + SimpleTestGraph graph; + graph.AddTensors(3); + graph.AddNode({0}, {1}); + graph.AddNode({1}, {2}); + graph.SetInputsAndOutputs({0}, {2}); + std::vector nodes_to_partition = {0, 1}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph0; + expected_subgraph0.type = Subgraph::kTfPartition; + expected_subgraph0.nodes = {0, 1}; + expected_subgraph0.input_tensors = {0}; + expected_subgraph0.output_tensors = {2}; + CheckPartitionSubgraphs(generated_subgraphs, {expected_subgraph0}); +} + +// Test a three node model where we want to partition nodes 0 and nodes +// 2, but nodes 0 and nodes 2 cannot be in the same subgraph since node 2 +// depends on node 1 which depends on node 0. Thus, we need to produce three +// subgraphs. +// +// Input: tensor(0) -> node(0) -> tensor(1) +// tensor(1) -> node(1) -> tensor(2) +// [tensor(2), tensor(1)] -> node(2) -> tensor(3) +// nodes_to_partition = [0, 2] +// Output: [[kTfPartition, tensor(0) -> node(0) -> tensor(1), +// [kTfNonPartition, tensor(1) -> node(1) -> tensor(2)], +// [kTfPartition, [tensor(2), tensor(1)] -> node(2) -> node(3)] +TEST(PartitionTest, Nodes3PartitionNodes2) { + SimpleTestGraph graph; + graph.AddTensors(4); + graph.AddNode({0}, {1}); + graph.AddNode({1}, {2}); + graph.AddNode({1, 2}, {3}); + graph.SetInputsAndOutputs({0}, {3}); + std::vector nodes_to_partition = {0, 2}; + std::vector generated_subgraphs; + PartitionGraph(graph, nodes_to_partition, &generated_subgraphs); + + Subgraph expected_subgraph0; + expected_subgraph0.type = Subgraph::kTfPartition; + expected_subgraph0.nodes = {0}; + expected_subgraph0.input_tensors = {0}; + expected_subgraph0.output_tensors = {1}; + Subgraph expected_subgraph1; + expected_subgraph1.type = Subgraph::kTfNonPartition; + expected_subgraph1.nodes = {1}; + expected_subgraph1.input_tensors = {1}; + expected_subgraph1.output_tensors = {2}; + Subgraph expected_subgraph2; + expected_subgraph2.type = Subgraph::kTfPartition; + expected_subgraph2.nodes = {2}; + expected_subgraph2.input_tensors = {1, 2}; + expected_subgraph2.output_tensors = {3}; + CheckPartitionSubgraphs( + generated_subgraphs, + {expected_subgraph0, expected_subgraph1, expected_subgraph2}); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc index 954e236ac8f0c8c59a9d20d62e66b3aa1164ecc1..028449211b8108d004df4d1cd8a58b4a08df6604 100644 --- a/tensorflow/contrib/lite/interpreter.cc +++ b/tensorflow/contrib/lite/interpreter.cc @@ -18,16 +18,16 @@ limitations under the License. #include #include #include +#include "tensorflow/contrib/lite/arena_planner.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/graph_info.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/memory_planner.h" #include "tensorflow/contrib/lite/nnapi_delegate.h" namespace { -// Memory allocation tuning -constexpr const int kDefaultArenaAlignment = 64; -constexpr const int kDefaultTensorAlignment = 4; // std::vector preallocation tuning. constexpr const int kSlotsToReserve = 128; @@ -35,10 +35,40 @@ constexpr const int kSlotsToReserve = 128; namespace tflite { +// A trivial implementation of GraphInfo around the Interpreter. +// NOTE: this interpreter info represents the subset of the +// graph that is executed according to execution plan. Thus, +// the indices are execution plan indices rather than raw node +// indices. +class InterpreterInfo : public GraphInfo { + public: + explicit InterpreterInfo(Interpreter* interpreter) + : interpreter_(interpreter) {} + + size_t num_tensors() const override { return interpreter_->tensors_size(); } + TfLiteTensor* tensor(size_t index) override { + return interpreter_->tensor(index); + } + size_t num_nodes() const override { + return interpreter_->execution_plan().size(); + } + const TfLiteNode& node(size_t index) const override { + int node_index = interpreter_->execution_plan()[index]; + return interpreter_->node_and_registration(node_index)->first; + } + const std::vector& inputs() const override { + return interpreter_->inputs(); + } + const std::vector& outputs() const override { + return interpreter_->outputs(); + } + + public: + Interpreter* interpreter_; +}; + Interpreter::Interpreter(ErrorReporter* error_reporter) - : arena_(kDefaultArenaAlignment), - persistent_arena_(kDefaultArenaAlignment), - error_reporter_(error_reporter ? error_reporter + : error_reporter_(error_reporter ? error_reporter : DefaultErrorReporter()) { context_.impl_ = static_cast(this); context_.ResizeTensor = ResizeTensor; @@ -47,10 +77,16 @@ Interpreter::Interpreter(ErrorReporter* error_reporter) context_.tensors = nullptr; context_.tensors_size = 0; context_.gemm_context = nullptr; + + // Invalid to call these these except from TfLiteDelegate + context_.GetNodeAndRegistration = nullptr; + context_.ReplaceSubgraphsWithDelegateKernels = nullptr; + context_.GetExecutionPlan = nullptr; + // Reserve some space for the tensors to avoid excessive resizing. tensors_.reserve(kSlotsToReserve); nodes_and_registration_.reserve(kSlotsToReserve); - next_allocate_node_id_ = 0; + next_execution_plan_index_to_prepare_ = 0; UseNNAPI(false); } @@ -70,6 +106,78 @@ Interpreter::~Interpreter() { } } +TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( + TfLiteContext* context, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace) { + return static_cast(context->impl_) + ->ReplaceSubgraphsWithDelegateKernels(registration, nodes_to_replace); +} + +TfLiteStatus Interpreter::ReplaceSubgraphsWithDelegateKernels( + TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace) { + // Analyze the graph to find all independent subgraphs that are either + // fully not-this-delegate or this-delegate computation. + InterpreterInfo info(this); + std::vector subgraphs; + PartitionGraphIntoIndependentSubgraphs(&info, nodes_to_replace, &subgraphs); + + execution_plan_.clear(); + for (auto& subgraph : subgraphs) { + // Turn subgraph.nodes into a TfLiteIntArray compatible data structure. + // TODO(aselle): Avoid this copy by constructing subgraph.nodes that way + // in the first place + subgraph.nodes.insert(subgraph.nodes.begin(), + static_cast(subgraph.nodes.size())); + // Subgraphs calimed by the delegate should have a "macro" op created, the + // other subgraphs (kTfNonPartition) just have their nodes added back to + // the execution plan. + switch (subgraph.type) { + case Subgraph::kTfNonPartition: + for (auto it = subgraph.nodes.begin() + 1; it != subgraph.nodes.end(); + ++it) { + execution_plan_.push_back(*it); + } + break; + case Subgraph::kTfPartition: { + void* builtin_data = nullptr; + int node_index; + // Create a node that represents computation of this subgraph. + AddNodeWithParameters( + subgraph.input_tensors, subgraph.output_tensors, + reinterpret_cast(subgraph.nodes.data()), + subgraph.nodes.size() * sizeof(subgraph.nodes[0]), builtin_data, + ®istration, &node_index); + } break; + case Subgraph::kTfUnexplored: + return kTfLiteError; + break; + } + } + return kTfLiteOk; +} + +// Gets an TfLiteIntArray* representing the execution plan. The interpreter owns +// this memory and it is only guaranteed to exist during the invocation of the +// delegate prepare. +TfLiteStatus Interpreter::GetExecutionPlan(TfLiteIntArray** execution_plan) { + // TODO(aselle): Do not make a copy here + plan_cache_.reset(TfLiteIntArrayCreate(execution_plan_.size())); + *execution_plan = plan_cache_.get(); + static_assert(sizeof(plan_cache_->data[0]) == sizeof(execution_plan_[0]), + "TfLiteIntArray and execution_plan do not contain same type."); + memcpy(plan_cache_->data, execution_plan_.data(), + sizeof(plan_cache_->data[0]) * execution_plan_.size()); + return kTfLiteOk; +} + +// WARNING: This is an experimental interface that is subject to change. +// Entry point for C node plugin API to get the execution plan +TfLiteStatus Interpreter::GetExecutionPlan(struct TfLiteContext* context, + TfLiteIntArray** execution_plan) { + return static_cast(context->impl_) + ->GetExecutionPlan(execution_plan); +} + TfLiteStatus Interpreter::SetInputs(std::vector inputs) { TF_LITE_ENSURE_OK(&context_, CheckTensorIndices("inputs", inputs.data(), inputs.size())); @@ -128,181 +236,6 @@ TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims, return kTfLiteOk; } -TfLiteStatus Interpreter::AllocateTensorsWhoseSizesAreKnown() { - if (!consistent_) { - ReportError(&context_, "AllocateTensors() called on inconsistent model."); - return kTfLiteError; - } - if (next_allocate_node_id_ == nodes_and_registration_.size() && invokable_) { - return kTfLiteOk; - } - allocs_and_refcounts_.resize(context_.tensors_size); - - int new_next_allocate_node_id = next_allocate_node_id_; - invokable_ = false; - - // Allocate graph input nodes. - if (next_allocate_node_id_ == 0) { - for (int i = 0; i < inputs_.size(); ++i) { - int tensor_index = inputs_[i]; - if (tensor_index == kOptionalTensor) { - continue; - } - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - // Add 1 to output tensors, so they will not get overwritten. - for (int i = 0; i < outputs_.size(); ++i) { - allocs_and_refcounts_[outputs_[i]].count++; - } - } - - // Count references to node input tensors, and resize node-referenced tensors - // until we encounter a node that has a dynamic output tensor. - for (int k = next_allocate_node_id_; k < nodes_and_registration_.size(); - k++) { - new_next_allocate_node_id++; - TfLiteNode& node = nodes_and_registration_[k].first; - const TfLiteRegistration& registration = nodes_and_registration_[k].second; - if (OpPrepare(registration, &node) == kTfLiteError) { - return kTfLiteError; - } - - TfLiteIntArray* node_inputs = node.inputs; - for (int i = 0; i < node_inputs->size; ++i) { - int tensor_index = node_inputs->data[i]; - if (tensor_index != kOptionalTensor) { - allocs_and_refcounts_[node_inputs->data[i]].count++; - } - } - - // Discontinue if the node has dynamic outputs. - bool has_unallocated_dynamic_tensor = false; - TfLiteIntArray* node_outputs = node.outputs; - for (int i = 0; i < node_outputs->size; ++i) { - TfLiteTensor& tensor = context_.tensors[node_outputs->data[i]]; - if (tensor.allocation_type == kTfLiteDynamic) { - has_unallocated_dynamic_tensor = true; - break; - } - } - if (has_unallocated_dynamic_tensor) { - break; - } - } - - // Allocate graph persistent outputs, e.g. RNN cell states, etc. - for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { - TfLiteNode& node = nodes_and_registration_[k].first; - - // Go through output tensors and allocate the persistent ones first. - TfLiteIntArray* node_outputs = node.outputs; - for (int i = 0; i < node_outputs->size; ++i) { - int tensor_index = node_outputs->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRwPersistent) { - TF_LITE_ENSURE_OK(&context_, - persistent_arena_.Allocate( - &context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - } - - // Go through the graph in execution order. - for (int k = next_allocate_node_id_; k < new_next_allocate_node_id; k++) { - TfLiteNode& node = nodes_and_registration_[k].first; - - // First allocate output tensors. - TfLiteIntArray* node_outputs = node.outputs; - for (int i = 0; i < node_outputs->size; ++i) { - int tensor_index = node_outputs->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - // Then the temporaries, in two passes. First allocate them all, them - // deallocate them. - TfLiteIntArray* node_temporaries = node.temporaries; - for (int i = 0; i < node_temporaries->size; ++i) { - int tensor_index = node_temporaries->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Allocate(&context_, kDefaultTensorAlignment, tensor.bytes, - &allocs_and_refcounts_[tensor_index].alloc)); - } - } - for (int i = 0; i < node_temporaries->size; ++i) { - int tensor_index = node_temporaries->data[i]; - TfLiteTensor& tensor = context_.tensors[tensor_index]; - allocs_and_refcounts_[tensor_index].count--; - if (tensor.allocation_type == kTfLiteArenaRw && - allocs_and_refcounts_[tensor_index].count == 0) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Deallocate(&context_, - allocs_and_refcounts_[tensor_index].alloc)); - } - } - - // Then process the node's inputs. - TfLiteIntArray* node_inputs = node.inputs; - for (int i = 0; i < node_inputs->size; ++i) { - int tensor_index = node_inputs->data[i]; - if (tensor_index == kOptionalTensor) { - continue; - } - TfLiteTensor& tensor = context_.tensors[tensor_index]; - - // Decrease reference count and deallocate if not needed anymore. - allocs_and_refcounts_[tensor_index].count--; - if (tensor.allocation_type == kTfLiteArenaRw && - allocs_and_refcounts_[tensor_index].count == 0) { - TF_LITE_ENSURE_OK( - &context_, - arena_.Deallocate(&context_, - allocs_and_refcounts_[tensor_index].alloc)); - } - } - } - - // Resize the buffer and commit the arena. - TF_LITE_ENSURE_OK(&context_, arena_.Commit(&context_)); - TF_LITE_ENSURE_OK(&context_, persistent_arena_.Commit(&context_)); - - // Rewire the tensors to use the underlying arena buffer. - for (int i = 0; i < context_.tensors_size; ++i) { - TfLiteTensor& tensor = context_.tensors[i]; - if (tensor.allocation_type == kTfLiteArenaRw) { - TF_LITE_ENSURE_OK( - &context_, - arena_.ResolveAlloc(&context_, allocs_and_refcounts_[i].alloc, - &tensor.data.raw)); - } - if (tensor.allocation_type == kTfLiteArenaRwPersistent) { - TF_LITE_ENSURE_OK( - &context_, - persistent_arena_.ResolveAlloc( - &context_, allocs_and_refcounts_[i].alloc, &tensor.data.raw)); - } - } - - invokable_ = true; - next_allocate_node_id_ = new_next_allocate_node_id; - return kTfLiteOk; -} - namespace { TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { TfLiteIntArray* lite = TfLiteIntArrayCreate(x.size()); @@ -312,11 +245,19 @@ TfLiteIntArray* convertVectorToTfLiteIntArray(const std::vector& x) { } // namespace TfLiteStatus Interpreter::AllocateTensors() { - next_allocate_node_id_ = 0; - TF_LITE_ENSURE_OK(&context_, arena_.Clear()); - TF_LITE_ENSURE_OK(&context_, persistent_arena_.Clear()); - allocs_and_refcounts_.clear(); - return AllocateTensorsWhoseSizesAreKnown(); + next_execution_plan_index_to_prepare_ = 0; + if (memory_planner_) { + TF_LITE_ENSURE_STATUS(memory_planner_->ResetAllocations()); + } + + if (!consistent_) { + ReportError(&context_, "AllocateTensors() called on inconsistent model."); + return kTfLiteError; + } + + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + invokable_ = true; + return kTfLiteOk; } TfLiteStatus Interpreter::AddNodeWithParameters( @@ -334,8 +275,10 @@ TfLiteStatus Interpreter::AddNodeWithParameters( &context_, CheckTensorIndices("node outputs", outputs.data(), outputs.size())); - if (node_index) *node_index = nodes_and_registration_.size(); + int new_node_index = nodes_and_registration_.size(); + if (node_index) *node_index = new_node_index; nodes_and_registration_.resize(nodes_and_registration_.size() + 1); + auto& node_and_reg = nodes_and_registration_.back(); TfLiteNode& node = node_and_reg.first; if (node.inputs) TfLiteIntArrayFree(node.inputs); @@ -357,6 +300,7 @@ TfLiteStatus Interpreter::AddNodeWithParameters( } node.builtin_data = builtin_data_deleter.release(); node_and_reg.second = *registration; + execution_plan_.push_back(new_node_index); return kTfLiteOk; } @@ -372,6 +316,60 @@ TfLiteStatus Interpreter::ResizeInputTensor(int tensor_index, return ResizeTensorImpl(&context_.tensors[tensor_index], dims_lite); } +// Returns true if at least one tensor in the given list is kTfLiteDynamic. +bool HasDynamicTensor(const TfLiteContext& context, + const TfLiteIntArray* tensors) { + for (int i = 0; i < tensors->size; ++i) { + const TfLiteTensor& tensor = context.tensors[tensors->data[i]]; + if (tensor.allocation_type == kTfLiteDynamic) { + return true; + } + } + return false; +} + +TfLiteStatus Interpreter::PrepareOpsStartingAt( + int first_execution_plan_index, int* last_execution_plan_index_prepared) { + for (int execution_plan_index = first_execution_plan_index; + execution_plan_index < execution_plan_.size(); execution_plan_index++) { + int node_index = execution_plan_[execution_plan_index]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + const TfLiteRegistration& registration = + nodes_and_registration_[node_index].second; + if (OpPrepare(registration, &node) == kTfLiteError) { + return kTfLiteError; + } + + *last_execution_plan_index_prepared = execution_plan_index; + + // Discontinue if the node has dynamic outputs. Note that we don't + // stop for dynamic temporary tensors since they won't affect the + // sizes of other tensors in the graph. + if (HasDynamicTensor(context_, node.outputs)) { + break; + } + } + return kTfLiteOk; +} + +TfLiteStatus Interpreter::PrepareOpsAndTensors() { + if (!memory_planner_) { + memory_planner_.reset(new ArenaPlanner( + &context_, std::unique_ptr(new InterpreterInfo(this)))); + memory_planner_->PlanAllocations(); + } + + int last_exec_plan_index_prepared = 0; + + TF_LITE_ENSURE_STATUS(PrepareOpsStartingAt( + next_execution_plan_index_to_prepare_, &last_exec_plan_index_prepared)); + TF_LITE_ENSURE_STATUS(memory_planner_->ExecuteAllocations( + next_execution_plan_index_to_prepare_, last_exec_plan_index_prepared)); + + next_execution_plan_index_to_prepare_ = last_exec_plan_index_prepared + 1; + return kTfLiteOk; +} + TfLiteStatus Interpreter::Invoke() { if (!consistent_) { ReportError(&context_, "Invoke called on model that is not consistent."); @@ -384,10 +382,7 @@ TfLiteStatus Interpreter::Invoke() { TfLiteStatus status = kTfLiteOk; if (nnapi_delegate_) { - if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { - return kTfLiteError; - } - if (next_allocate_node_id_ == nodes_and_registration_.size()) { + if (next_execution_plan_index_to_prepare_ == execution_plan_.size()) { TF_LITE_ENSURE_OK(&context_, nnapi_delegate_->Invoke(this)); return kTfLiteOk; } else { @@ -400,17 +395,24 @@ TfLiteStatus Interpreter::Invoke() { } } - for (int i = 0; i < nodes_and_registration_.size(); i++) { - // Ensure we have allocated up to this node. The point of this is to - // allocate as much as possible before running any evaluation, but - // dynamic shapes can prevent this from being possible. - if (i >= next_allocate_node_id_) { - if (AllocateTensorsWhoseSizesAreKnown() == kTfLiteError) { - return kTfLiteError; - } + // Invocations are always done in node order. + // Note that calling Invoke repeatedly will cause the original memory plan to + // be reused, unless either ResizeInputTensor() or AllocateTensors() has been + // called. + // TODO(b/71913981): we should force recalculation in the presence of dynamic + // tensors, because they may have new value which in turn may affect shapes + // and allocations. + for (int execution_plan_index = 0; + execution_plan_index < execution_plan_.size(); execution_plan_index++) { + if (execution_plan_index == next_execution_plan_index_to_prepare_) { + TF_LITE_ENSURE_STATUS(PrepareOpsAndTensors()); + TF_LITE_ENSURE(&context_, next_execution_plan_index_to_prepare_ >= + execution_plan_index); } - TfLiteNode& node = nodes_and_registration_[i].first; - const TfLiteRegistration& registration = nodes_and_registration_[i].second; + int node_index = execution_plan_[execution_plan_index]; + TfLiteNode& node = nodes_and_registration_[node_index].first; + const TfLiteRegistration& registration = + nodes_and_registration_[node_index].second; if (OpInvoke(registration, &node) == kTfLiteError) { status = kTfLiteError; } @@ -465,6 +467,22 @@ TfLiteStatus Interpreter::AddTensors(TfLiteContext* context, int tensors_to_add, ->AddTensors(tensors_to_add, first_new_tensor_index); } +TfLiteStatus Interpreter::GetNodeAndRegistration( + int node_index, TfLiteNode** node, TfLiteRegistration** registration) { + TF_LITE_ENSURE(&context_, node_index < nodes_size() && node_index >= 0); + TF_LITE_ENSURE(&context_, node != nullptr && registration != nullptr); + *node = &nodes_and_registration_[node_index].first; + *registration = &nodes_and_registration_[node_index].second; + return kTfLiteOk; +} + +TfLiteStatus Interpreter::GetNodeAndRegistration( + struct TfLiteContext* context, int node_index, TfLiteNode** node, + TfLiteRegistration** registration) { + return static_cast(context->impl_) + ->GetNodeAndRegistration(node_index, node, registration); +} + TfLiteStatus Interpreter::SetTensorParametersReadOnly( int tensor_index, TfLiteType type, const char* name, const std::vector& dims, TfLiteQuantizationParams quantization, @@ -514,6 +532,14 @@ TfLiteStatus Interpreter::SetTensorParametersReadWrite( return kTfLiteOk; } +TfLiteStatus Interpreter::SetExecutionPlan(const std::vector& new_plan) { + for (int node_index : new_plan) { + TF_LITE_ENSURE(&context_, node_index >= 0 && node_index < nodes_size()); + } + execution_plan_ = new_plan; + return kTfLiteOk; +} + TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size) { // Note that in theory we could resize kTfLiteArenaRwPersistent tensors too. @@ -527,6 +553,9 @@ TfLiteStatus Interpreter::ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArrayFree(new_size); return kTfLiteError; } + + // Realloc space for kTfLiteDynamic tensors. + TfLiteTensorRealloc(bytesRequired, tensor); tensor->bytes = bytesRequired; } if (tensor->dims) TfLiteIntArrayFree(tensor->dims); @@ -564,4 +593,20 @@ void Interpreter::SetNumThreads(int num_threads) { tflite::gemm_support::SetMaxNumThreads(&context_, num_threads); } +TfLiteStatus Interpreter::ModifyGraphWithDelegate(TfLiteDelegate* delegate) { + // TODO(aselle): Consider if it is worth storing pointers to delegates. + // Setup additional context interface + context_.GetNodeAndRegistration = GetNodeAndRegistration; + context_.ReplaceSubgraphsWithDelegateKernels = + ReplaceSubgraphsWithDelegateKernels; + context_.GetExecutionPlan = GetExecutionPlan; + + TfLiteStatus status = delegate->Prepare(&context_, delegate->data_); + // Remove additional context info. + context_.GetNodeAndRegistration = nullptr; + context_.ReplaceSubgraphsWithDelegateKernels = nullptr; + context_.GetExecutionPlan = nullptr; + return status; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h index 65c61e44bee48535f884a3afaddc691972f5e04b..bab56a9d72f8992a9d8af23f92133c7c918fd46d 100644 --- a/tensorflow/contrib/lite/interpreter.h +++ b/tensorflow/contrib/lite/interpreter.h @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Main abstraction controlling the tflite interpreter. // See context.h for the API for defining operations (TfLiteRegistration). -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#define TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ #include #include @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/contrib/lite/allocation.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/error_reporter.h" -#include "tensorflow/contrib/lite/simple_memory_arena.h" +#include "tensorflow/contrib/lite/memory_planner.h" namespace tflite { @@ -49,13 +49,6 @@ constexpr TfLiteType typeToTfLiteType() { return kTfLiteUInt8; } -struct ArenaAllocRefCount { - ArenaAllocRefCount() : alloc(), count(0) {} - - ArenaAlloc alloc; - int count; -}; - // Forward declare since NNAPIDelegate uses Interpreter. class NNAPIDelegate; @@ -87,6 +80,12 @@ class NNAPIDelegate; // foo.Invoke(); // +struct TfLiteIntArrayDeleter { + void operator()(TfLiteIntArray* a) { + if (a) TfLiteIntArrayFree(a); + } +}; + class Interpreter { public: // Instantiate an interpreter. All errors associated with reading and @@ -115,7 +114,7 @@ class Interpreter { // Adds a node with the given parameters and returns the index of the new // node in `node_index` (optionally). Interpreter will take ownership of - // `builtin_data` and destroy it with `delete`. Ownership of 'init_data' + // `builtin_data` and destroy it with `free`. Ownership of 'init_data' // remains with the caller. TfLiteStatus AddNodeWithParameters(const std::vector& inputs, const std::vector& outputs, @@ -173,12 +172,19 @@ class Interpreter { // Return the number of ops in the model. int nodes_size() const { return nodes_and_registration_.size(); } + // WARNING: Experimental interface, subject to change + const std::vector& execution_plan() const { return execution_plan_; } + + // WARNING: Experimental interface, subject to change + // Overrides execution plan. This bounds checks indices sent in. + TfLiteStatus SetExecutionPlan(const std::vector& new_plan); + // Get a tensor data structure. // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this // read/write access to structure TfLiteTensor* tensor(int tensor_index) { if (tensor_index >= context_.tensors_size || tensor_index < 0) - return nullptr; + return nullptr; return &context_.tensors[tensor_index]; } @@ -247,6 +253,11 @@ class Interpreter { // Set the number of threads available to the interpreter. void SetNumThreads(int num_threads); + // Allow a delegate to look at the graph and modify the graph to handle + // parts of the graph themselves. After this is called, the graph may + // contain new nodes that replace 1 more nodes. + TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); + private: // Give 'op_reg' a chance to initialize itself using the contents of // 'buffer'. @@ -276,9 +287,18 @@ class Interpreter { return op_reg.invoke(&context_, node); } - // Allocate tensors whose sizes are known in order of nodes. Discontinue when - // we encounter a node that has a dynamic output tensor. - TfLiteStatus AllocateTensorsWhoseSizesAreKnown(); + // Call OpPrepare() for as many ops as possible, allocating memory for their + // tensors. If an op containing dynamic tensors is found, preparation will be + // postponed until this function is called again. This allows the interpreter + // to wait until Invoke() to resolve the sizes of dynamic tensors. + TfLiteStatus PrepareOpsAndTensors(); + + // Call OpPrepare() for all ops starting at 'first_node'. Stop when a + // dynamic tensors is found or all ops have been prepared. Fill + // 'last_node_prepared' with the id of the op containing dynamic tensors, or + // the last in the graph. + TfLiteStatus PrepareOpsStartingAt(int first_execution_plan_index, + int* last_execution_plan_index_prepared); // Tensors needed by the interpreter. Use `AddTensors` to add more blank // tensor entries. Note, `tensors_.data()` needs to be synchronized to the @@ -298,7 +318,8 @@ class Interpreter { TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size, size_t* bytes); - // Request an tensor be resized implementation. + // Request an tensor be resized implementation. If the given tensor is of + // type kTfLiteDynamic it will also be allocated new memory. TfLiteStatus ResizeTensorImpl(TfLiteTensor* tensor, TfLiteIntArray* new_size); // Report a detailed error string (will be printed to stderr). @@ -315,6 +336,40 @@ class Interpreter { static TfLiteStatus AddTensors(TfLiteContext* context, int tensors_to_add, int* first_new_tensor_index); + // WARNING: This is an experimental API and subject to change. + // Entry point for C API ReplaceSubgraphsWithDelegateKernels + static TfLiteStatus ReplaceSubgraphsWithDelegateKernels( + TfLiteContext* context, TfLiteRegistration registration, + const TfLiteIntArray* nodes_to_replace); + + // Update the execution graph to replace some of the nodes with stub + // nodes. Specifically any node index that has `nodes[index]==1` will be + // slated for replacement with a delegate kernel specified by registration. + // WARNING: This is an experimental interface that is subject to change. + TfLiteStatus ReplaceSubgraphsWithDelegateKernels( + TfLiteRegistration registration, const TfLiteIntArray* nodes_to_replace); + + // WARNING: This is an experimental interface that is subject to change. + // Gets the internal pointer to a TensorFlow lite node by node_index. + TfLiteStatus GetNodeAndRegistration(int node_index, TfLiteNode** node, + TfLiteRegistration** registration); + + // WARNING: This is an experimental interface that is subject to change. + // Entry point for C node plugin API to get a node by index. + static TfLiteStatus GetNodeAndRegistration(struct TfLiteContext*, + int node_index, TfLiteNode** node, + TfLiteRegistration** registration); + + // WARNING: This is an experimental interface that is subject to change. + // Gets an TfLiteIntArray* representing the execution plan. The caller owns + // this memory and must free it with TfLiteIntArrayFree(). + TfLiteStatus GetExecutionPlan(TfLiteIntArray** execution_plan); + + // WARNING: This is an experimental interface that is subject to change. + // Entry point for C node plugin API to get the execution plan + static TfLiteStatus GetExecutionPlan(struct TfLiteContext* context, + TfLiteIntArray** execution_plan); + // A pure C data structure used to communicate with the pure C plugin // interface. To avoid copying tensor metadata, this is also the definitive // structure to store tensors. @@ -325,17 +380,6 @@ class Interpreter { std::vector> nodes_and_registration_; - // Raw memory buffer that is allocated for all temporary and graph outputs. - // that are declared kTfLiteArenaRw. - SimpleMemoryArena arena_; - - // Raw memory buffer that is allocated for persistent tensors that are - // declared as kTfLiteArenaRwPersistent. - SimpleMemoryArena persistent_arena_; - - // Stores allocation and reference counts of all tensors. - std::vector allocs_and_refcounts_; - // Whether the model is consistent. That is to say if the inputs and outputs // of every node and the global inputs and outputs are valid indexes into // the tensor array. @@ -356,7 +400,7 @@ class Interpreter { // The error reporter delegate that tflite will forward queries errors to. ErrorReporter* error_reporter_; - // Next node to allocate output tensors. + // Index of the next node to prepare. // During Invoke(), Interpreter will allocate input tensors first, which are // known to be fixed size. Then it will allocate outputs from nodes as many // as possible. When there is a node that produces dynamic sized tensor. @@ -364,11 +408,24 @@ class Interpreter { // node id, and execute the node to generate the output tensor before continue // to allocate successors. This process repeats until all nodes are executed. // NOTE: this relies on the order of nodes that is in topological order. - int next_allocate_node_id_; + int next_execution_plan_index_to_prepare_; + + // WARNING: This is an experimental interface that is subject to change. + // This is a list of node indices (to index into nodes_and_registration). + // This represents a valid topological sort (dependency ordered) execution + // plan. In particular, it is valid for this ordering to contain only a + // subset of the node indices. + std::vector execution_plan_; + + // In the future, we'd like a TfLiteIntArray compatible representation. + // TODO(aselle): replace execution_plan_ with this. + std::unique_ptr plan_cache_; // Whether to delegate to NN API std::unique_ptr nnapi_delegate_; + + std::unique_ptr memory_planner_; }; } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ +#endif // TENSORFLOW_CONTRIB_LITE_INTERPRETER_H_ diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc index edff2109430c6e1ec6c481619ed7772237a3301d..28c96e5dde6ffa62bb073db9716a00f91c6e0bdf 100644 --- a/tensorflow/contrib/lite/interpreter_test.cc +++ b/tensorflow/contrib/lite/interpreter_test.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/contrib/lite/interpreter.h" #include #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" +#include "tensorflow/contrib/lite/schema/schema_generated.h" #include "tensorflow/contrib/lite/string_util.h" - +#include "tensorflow/contrib/lite/testing/util.h" namespace tflite { namespace { @@ -282,6 +284,51 @@ TEST(BasicInterpreter, NoOpInterpreter) { ASSERT_EQ(interpreter.Invoke(), kTfLiteOk); } +TEST(BasicInterpreter, ResizingTensors) { + Interpreter interpreter; + ASSERT_EQ(interpreter.AddTensors(1), kTfLiteOk); + ASSERT_EQ(interpreter.SetInputs({0}), kTfLiteOk); + ASSERT_EQ(interpreter.SetOutputs({0}), kTfLiteOk); + + ASSERT_EQ(interpreter.SetTensorParametersReadWrite( + 0, kTfLiteFloat32, "", {3}, TfLiteQuantizationParams()), + kTfLiteOk); + + int t = interpreter.inputs()[0]; + TfLiteTensor* tensor = interpreter.tensor(t); + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 3}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 6 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + tensor->data.f[5] = 0.123f; + + // Changing from kTfLiteArenaRw to kTfLiteDynamic is quite complicate: we need + // to unset data.raw, otherwise Realloc will try to free that memory. + tensor->data.raw = nullptr; + tensor->allocation_type = kTfLiteDynamic; + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {1, 2, 4}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 8 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // TODO(ahentz): We shouldn't have to force reallocation, but + // ResizeInputTensor doesn't realloc dynamic tensors. Also note that + // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. + TfLiteTensorRealloc(9 * sizeof(float), tensor); + tensor->data.f[7] = 0.123f; + + ASSERT_EQ(interpreter.ResizeInputTensor(t, {2, 2, 4}), kTfLiteOk); + EXPECT_EQ(tensor->bytes, 16 * sizeof(float)); + ASSERT_EQ(interpreter.AllocateTensors(), kTfLiteOk); + + // TODO(ahentz): We shouldn't have to force reallocation, but + // ResizeInputTensor doesn't realloc dynamic tensors. Also note that + // TfLiteTensorRealloc(tensor->bytes, tensor) is a no-op. + TfLiteTensorRealloc(17 * sizeof(float), tensor); + tensor->data.f[15] = 0.123f; +} + TEST(BasicInterpreter, OneOpInterpreter) { Interpreter interpreter; ASSERT_EQ(interpreter.AddTensors(2), kTfLiteOk); @@ -514,13 +561,283 @@ TEST(BasicInterpreter, TestCustomErrorReporter) { ASSERT_EQ(reporter.calls, 1); } +// Test fixture that allows playing with execution plans. It creates a two +// node graph that can be executed in either [0,1] order or [1,0] order. +// The CopyOp records when it is invoked in the class member run_order_ +// so we can test whether the execution plan was honored. +class TestExecutionPlan : public ::testing::Test { + // Encapsulates the node ids and provides them to a C primitive data type + // Allocatable with placement new, but never destructed, so make sure this + // doesn't own any heap allocated data. This is then is used as op local + // data to allow access to the test fixture data. + class CallReporting { + public: + CallReporting(int node_id, std::vector* run_order) + : node_id_(node_id), run_order_(run_order) {} + + void Record() { run_order_->push_back(node_id_); } + + private: + // The node id for this particular node + int node_id_; + // A pointer to the global run-order + std::vector* run_order_; + }; + + // Build a kernel registration for an op that copies its one input + // to an output + TfLiteRegistration CopyOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + return context->ResizeTensor(context, tensor1, newSize); + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + CallReporting* call_reporting = + reinterpret_cast(node->builtin_data); + // Copy input data to output data. + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + a1->data.f[i] = a0->data.f[i]; + } + call_reporting->Record(); + return kTfLiteOk; + }; + return reg; + } + + // Adds a copy node going from tensor `input` to output tensor `output`. + // Note, input is used as the node_id. Inject run_order as op accessible + // data. Note: this is a little strange of a way to do this, but it is + // using op functionality to avoid static global variables. + void MakeCopyNode(int input, int output) { + // Ownership of call_reporting is taken by interpreter (malloc is used due + // to nodes being a C99 interface so free() is used). + TfLiteRegistration copy_op = CopyOpRegistration(); + CallReporting* call_reporting_1 = + reinterpret_cast(malloc(sizeof(CallReporting))); + new (call_reporting_1) CallReporting(input, &run_order_); + ASSERT_EQ(interpreter_.AddNodeWithParameters( + {0}, {2}, nullptr, 0, + reinterpret_cast(call_reporting_1), ©_op), + kTfLiteOk); + ASSERT_EQ(interpreter_.ResizeInputTensor(input, {3}), kTfLiteOk); + } + + void SetUp() final { + // Add two inputs and two outputs that don't depend on each other + ASSERT_EQ(interpreter_.AddTensors(4), kTfLiteOk); + interpreter_.SetInputs({0, 1}); + interpreter_.SetOutputs({2, 3}); + TfLiteQuantizationParams quantized; + for (int tensor_index = 0; tensor_index < 4; tensor_index++) { + ASSERT_EQ(interpreter_.SetTensorParametersReadWrite( + tensor_index, kTfLiteFloat32, "", {3}, quantized), + kTfLiteOk); + } + + // Define two copy functions that also use the user_data to report that + // they were called. + // i.e. tensor[2] = copy(tensor[0]); tensor[3] = copy(tensor[1]); + // thus we can reorder the two nodes arbitrary and still satisfy dependency + // order. + MakeCopyNode(0, 2); + MakeCopyNode(1, 3); + + ASSERT_EQ(interpreter_.AllocateTensors(), kTfLiteOk); + } + + protected: + Interpreter interpreter_; + + // list of node_ids that were run + std::vector run_order_; +}; + +TEST_F(TestExecutionPlan, DefaultExecutionPlan) { + // Check default order + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({0, 1})); +} + +TEST_F(TestExecutionPlan, ReversedExecutionPlan) { + // Check reversed order + interpreter_.SetExecutionPlan({1, 0}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({1, 0})); +} + +TEST_F(TestExecutionPlan, SubsetExecutionPlan) { + // Check running only node index 1 + interpreter_.SetExecutionPlan({1}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector({1})); +} + +TEST_F(TestExecutionPlan, NullExecutionPlan) { + // Check nothing executed. + interpreter_.SetExecutionPlan({}); + ASSERT_EQ(interpreter_.Invoke(), kTfLiteOk); + ASSERT_EQ(run_order_, std::vector()); +} + +// Build a kernel registration for an op that copies its one input +// to an output +TfLiteRegistration AddOpRegistration() { + TfLiteRegistration reg = {nullptr, nullptr, nullptr, nullptr}; + + reg.custom_name = "my_add"; + reg.builtin_code = tflite::BuiltinOperator_CUSTOM; + + reg.prepare = [](TfLiteContext* context, TfLiteNode* node) { + // Set output size to input size + TfLiteTensor* tensor0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* tensor1 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* tensor2 = &context->tensors[node->outputs->data[0]]; + TfLiteIntArray* newSize = TfLiteIntArrayCopy(tensor0->dims); + TfLiteIntArray* newSizeOther = TfLiteIntArrayCopy(tensor1->dims); + TF_LITE_ENSURE_EQ(context, newSize->size, newSizeOther->size); + TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, tensor2, newSize)); + return kTfLiteOk; + }; + + reg.invoke = [](TfLiteContext* context, TfLiteNode* node) { + // Copy input data to output data. + TfLiteTensor* a0 = &context->tensors[node->inputs->data[0]]; + TfLiteTensor* a1 = &context->tensors[node->inputs->data[1]]; + TfLiteTensor* out = &context->tensors[node->outputs->data[0]]; + int num = a0->dims->data[0]; + for (int i = 0; i < num; i++) { + out->data.f[i] = a0->data.f[i] + a1->data.f[i]; + } + return kTfLiteOk; + }; + return reg; +} + +class TestDelegate : public ::testing::Test { + public: + TestDelegate() { + interpreter_.AddTensors(5); + interpreter_.SetInputs({0, 1}); + interpreter_.SetOutputs({3, 4}); + TfLiteQuantizationParams quant; + interpreter_.SetTensorParametersReadWrite(0, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(1, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(2, kTfLiteFloat32, "", {3}, + quant); + interpreter_.SetTensorParametersReadWrite(3, kTfLiteFloat32, "", {3}, + quant); + TfLiteRegistration reg = AddOpRegistration(); + interpreter_.AddNodeWithParameters({0, 0}, {2}, nullptr, 0, nullptr, ®); + interpreter_.AddNodeWithParameters({1, 1}, {3}, nullptr, 0, nullptr, ®); + interpreter_.AddNodeWithParameters({2, 1}, {4}, nullptr, 0, nullptr, ®); + } + + protected: + class SimpleDelegate { + public: + // Create a simple implementation of a TfLiteDelegate. We use the C++ class + // SimpleDelegate and it can produce a handle TfLiteDelegate that is + // value-copyable and compatible with TfLite. + explicit SimpleDelegate(const std::vector& nodes) : nodes_(nodes) { + delegate_.Prepare = [](TfLiteContext* context, + void* data) -> TfLiteStatus { + auto* simple = reinterpret_cast(data); + TfLiteIntArray* nodes_to_separate = + TfLiteIntArrayCreate(simple->nodes_.size()); + // Mark nodes that we want in TfLiteIntArray* structure. + int index = 0; + for (auto node_index : simple->nodes_) { + nodes_to_separate->data[index++] = node_index; + // make sure node is add + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + // Check that all nodes are available + TfLiteIntArray* execution_plan; + TF_LITE_ENSURE_STATUS( + context->GetExecutionPlan(context, &execution_plan)); + for (int exec_index = 0; exec_index < execution_plan->size; + exec_index++) { + int node_index = execution_plan->data[exec_index]; + // Check that we are an identity map to start. + TFLITE_CHECK_EQ(exec_index, node_index); + TfLiteNode* node; + TfLiteRegistration* reg; + context->GetNodeAndRegistration(context, node_index, &node, ®); + TFLITE_CHECK_EQ(reg->builtin_code, tflite::BuiltinOperator_CUSTOM); + TFLITE_CHECK_EQ(strcmp(reg->custom_name, "my_add"), 0); + } + + context->ReplaceSubgraphsWithDelegateKernels( + context, FakeFusedRegistration(), nodes_to_separate); + TfLiteIntArrayFree(nodes_to_separate); + return kTfLiteOk; + }; + // Store type-punned data SimpleDelegate structure. + delegate_.data_ = reinterpret_cast(this); + } + + static TfLiteRegistration FakeFusedRegistration() { + TfLiteRegistration reg = {nullptr}; + reg.custom_name = "fake_fused_op"; + return reg; + } + + TfLiteDelegate* get_tf_lite_delegate() { return &delegate_; } + + private: + std::vector nodes_; + TfLiteDelegate delegate_; + }; + Interpreter interpreter_; +}; + +TEST_F(TestDelegate, BasicDelegate) { + interpreter_.Invoke(); + SimpleDelegate simple({0, 1, 2}); + interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_.execution_plan().size(), 1); + int node = interpreter_.execution_plan()[0]; + const auto* node_and_reg = interpreter_.node_and_registration(node); + ASSERT_EQ(node_and_reg->second.custom_name, + SimpleDelegate::FakeFusedRegistration().custom_name); +} + +TEST_F(TestDelegate, ComplexDeligate) { + interpreter_.Invoke(); + SimpleDelegate simple({1, 2}); + interpreter_.ModifyGraphWithDelegate(simple.get_tf_lite_delegate()); + + ASSERT_EQ(interpreter_.execution_plan().size(), 2); + // 0th should be a non-delegated original op + ASSERT_EQ(interpreter_.execution_plan()[0], 0); + // 1st should be a new macro op (3) which didn't exist) + ASSERT_EQ(interpreter_.execution_plan()[1], 3); + const auto* node_and_reg = interpreter_.node_and_registration(3); + ASSERT_EQ(node_and_reg->second.custom_name, + SimpleDelegate::FakeFusedRegistration().custom_name); +} + } // namespace } // namespace tflite int main(int argc, char** argv) { -#ifdef OS_LINUX - FLAGS_logtostderr = true; -#endif + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc index bcff7ed9889e95c13294b6cf0d0f4788991a04df..fc6594c3a04ba6aabba99bb631f85737baf389f1 100644 --- a/tensorflow/contrib/lite/ios_makefile.inc +++ b/tensorflow/contrib/lite/ios_makefile.inc @@ -22,6 +22,7 @@ ifeq ($(TARGET), IOS) IOS_ARCH := x86_64 CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \ + -DTFLITE_USE_APPLE_ACCELERATE_FOR_CONV \ -fembed-bitcode \ -Wno-c++11-narrowing \ -mno-thumb \ @@ -30,6 +31,9 @@ ifeq ($(TARGET), IOS) ${IPHONEOS_SYSROOT} \ -arch $(IOS_ARCH) \ -O3 + ifeq ($(IOS_ARCH), x86_64) + CXXFLAGS += -msse4.1 + endif CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \ -fembed-bitcode \ -mno-thumb \ @@ -39,6 +43,7 @@ ifeq ($(TARGET), IOS) -O3 LDFLAGS := -fembed-bitcode \ -miphoneos-version-min=${MIN_SDK_VERSION} \ + -framework Accelerate \ -arch $(IOS_ARCH) OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/ LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/ diff --git a/tensorflow/contrib/lite/java/AndroidManifest.xml b/tensorflow/contrib/lite/java/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..f705feacbec38ab5152ce52b701320d8f1cd8d3d --- /dev/null +++ b/tensorflow/contrib/lite/java/AndroidManifest.xml @@ -0,0 +1,7 @@ + + + + + + diff --git a/tensorflow/contrib/lite/java/BUILD b/tensorflow/contrib/lite/java/BUILD index 1de28eb52ddb458df0be0a8f9ef453f7caf68654..35aacb70002d1d454f675484e4398bcdffc4acf1 100644 --- a/tensorflow/contrib/lite/java/BUILD +++ b/tensorflow/contrib/lite/java/BUILD @@ -7,6 +7,16 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow/java:build_defs.bzl", "JAVACOPTS") load("//tensorflow/contrib/lite:build_def.bzl", "tflite_jni_binary") +load("//tensorflow/contrib/lite/java:aar_with_jni.bzl", "aar_with_jni") + +# Building tensorflow-lite.aar including 4 variants of .so +# To build an aar for release, run below command: +# bazel build --cxxopt='--std=c++11' -c opt --fat_apk_cpu=x86,x86_64,arm64-v8a,armeabi-v7a \ +# tensorflow/contrib/lite/java:tensorflow-lite +aar_with_jni( + name = "tensorflow-lite", + android_library = ":tensorflowlite", +) android_library( name = "tensorflowlite", @@ -15,6 +25,7 @@ android_library( "src/main/java/org/tensorflow/lite/*.java", ], ), + manifest = "AndroidManifest.xml", visibility = ["//visibility:public"], deps = [ ":tflite_runtime", @@ -100,6 +111,26 @@ java_test( ], ) +# TODO: generate large models at runtime, instead of storing them. +java_test( + name = "InterpreterTest", + size = "small", + srcs = ["src/test/java/org/tensorflow/lite/InterpreterTest.java"], + data = [ + "src/testdata/add.bin", + "src/testdata/mobilenet.tflite.bin", + ], + javacopts = JAVACOPTS, + test_class = "org.tensorflow.lite.InterpreterTest", + visibility = ["//visibility:private"], + deps = [ + ":libtensorflowlite_jni.so", + ":tensorflowlitelib", + "@com_google_truth", + "@junit", + ], +) + java_test( name = "TensorTest", size = "small", diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl new file mode 100644 index 0000000000000000000000000000000000000000..4450bc9085555b3416f51bac07ea94a1240e919c --- /dev/null +++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl @@ -0,0 +1,47 @@ +"""Generate zipped aar file including different variants of .so in jni folder.""" + +def aar_with_jni(name, android_library): + # Generate dummy AndroidManifest.xml for dummy apk usage + # (dummy apk is generated by _dummy_app_for_so target below) + native.genrule( + name = name + "_binary_manifest_generator", + outs = [name + "_generated_AndroidManifest.xml"], + cmd = """ +cat > $(OUTS) < + + +EOF +""", + ) + + # Generate dummy apk including .so files and later we extract out + # .so files and throw away the apk. + native.android_binary( + name = name + "_dummy_app_for_so", + manifest = name + "_generated_AndroidManifest.xml", + custom_package = "dummy.package.for.so", + deps = [android_library], + # In some platforms we don't have an Android SDK/NDK and this target + # can't be built. We need to prevent the build system from trying to + # use the target in that case. + tags = ["manual"], + ) + + native.genrule( + name = name, + srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"], + outs = [name + ".aar"], + tags = ["manual"], + cmd = """ +cp $(location {}.aar) $(location :{}.aar) +chmod +w $(location :{}.aar) +origdir=$$PWD +cd $$(mktemp -d) +unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*" +cp -r lib jni +zip -r $$origdir/$(location :{}.aar) jni/*/*.so +""".format(android_library, name, name, name, name), + ) diff --git a/tensorflow/contrib/lite/java/build_aar_for_release.sh b/tensorflow/contrib/lite/java/build_aar_for_release.sh new file mode 100755 index 0000000000000000000000000000000000000000..fbcb1e7db9a3f9b885505e989b7ff7224f2d2b15 --- /dev/null +++ b/tensorflow/contrib/lite/java/build_aar_for_release.sh @@ -0,0 +1,66 @@ +#!/bin/bash +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +set -e +set -x + +TMPDIR=`mktemp -d` +trap "rm -rf $TMPDIR" EXIT + +VERSION=1.0 + +BUILDER=bazel +BASEDIR=tensorflow/contrib/lite +CROSSTOOL="//external:android/crosstool" +HOST_CROSSTOOL="@bazel_tools//tools/cpp:toolchain" + +BUILD_OPTS="--cxxopt=--std=c++11 -c opt" +CROSSTOOL_OPTS="--crosstool_top=$CROSSTOOL --host_crosstool_top=$HOST_CROSSTOOL" + +test -d $BASEDIR || (echo "Aborting: not at top-level build directory"; exit 1) + +function build_basic_aar() { + local OUTDIR=$1 + $BUILDER build $BUILD_OPTS $BASEDIR/java:tensorflowlite.aar + unzip -d $OUTDIR $BUILDER-bin/$BASEDIR/java/tensorflowlite.aar + # targetSdkVersion is here to prevent the app from requesting spurious + # permissions, such as permission to make phone calls. It worked for v1.0, + # but minSdkVersion might be the preferred way to handle this. + sed -i -e 's///' $OUTDIR/AndroidManifest.xml +} + +function build_arch() { + local ARCH=$1 + local CONFIG=$2 + local OUTDIR=$3 + mkdir -p $OUTDIR/jni/$ARCH/ + $BUILDER build $BUILD_OPTS $CROSSTOOL_OPTS --cpu=$CONFIG \ + $BASEDIR/java:libtensorflowlite_jni.so + cp $BUILDER-bin/$BASEDIR/java/libtensorflowlite_jni.so $OUTDIR/jni/$ARCH/ +} + +rm -rf $TMPDIR +mkdir -p $TMPDIR/jni + +build_basic_aar $TMPDIR +build_arch arm64-v8a arm64-v8a $TMPDIR +build_arch armeabi-v7a armeabi-v7a $TMPDIR +build_arch x86 x86 $TMPDIR +build_arch x86_64 x86_64 $TMPDIR + +AAR_FILE=`realpath tflite-${VERSION}.aar` +(cd $TMPDIR && zip $AAR_FILE -r *) +echo "New AAR file is $AAR_FILE" + diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md index 71b633c5774d93684f651821adad13c378a8243c..2e818f728ef208d30b0eeb27ffd7e3fa0c7c1a2d 100644 --- a/tensorflow/contrib/lite/java/demo/README.md +++ b/tensorflow/contrib/lite/java/demo/README.md @@ -8,7 +8,12 @@ It's easiest with Android Studio. - You'll need at least SDK version 23. + - Make sure to install the latest version of Bazel. Some distributions + ship with Bazel 0.5.4, which is too old. - Bazel requires Android Build Tools `26.0.1` or higher. + - **Bazel is incompatible with NDK revisions 15 and above,** with revision + 16 being a compile-breaking change. [Download an older version manually + instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites) - You also need to install the Android Support Repository, available through Android Studio under `Android SDK Manager -> SDK Tools -> Android Support Repository`. @@ -16,10 +21,15 @@ 2. [Edit your `WORKSPACE`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#edit-workspace) to add SDK and NDK targets. + NOTE: As long as you have the SDK and NDK installed, the `./configure` + script will create these rules for you. Answer "Yes" when the script asks + to automatically configure the `./WORKSPACE`. + - Make sure the `api_level` in `WORKSPACE` is set to an SDK version that you have installed. - By default, Android Studio will install the SDK to `~/Android/Sdk` and - the NDK to `~/Android/Sdk/ndk-bundle`. + the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual + download until Bazel supports NDK 16. See bullet points under (1)). 2. Build the app with Bazel. The demo needs C++11: diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java index e7bad4637041d003c1e507d81c0c30404c587653..e44c5ae6b48eda187079dd3a0a1bc563276d816e 100644 --- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java +++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java @@ -73,6 +73,11 @@ public class ImageClassifier { /** An array to hold inference results, to be feed into Tensorflow Lite as outputs. */ private byte[][] labelProbArray = null; + /** multi-stage low pass filter * */ + private float[][] filterLabelProbArray = null; + + private static final int FILTER_STAGES = 3; + private static final float FILTER_FACTOR = 0.4f; private PriorityQueue> sortedLabels = new PriorityQueue<>( @@ -93,6 +98,7 @@ public class ImageClassifier { DIM_BATCH_SIZE * DIM_IMG_SIZE_X * DIM_IMG_SIZE_Y * DIM_PIXEL_SIZE); imgData.order(ByteOrder.nativeOrder()); labelProbArray = new byte[1][labelList.size()]; + filterLabelProbArray = new float[FILTER_STAGES][labelList.size()]; Log.d(TAG, "Created a Tensorflow Lite Image Classifier."); } @@ -108,11 +114,38 @@ public class ImageClassifier { tflite.run(imgData, labelProbArray); long endTime = SystemClock.uptimeMillis(); Log.d(TAG, "Timecost to run model inference: " + Long.toString(endTime - startTime)); + + // Smooth the results across frames. + applyFilter(); + + // Print the results. String textToShow = printTopKLabels(); textToShow = Long.toString(endTime - startTime) + "ms" + textToShow; return textToShow; } + void applyFilter() { + int numLabels = labelList.size(); + + // Low pass filter `labelProbArray` into the first stage of the filter. + for (int j = 0; j < numLabels; ++j) { + filterLabelProbArray[0][j] += + FILTER_FACTOR * (labelProbArray[0][j] - filterLabelProbArray[0][j]); + } + // Low pass filter each stage into the next. + for (int i = 1; i < FILTER_STAGES; ++i) { + for (int j = 0; j < numLabels; ++j) { + filterLabelProbArray[i][j] += + FILTER_FACTOR * (filterLabelProbArray[i - 1][j] - filterLabelProbArray[i][j]); + } + } + + // Copy the last stage filter output back to `labelProbArray`. + for (int j = 0; j < numLabels; ++j) { + labelProbArray[0][j] = (byte)filterLabelProbArray[FILTER_STAGES - 1][j]; + } + } + /** Closes tflite to release resources. */ public void close() { tflite.close(); @@ -177,7 +210,7 @@ public class ImageClassifier { final int size = sortedLabels.size(); for (int i = 0; i < size; ++i) { Map.Entry label = sortedLabels.poll(); - textToShow = "\n" + label.getKey() + ":" + Float.toString(label.getValue()) + textToShow; + textToShow = String.format("\n%s: %4.2f", label.getKey(), label.getValue()) + textToShow; } return textToShow; } diff --git a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java index 1939a078ad8031b99620773c9b91335c4e8f7b22..5ee594dec492ad2fee22e603a6de311b3fed4cac 100644 --- a/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java +++ b/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/NativeInterpreterWrapper.java @@ -34,7 +34,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { NativeInterpreterWrapper(String modelPath) { errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModel(modelPath, errorHandle); - interpreterHandle = createInterpreter(modelHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle); } /** @@ -46,7 +46,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { modelByteBuffer = mappedByteBuffer; errorHandle = createErrorReporter(ERROR_BUFFER_SIZE); modelHandle = createModelWithBuffer(modelByteBuffer, errorHandle); - interpreterHandle = createInterpreter(modelHandle); + interpreterHandle = createInterpreter(modelHandle, errorHandle); } /** Releases resources associated with this {@code NativeInterpreterWrapper}. */ @@ -103,11 +103,22 @@ final class NativeInterpreterWrapper implements AutoCloseable { return outputs; } + private static native long[] run( + long interpreterHandle, + long errorHandle, + Object[] sizes, + int[] dtypes, + int[] numsOfBytes, + Object[] values); + /** Resizes dimensions of a specific input. */ void resizeInput(int idx, int[] dims) { resizeInput(interpreterHandle, errorHandle, idx, dims); } + private static native void resizeInput( + long interpreterHandle, long errorHandle, int inputIdx, int[] dims); + void setUseNNAPI(boolean useNNAPI) { useNNAPI(interpreterHandle, useNNAPI); } @@ -245,9 +256,6 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native String[] getOutputNames(long interpreterHandle); - private static native void resizeInput( - long interpreterHandle, long errorHandle, int inputIdx, int[] dims); - private static native void useNNAPI(long interpreterHandle, boolean state); private static native long createErrorReporter(int size); @@ -256,15 +264,7 @@ final class NativeInterpreterWrapper implements AutoCloseable { private static native long createModelWithBuffer(MappedByteBuffer modelBuffer, long errorHandle); - private static native long createInterpreter(long modelHandle); - - private static native long[] run( - long interpreterHandle, - long errorHandle, - Object[] sizes, - int[] dtypes, - int[] numsOfBytes, - Object[] values); + private static native long createInterpreter(long modelHandle, long errorHandle); private static native void delete(long errorHandle, long modelHandle, long interpreterHandle); diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc index bc6462eb5466e14769f94c5103984f5201b4b8dc..c346f9f92e360c0722ebac440d790da6441ceecf 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.cc @@ -200,6 +200,12 @@ TfLiteStatus setInputs(JNIEnv* env, tflite::Interpreter* interpreter, return kTfLiteOk; } +// TODO(yichengfan): evaluate the benefit to use tflite verifier. +bool VerifyModel(const void* buf, size_t len) { + flatbuffers::Verifier verifier(static_cast(buf), len); + return tflite::VerifyModelBuffer(verifier); +} + } // namespace JNIEXPORT jobjectArray JNICALL @@ -271,6 +277,17 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModel( convertLongToErrorReporter(env, error_handle); if (error_reporter == nullptr) return 0; const char* path = env->GetStringUTFChars(model_file, nullptr); + + { + tflite::FileCopyAllocation allocation(path, nullptr); + if (!VerifyModel(allocation.base(), allocation.bytes())) { + throwException(env, kIllegalArgumentException, + "Contents of %s is not a valid flatbuffer model", path); + env->ReleaseStringUTFChars(model_file, path); + return 0; + } + } + auto model = tflite::FlatBufferModel::BuildFromFile(path, error_reporter); if (!model) { throwException(env, kIllegalArgumentException, @@ -293,6 +310,12 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( const char* buf = static_cast(env->GetDirectBufferAddress(model_buffer)); jlong capacity = env->GetDirectBufferCapacity(model_buffer); + if (!VerifyModel(buf, capacity)) { + throwException(env, kIllegalArgumentException, + "MappedByteBuffer is not a valid flatbuffer model"); + return 0; + } + auto model = tflite::FlatBufferModel::BuildFromBuffer( buf, static_cast(capacity), error_reporter); if (!model) { @@ -307,12 +330,21 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle) { + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle) { tflite::FlatBufferModel* model = convertLongToModel(env, model_handle); if (model == nullptr) return 0; + BufferErrorReporter* error_reporter = + convertLongToErrorReporter(env, error_handle); + if (error_reporter == nullptr) return 0; auto resolver = ::tflite::CreateOpResolver(); std::unique_ptr interpreter; - tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + TfLiteStatus status = + tflite::InterpreterBuilder(*model, *(resolver.get()))(&interpreter); + if (status != kTfLiteOk) { + throwException(env, kIllegalArgumentException, + "Cannot create interpreter: %s", + error_reporter->CachedErrorMessage()); + } return reinterpret_cast(interpreter.release()); } diff --git a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h index 430886b7cc04a356d1826843acc1bbebf4189bf7..c52a7e4e439936344be26d5761fb5747db64794a 100644 --- a/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h +++ b/tensorflow/contrib/lite/java/src/main/native/nativeinterpreterwrapper_jni.h @@ -95,11 +95,11 @@ Java_org_tensorflow_lite_NativeInterpreterWrapper_createModelWithBuffer( /* * Class: org_tensorflow_lite_NativeInterpreterWrapper * Method: - * Signature: (J)J + * Signature: (JJ)J */ JNIEXPORT jlong JNICALL Java_org_tensorflow_lite_NativeInterpreterWrapper_createInterpreter( - JNIEnv* env, jclass clazz, jlong model_handle); + JNIEnv* env, jclass clazz, jlong model_handle, jlong error_handle); /* * Class: org_tensorflow_lite_NativeInterpreterWrapper diff --git a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java index 9a6894f49c0b7278511717d2671648c6d1763e00..90323555d88419d837a76bca7de6d9998e388fca 100644 --- a/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java +++ b/tensorflow/contrib/lite/java/src/test/java/org/tensorflow/lite/NativeInterpreterWrapperTest.java @@ -25,6 +25,7 @@ import org.junit.runner.RunWith; import org.junit.runners.JUnit4; /** Unit tests for {@link org.tensorflow.lite.NativeInterpreterWrapper}. */ +// TODO(b/71818425): Generates model files dynamically. @RunWith(JUnit4.class) public final class NativeInterpreterWrapperTest { @@ -43,6 +44,9 @@ public final class NativeInterpreterWrapperTest { private static final String INVALID_MODEL_PATH = "tensorflow/contrib/lite/java/src/testdata/invalid_model.bin"; + private static final String MODEL_WITH_CUSTOM_OP_PATH = + "tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite"; + @Test public void testConstructor() { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(FLOAT_MODEL_PATH); @@ -55,10 +59,20 @@ public final class NativeInterpreterWrapperTest { try { NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(INVALID_MODEL_PATH); fail(); + } catch (IllegalArgumentException e) { + assertThat(e).hasMessageThat().contains("is not a valid flatbuffer model"); + } + } + + @Test + public void testConstructorWithUnresolableCustomOp() { + try { + NativeInterpreterWrapper wrapper = new NativeInterpreterWrapper(MODEL_WITH_CUSTOM_OP_PATH); + fail(); } catch (IllegalArgumentException e) { assertThat(e) .hasMessageThat() - .contains("Model provided has model identifier ' is ', should be 'TFL3'"); + .contains("Cannot create interpreter: Didn't find custom op for name 'Assign'"); } } diff --git a/tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite b/tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite new file mode 100644 index 0000000000000000000000000000000000000000..e775d56d88854ecdf70233262ff5884d224f4373 Binary files /dev/null and b/tensorflow/contrib/lite/java/src/testdata/with_custom_op.lite differ diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index bbbfa3e7415bfd7a34dfc7d764da55cac22e7d42..5d553def0a213da2350cdedc159de43b4d8cff04 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -32,6 +32,7 @@ cc_library( "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:schema_fbs_version", "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/testing:util", "//tensorflow/core:lib", "@com_google_googletest//:gtest", ], @@ -49,7 +50,7 @@ cc_library( deps = [ ":op_macros", "//tensorflow/contrib/lite:context", - "@gemmlowp//:gemmlowp", + "@gemmlowp", ], ) @@ -70,35 +71,73 @@ cc_library( ], ) +cc_library( + name = "kernel_util", + srcs = [ + "kernel_util.cc", + ], + hdrs = [ + "kernel_util.h", + ], + deps = [ + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:context", + "//tensorflow/contrib/lite/kernels/internal:round", + ], +) + +tf_cc_test( + name = "kernel_util_test", + size = "small", + srcs = ["kernel_util_test.cc"], + deps = [ + ":kernel_util", + "//tensorflow/contrib/lite/testing:util", + "@com_google_googletest//:gtest", + ], +) + cc_library( name = "builtin_ops", srcs = [ "activations.cc", "add.cc", "basic_rnn.cc", + "batch_to_space_nd.cc", + "bidirectional_sequence_rnn.cc", "concatenation.cc", "conv.cc", "depthwise_conv.cc", + "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", + "exp.cc", "fully_connected.cc", + "gather.cc", "hashtable_lookup.cc", - "kernel_util.cc", "l2norm.cc", "local_response_norm.cc", "lsh_projection.cc", "lstm.cc", + "mean.cc", "mul.cc", + "pad.cc", "pooling.cc", "register.cc", "reshape.cc", "resize_bilinear.cc", "skip_gram.cc", + "space_to_batch_nd.cc", "space_to_depth.cc", + "squeeze.cc", + "strided_slice.cc", + "sub.cc", "svdf.cc", + "transpose.cc", + "unidirectional_sequence_lstm.cc", + "unidirectional_sequence_rnn.cc", ], hdrs = [ - "kernel_util.h", "padding.h", "register.h", ], @@ -112,11 +151,13 @@ cc_library( }), deps = [ ":activation_functor", + ":kernel_util", ":op_macros", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite:string_util", "//tensorflow/contrib/lite/kernels:gemm_support", + "//tensorflow/contrib/lite/kernels/internal:kernel_utils", "//tensorflow/contrib/lite/kernels/internal:optimized", "//tensorflow/contrib/lite/kernels/internal:optimized_base", "//tensorflow/contrib/lite/kernels/internal:quantization_util", @@ -152,6 +193,44 @@ tf_cc_test( ], ) +tf_cc_test( + name = "transpose_test", + size = "small", + srcs = ["transpose_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "//tensorflow/contrib/lite/kernels/internal:reference", + "//tensorflow/contrib/lite/kernels/internal:reference_base", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "space_to_batch_nd_test", + size = "small", + srcs = ["space_to_batch_nd_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "batch_to_space_nd_test", + size = "small", + srcs = ["batch_to_space_nd_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "concatenation_test", size = "small", @@ -172,6 +251,7 @@ tf_cc_test( ":builtin_ops", "//tensorflow/contrib/lite:framework", "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", "@com_google_googletest//:gtest", ], ) @@ -200,6 +280,42 @@ tf_cc_test( ], ) +tf_cc_test( + name = "unidirectional_sequence_lstm_test", + size = "small", + srcs = ["unidirectional_sequence_lstm_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "bidirectional_sequence_rnn_test", + size = "small", + srcs = ["bidirectional_sequence_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "unidirectional_sequence_rnn_test", + size = "small", + srcs = ["unidirectional_sequence_rnn_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "l2norm_test", size = "small", @@ -212,6 +328,30 @@ tf_cc_test( ], ) +tf_cc_test( + name = "exp_test", + size = "small", + srcs = ["exp_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "mean_test", + size = "small", + srcs = ["mean_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "mul_test", size = "small", @@ -224,6 +364,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "pad_test", + size = "small", + srcs = ["pad_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "reshape_test", size = "small", @@ -236,6 +388,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_test", + size = "small", + srcs = ["gather_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "resize_bilinear_test", size = "small", @@ -395,6 +560,30 @@ tf_cc_test( ], ) +tf_cc_test( + name = "squeeze_test", + size = "small", + srcs = ["squeeze_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( + name = "strided_slice_test", + size = "small", + srcs = ["strided_slice_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/kernels/activation_functor.h b/tensorflow/contrib/lite/kernels/activation_functor.h index cfb3369e991a474315424423fe655ba214edabbc..41ec3cca33ae1c6bb3f7c43dd1923f104c2ab6a2 100644 --- a/tensorflow/contrib/lite/kernels/activation_functor.h +++ b/tensorflow/contrib/lite/kernels/activation_functor.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ #include #include @@ -55,4 +55,4 @@ class ActivationFunctor { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_ACTIVATION_FUNCTOR_H_ diff --git a/tensorflow/contrib/lite/kernels/activations.cc b/tensorflow/contrib/lite/kernels/activations.cc index 7ab60a33e5e2ff61bae5f4c6db85ab9c47a391bc..3c5c77815d0f2592ab549152b4d77f45b967a660 100644 --- a/tensorflow/contrib/lite/kernels/activations.cc +++ b/tensorflow/contrib/lite/kernels/activations.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include @@ -134,8 +134,7 @@ TfLiteStatus ReluEval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::max(0.f, *in); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -173,8 +172,7 @@ TfLiteStatus Relu6Eval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::min(std::max(0.f, *in), 6.f); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -192,8 +190,7 @@ TfLiteStatus TanhEval(TfLiteContext* context, TfLiteNode* node) { float* out = output->data.f; for (; in < in_end; in++, out++) *out = std::tanh(*in); return kTfLiteOk; - } - break; + } break; default: context->ReportError(context, "Only float32 supported currently."); return kTfLiteError; @@ -349,7 +346,7 @@ TfLiteRegistration* Register_RELU() { return &r; } -TfLiteRegistration* Register_RELU1() { +TfLiteRegistration* Register_RELU_N1_TO_1() { static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, activations::GenericPrepare, activations::Relu1Eval}; diff --git a/tensorflow/contrib/lite/kernels/activations_test.cc b/tensorflow/contrib/lite/kernels/activations_test.cc index f10aee70170d4a94ed54376fa410b22a60f109af..68d49944e51b043b6b82aa1589d22f6ebed37574 100644 --- a/tensorflow/contrib/lite/kernels/activations_test.cc +++ b/tensorflow/contrib/lite/kernels/activations_test.cc @@ -102,7 +102,7 @@ TEST(FloatActivationsOpTest, Relu) { } TEST(FloatActivationsOpTest, Relu1) { - FloatActivationsOpModel m(BuiltinOperator_RELU1, + FloatActivationsOpModel m(BuiltinOperator_RELU_N1_TO_1, /*input=*/{TensorType_FLOAT32, {1, 2, 4, 1}}); m.SetInput({ 0.0, -0.6, 0.2, -0.4, // @@ -317,7 +317,7 @@ TEST(QuantizedActivationsOpTest, Softmax2D) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/add.cc b/tensorflow/contrib/lite/kernels/add.cc index 0e10a249abac3ba19cf107e055aa71d1eee00122..63ea89df56bafa995950afec3a58267681af304f 100644 --- a/tensorflow/contrib/lite/kernels/add.cc +++ b/tensorflow/contrib/lite/kernels/add.cc @@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); - for (int i = 0; i < NumDimensions(input1); ++i) { - TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), - SizeOfDimension(input2, i)); - } + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + output->type = input2->type; - TF_LITE_ENSURE_EQ(context, input1->type, output->type); - TF_LITE_ENSURE_EQ(context, input2->type, output->type); + data->requires_broadcast = !HaveSameShapes(input1, input2); + + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } - TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); return context->ResizeTensor(context, output, output_size); } template void EvalAddFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteAddParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_ADD(type) \ - type::Add(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) - if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops); +#define TF_LITE_ADD(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + if (data->requires_broadcast) { + TF_LITE_ADD(reference_ops, BroadcastAdd); } else { - TF_LITE_ADD(optimized_ops); + TF_LITE_ADD(reference_ops, Add); + } + } else { + if (data->requires_broadcast) { + TF_LITE_ADD(optimized_ops, BroadcastAdd); + } else { + TF_LITE_ADD(optimized_ops, Add); + } } #undef TF_LITE_ADD } template void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteAddParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteAddParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; auto output_offset = output->params.zero_point; @@ -112,19 +141,20 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_ADD(type) \ - type::BroadcastAdd( \ - left_shift, GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, input1_multiplier, input1_shift, \ - GetTensorData(input2), GetTensorDims(input2), input2_offset, \ - input2_multiplier, input2_shift, output_offset, output_multiplier, \ - output_shift, output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)); - +#define TF_LITE_ADD(type, opname) \ + type::opname(left_shift, GetTensorData(input1), \ + GetTensorDims(input1), input1_offset, input1_multiplier, \ + input1_shift, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, input2_multiplier, \ + input2_shift, output_offset, output_multiplier, output_shift, \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)); + // The quantized version of Add doesn't support activations, so we + // always use BroadcastAdd. if (kernel_type == kReference) { - TF_LITE_ADD(reference_ops); + TF_LITE_ADD(reference_ops, BroadcastAdd); } else { - TF_LITE_ADD(optimized_ops); + TF_LITE_ADD(optimized_ops, BroadcastAdd); } #undef TF_LITE_ADD } @@ -132,15 +162,17 @@ void EvalAddQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { - EvalAddFloat(context, node, params, input1, input2, output); + EvalAddFloat(context, node, params, data, input1, input2, + output); } else if (output->type == kTfLiteUInt8) { - EvalAddQuantized(context, node, params, input1, input2, + EvalAddQuantized(context, node, params, data, input1, input2, output); } else { context->ReportError(context, @@ -154,19 +186,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace add TfLiteRegistration* Register_ADD_REF() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } TfLiteRegistration* Register_ADD_GENERIC_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } TfLiteRegistration* Register_ADD_NEON_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, add::Prepare, + static TfLiteRegistration r = {add::Init, add::Free, add::Prepare, add::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/add_test.cc b/tensorflow/contrib/lite/kernels/add_test.cc index 8e12a837c4954832ff37a6d1ab377bee9e8d5763..956d05bed5162f6ce59705d59aad77ff056dda77 100644 --- a/tensorflow/contrib/lite/kernels/add_test.cc +++ b/tensorflow/contrib/lite/kernels/add_test.cc @@ -25,10 +25,11 @@ using ::testing::ElementsAreArray; class BaseAddOpModel : public SingleOpModel { public: - BaseAddOpModel(const TensorData& input, const TensorData& output, + BaseAddOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type) { - input1_ = AddInput(input); - input2_ = AddInput(input); + input1_ = AddInput(input1); + input2_ = AddInput(input2); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_ADD, BuiltinOptions_AddOptions, CreateAddOptions(builder_, activation_type).Union()); @@ -70,6 +71,7 @@ float GetTolerance(int min, int max) { TEST(FloatAddOpModel, NoActivation) { FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); @@ -77,9 +79,10 @@ TEST(FloatAddOpModel, NoActivation) { EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1.9, 0.4, 1.0, 1.3})); } -TEST(FloatAddOpModel, ActivationRELU1) { - FloatAddOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); +TEST(FloatAddOpModel, ActivationRELU_N1_TO_1) { + FloatAddOpModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); m.Invoke(); @@ -91,6 +94,7 @@ TEST(FloatAddOpModel, VariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); @@ -101,6 +105,23 @@ TEST(FloatAddOpModel, VariousInputShapes) { } } +TEST(FloatAddOpModel, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatAddOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}))) + << "With shape number " << i; + } +} + TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = { @@ -111,6 +132,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { {0.7, 0.6, 0.6, 0.5}, {-0.2, 0.6, 0.9, -0.1}, {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), inputs1[i]); @@ -122,7 +144,7 @@ TEST(QuantizedAddOpModel, QuantizedTestsNoActivation) { } } -TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { +TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU_N1_TO_1) { float kQuantizedTolerance = GetTolerance(-1.0, 1.0); std::vector> inputs1 = {{-0.8, 0.2, 0.9, 0.7}, {-0.8, 0.2, 0.7, 0.3}}; @@ -132,8 +154,9 @@ TEST(QuantizedAddOpModel, QuantizedTestsActivationRELU1) { {-0.2, 0.6, -0.1, 0.8}}; for (int i = 0; i < inputs1.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, - ActivationFunctionType_RELU1); + ActivationFunctionType_RELU_N1_TO_1); m.QuantizeAndPopulate(m.input1(), inputs1[i]); m.QuantizeAndPopulate(m.input2(), inputs2[i]); m.Invoke(); @@ -149,6 +172,7 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, test_shapes[i], -3.0, 3.0}, {TensorType_UINT8, {}, -3.0, 3.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); @@ -161,11 +185,29 @@ TEST(QuantizedAddOpModel, QuantizedVariousInputShapes) { } } +TEST(QuantizedAddOpModel, QuantizedWithBroadcast) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedAddOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear({-1.9, 0.3, 0.8, 0.9, 1.2, 2.1}, + kQuantizedTolerance))) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn.cc b/tensorflow/contrib/lite/kernels/basic_rnn.cc index 3cee43c68b2a0af5a3fd84b33a980b74bb8f0cb4..2c5074eca3176c7f33a6f051b492dc41333257ed 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn.cc @@ -15,14 +15,15 @@ limitations under the License. #include #include #include -#include #include +#include #include #include #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" #include "tensorflow/contrib/lite/kernels/op_macros.h" namespace tflite { @@ -76,8 +77,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(2); output_size_array->data[0] = batch_size; output_size_array->data[1] = num_units; - TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, output, - output_size_array)); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); return kTfLiteOk; } @@ -101,50 +102,20 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const int batch_size = input->dims->data[0]; const int num_units = input_weights->dims->data[0]; const int input_size = input->dims->data[1]; - const int input_weights_stride = input_weights->dims->data[1]; - const int recurrent_weights_stride = recurrent_weights->dims->data[1]; - - // For each batch - for (int b = 0; b < batch_size; b++) { - // Initialize the pointer to input, output and bias. - const float* input_ptr_batch = input->data.f + b * input_size; - float* output_ptr_batch = output->data.f + b * num_units; - float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; - - // Initialize input_weights and recurrent_weights. - const float* input_weights_ptr = input_weights->data.f; - const float* recurrent_weights_ptr = recurrent_weights->data.f; - - // Output = bias - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = bias_ptr[o]; - } - - // Output += input * input_weights - for (int o = 0; o < num_units; o++) { - for (int i = 0; i < input_size; i++) { - output_ptr_batch[o] += input_ptr_batch[i] * input_weights_ptr[i]; - } - input_weights_ptr += input_weights_stride; - } - - // Output += recurrent_weights * hidden_state - for (int o = 0; o < num_units; o++) { - for (int h = 0; h < num_units; h++) { - output_ptr_batch[o] += - hidden_state_ptr_batch[h] * recurrent_weights_ptr[h]; - } - recurrent_weights_ptr += recurrent_weights_stride; - } - - // Output = activation(Output) and update hidden_state - for (int o = 0; o < num_units; o++) { - output_ptr_batch[o] = - (ActivationFunctor(params->activation))(output_ptr_batch[o]); - hidden_state_ptr_batch[o] = output_ptr_batch[o]; - } - } + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Initialize the pointer to input and output. + const float* input_ptr_batch = input->data.f; + float* output_ptr_batch = output->data.f; + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc index dfa75655bcfe7762c6cc4c9a98a71d529028c03a..fa7ef525db47c93f98951604cd04da66196422d7 100644 --- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite RNN op. -#include #include +#include #include #include @@ -120,8 +120,7 @@ static float rnn_golden_output[] = { 0.415153, 0.210318, 0, 0, 0, 0, 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, - 0.628881, 3.58099, 1.49974, 0 -}; + 0.628881, 3.58099, 1.49974, 0}; class RNNOpModel : public SingleOpModel { public: @@ -261,7 +260,7 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc new file mode 100644 index 0000000000000000000000000000000000000000..bc438f99c6a72fdbc2794dee03524db6a7523834 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -0,0 +1,188 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace batch_to_space_nd { + +// This file has two implementations of BatchToSpaceND. +enum KernelType { + kReference, + kGenericOptimized, +}; + +struct BatchToSpaceNDContext { + BatchToSpaceNDContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + crops = GetInput(context, node, 2); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* crops; + TfLiteTensor* output; +}; + +// Currently, only 4D NHWC input/output op_context are supported. +// The 4D array need to have exactly 2 spatial dimensions. +// TODO(ycling): Support arbitrary dimension in BatchToSpaceND. +const int kInputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; +const int kSpatialDimensionNum = 2; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + BatchToSpaceNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int* block_shape = GetTensorData(op_context->block_shape); + const int* crops = GetTensorData(op_context->crops); + + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), + kSpatialDimensionNum); + + // TODO(ycling): Add crops as part of calculation. Remove check for a crops + // containing all zeroes. + TF_LITE_ENSURE_EQ(context, crops[0], 0); + TF_LITE_ENSURE_EQ(context, crops[1], 0); + TF_LITE_ENSURE_EQ(context, crops[2], 0); + TF_LITE_ENSURE_EQ(context, crops[3], 0); + + // Number of batch must be multiple of (block_shape[0] * block_shape[1]). + TF_LITE_ENSURE_EQ(context, + input_size->data[0] % (block_shape[0] * block_shape[1]), 0); + + const int output_batch_size = + input_size->data[0] / (block_shape[0] * block_shape[1]); + const int output_height = input_size->data[1] * block_shape[0]; + const int output_width = input_size->data[2] * block_shape[1]; + const int output_channel_size = input_size->data[3]; + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); + output_size->data[0] = output_batch_size; + output_size->data[1] = output_height; + output_size->data[2] = output_width; + output_size->data[3] = output_channel_size; + + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + BatchToSpaceNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.crops)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + BatchToSpaceNDContext op_context(context, node); + + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_BATCH_TO_SPACE_ND(type, scalar) \ + type::BatchToSpaceND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + switch (op_context.input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, float); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, uint8_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int32_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_BATCH_TO_SPACE_ND(reference_ops, int64_t); + } else { + TF_LITE_BATCH_TO_SPACE_ND(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by BatchToSpace."); + return kTfLiteError; + } +#undef TF_LITE_BATCH_TO_SPACE_ND + return kTfLiteOk; +} + +} // namespace batch_to_space_nd + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, batch_to_space_nd::Prepare, + batch_to_space_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_BATCH_TO_SPACE_ND() { + return Register_BATCH_TO_SPACE_ND_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8485cde1b40066f2070855bca91ea78a9f80e83c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -0,0 +1,142 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BatchToSpaceNDOpModel : public SingleOpModel { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetCrops(std::initializer_list data) { + PopulateTensor(crops_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int block_shape_; + int crops_; + int output_; +}; + +// Tests case where block_shape and crops are const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpConstModel m(input_shape, block_shape, crops); +// m.SetInput(input_data); +// m.Invoke(); +class BatchToSpaceNDOpConstModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list crops) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + crops_ = AddConstInput(TensorType_INT32, crops, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and crops are non-const tensors. +// +// Example usage is as follows: +// BatchToSpaceNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(crops); +// m.Invoke(); +class BatchToSpaceNDOpDynamicModel : public BatchToSpaceNDOpModel { + public: + BatchToSpaceNDOpDynamicModel(std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + crops_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_BATCH_TO_SPACE_ND, + BuiltinOptions_BatchToSpaceNDOptions, + CreateBatchToSpaceNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + +TEST(BatchToSpaceNDOpTest, SimpleConstTest) { + BatchToSpaceNDOpConstModel m({4, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, SimpleDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5, 2, 6, 9, 13, 10, 14, 3, 7, + 4, 8, 11, 15, 12, 16})); +} + +TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 0}), + "Cannot allocate tensors"); +} + +TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) { + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}), + "1 != 0"); +} + +TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) { + BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetCrops({0, 0, 1, 0}); + EXPECT_DEATH(m.Invoke(), "1 != 0"); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..aa24c1f34cd1e8c02a6a75b62fbe5f3c629498ca --- /dev/null +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -0,0 +1,205 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace bidirectional_sequence_rnn { + +constexpr int kInputTensor = 0; +// Forward and backward cell tensors. +constexpr int kFwWeightsTensor = 1; +constexpr int kFwRecurrentWeightsTensor = 2; +constexpr int kFwBiasTensor = 3; +constexpr int kBwWeightsTensor = 4; +constexpr int kBwRecurrentWeightsTensor = 5; +constexpr int kBwBiasTensor = 6; +// State and output tensors. +constexpr int kFwHiddenStateTensor = 0; +constexpr int kFwOutputTensor = 1; +constexpr int kBwHiddenStateTensor = 2; +constexpr int kBwOutputTensor = 3; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 7); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* fw_input_weights = + &context->tensors[node->inputs->data[kFwWeightsTensor]]; + TfLiteTensor* fw_recurrent_weights = + &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; + TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; + TfLiteTensor* bw_input_weights = + &context->tensors[node->inputs->data[kBwWeightsTensor]]; + TfLiteTensor* bw_recurrent_weights = + &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; + TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + const int batch_size = input->dims->data[0]; + const int max_time = input->dims->data[1]; + const int fw_num_units = fw_input_weights->dims->data[0]; + const int bw_num_units = bw_input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input->dims->data[2], bw_input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(fw_input_weights->dims->data[0], fw_bias->dims->data[0]); + TF_LITE_ASSERT_EQ(bw_input_weights->dims->data[0], bw_bias->dims->data[0]); + TF_LITE_ASSERT_EQ(fw_recurrent_weights->dims->data[0], + fw_bias->dims->data[0]); + TF_LITE_ASSERT_EQ(bw_recurrent_weights->dims->data[1], + bw_bias->dims->data[0]); + + TfLiteTensor* fw_output = + &context->tensors[node->outputs->data[kFwOutputTensor]]; + TfLiteTensor* bw_output = + &context->tensors[node->outputs->data[kBwOutputTensor]]; + + // Resize hidden states. + TfLiteIntArray* fw_hidden_state_size_array = TfLiteIntArrayCreate(2); + fw_hidden_state_size_array->data[0] = batch_size; + fw_hidden_state_size_array->data[1] = fw_num_units; + TfLiteTensor* fw_hidden_state = + &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_hidden_state, + fw_hidden_state_size_array)); + + TfLiteIntArray* bw_hidden_state_size_array = TfLiteIntArrayCreate(2); + bw_hidden_state_size_array->data[0] = batch_size; + bw_hidden_state_size_array->data[1] = fw_num_units; + TfLiteTensor* bw_hidden_state = + &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_hidden_state, + bw_hidden_state_size_array)); + + // Mark hidden states as a persistent tensor. + fw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; + bw_hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize outputs. + TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); + fw_output_size_array->data[0] = batch_size; + fw_output_size_array->data[1] = max_time; + fw_output_size_array->data[2] = fw_num_units; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, fw_output, fw_output_size_array)); + TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); + bw_output_size_array->data[0] = batch_size; + bw_output_size_array->data[1] = max_time; + bw_output_size_array->data[2] = bw_num_units; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_output, bw_output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* fw_input_weights = + &context->tensors[node->inputs->data[kFwWeightsTensor]]; + TfLiteTensor* fw_recurrent_weights = + &context->tensors[node->inputs->data[kFwRecurrentWeightsTensor]]; + TfLiteTensor* fw_bias = &context->tensors[node->inputs->data[kFwBiasTensor]]; + TfLiteTensor* fw_hidden_state = + &context->tensors[node->outputs->data[kFwHiddenStateTensor]]; + TfLiteTensor* fw_output = + &context->tensors[node->outputs->data[kFwOutputTensor]]; + + TfLiteTensor* bw_input_weights = + &context->tensors[node->inputs->data[kBwWeightsTensor]]; + TfLiteTensor* bw_recurrent_weights = + &context->tensors[node->inputs->data[kBwRecurrentWeightsTensor]]; + TfLiteTensor* bw_bias = &context->tensors[node->inputs->data[kBwBiasTensor]]; + TfLiteTensor* bw_hidden_state = + &context->tensors[node->outputs->data[kBwHiddenStateTensor]]; + TfLiteTensor* bw_output = + &context->tensors[node->outputs->data[kBwOutputTensor]]; + + const int batch_size = input->dims->data[0]; + const int max_time = input->dims->data[1]; + const int input_size = input->dims->data[2]; + + const int fw_num_units = fw_input_weights->dims->data[0]; + const float* fw_bias_ptr = fw_bias->data.f; + const float* fw_input_weights_ptr = fw_input_weights->data.f; + const float* fw_recurrent_weights_ptr = fw_recurrent_weights->data.f; + + const int bw_num_units = bw_input_weights->dims->data[0]; + const float* bw_bias_ptr = bw_bias->data.f; + const float* bw_input_weights_ptr = bw_input_weights->data.f; + const float* bw_recurrent_weights_ptr = bw_recurrent_weights->data.f; + + for (int b = 0; b < batch_size; b++) { + // Forward cell. + float* fw_hidden_state_ptr_batch = + fw_hidden_state->data.f + b * fw_num_units; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_recurrent_weights_ptr, + fw_bias_ptr, input_size, fw_num_units, /*batch_size=*/1, + params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = + bw_hidden_state->data.f + b * bw_num_units; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_recurrent_weights_ptr, + bw_bias_ptr, input_size, bw_num_units, /*batch_size=*/1, + params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); + } + } + return kTfLiteOk; +} + +} // namespace bidirectional_sequence_rnn + +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + bidirectional_sequence_rnn::Prepare, + bidirectional_sequence_rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..12f4ff97cfd90e3a6894a24d15fcbc356f96cde2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -0,0 +1,931 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Bidirectional RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_fw_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0}; + +static float rnn_golden_bw_output[] = { + 0.496726, 0, 1.00883, 0, 0.0584256, 0, 0, + 0.236412, 0, 0, 0.612267, 0.487726, 0, 0.54883, + 1.16099, 0.0291233, 0, 0, 0.428302, 0, 0, + 0, 0, 1.13262, 0, 1.64415, 0, 0.311249, + 0.570804, 0.259696, 0, 0, 0, 0, 0, + 0.262334, 0, 0, 0, 1.23781, 0, 2.86532, + 0, 0, 1.34389, 2.76409, 0, 0, 1.03969, + 0, 0.00410865, 0, 0.0470295, 0, 0, 0, + 0.371556, 0.27175, 1.36614, 1.63956, 0.683887, 1.06176, 0.719552, + 0.301314, 0.971195, 0, 0.697143, 0, 0.215219, 0.210693, + 0.363027, 0, 0.501283, 0, 1.13399, 0.623774, 0, + 1.09851, 1.33313, 0.470441, 0.210965, 0, 0.664178, 0, + 0.839686, 0, 0, 0.147834, 0, 0, 0, + 0.58786, 0.490128, 0, 0.905806, 0, 0.932134, 0.424257, + 0, 0, 0.860629, 0, 0, 0, 0.476425, + 0, 0.566017, 0.513721, 0.207341, 1.09508, 1.08385, 0, + 0.973787, 0, 0, 0, 0, 0, 0, + 1.20698, 0, 0, 0, 1.56135, 1.12369, 0.99588, + 0.459803, 0, 0.915854, 0, 0, 0, 0, + 0, 0, 2.03206, 0, 0.773264, 0.267228, 1.55012, + 1.202, 1.51611, 0.701202, 0, 0.725088, 0, 0.509069, + 0, 0.671349, 0.581129, 0.343447, 0, 0.107755, 0.611838, + 1.4331, 1.55871, 0.015242, 0.140624, 0.492562, 0.395095, 0.147722, + 0, 0.784925, 0, 1.65477, 0.715257, 0, 0, + 0, 0.685024, 0, 1.89505, 1.00037, 0, 0, + 0, 0, 0, 1.52659, 0, 0, 0, + 0, 0.618583, 0, 0.11115, 0, 1.37194, 0.630225, + 0, 0, 0, 0, 0, 0.0322124, 0, + 0, 0, 0, 0.430834, 0.252786, 0, 0, + 0, 0.991297, 1.98451, 0, 0, 0.111511, 0, + 1.05513, 0, 0, 0, 0, 0, 0, + 0.290445, 0.412559, 0.0429958, 0.256564, 1.27858, 0.289948, 0, + 1.01693, 0.327141, 0, 0, 0, 0, 0, + 1.83508, 0.346248, 0, 0.961535, 0.790026, 0.552203, 2.13457, + 2.19233, 0.333244, 0.316526, 0.179398, 0, 0, 0, + 0, 0, 1.86126, 0, 0.728256, 0.750013, 0.011861, + 0.576383, 3.38891, 1.29273, 0}; + +constexpr std::initializer_list weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static float endtoend_input[] = { + 0.996808, 0.060710, 0.981855, 0.570017, 0.525164, 0.796859, 0.696547, + 0.505925, 0.991844, 0.461208, 0.949371, 0.027624, 0.539236, 0.841854, + 0.915222, 0.538569, 0.069375, 0.237905, 0.903700, 0.441703, 0.536196, + 0.402724, 0.761635, 0.025063, 0.082592, 0.688245, 0.239310, 0.256931, + 0.658900, 0.105695, 0.301983, 0.655708, 0.166405, 0.283837, 0.225725, + 0.691569, 0.080696, 0.922272, 0.197494, 0.072540, 0.383481, 0.146865, + 0.100163, 0.922717, 0.988720, 0.015386, 0.461286, 0.058095, 0.253290, + 0.364986, 0.499797, 0.789487, 0.767709, 0.261433, 0.814549, 0.850302, + 0.949678, 0.053859, 0.107233, 0.608577, 0.159554, 0.409215, 0.264285, + 0.325960, 0.693053, 0.490011, 0.017529, 0.773749, 0.412283, 0.215023, + 0.846288, 0.795764, 0.361889, 0.946452, 0.718481, 0.350608, 0.961837, + 0.179767, 0.408703, 0.215128, 0.544753, 0.908500, 0.004614, 0.312462, + 0.169933, 0.819163, 0.162764, 0.119611, 0.873022, 0.269997, 0.728188, + 0.032576, 0.679212, 0.992474, 0.358536, 0.372265, 0.482484, 0.376065, + 0.146014, 0.894767, 0.591088, 0.992302, 0.690531, 0.952977, 0.938754, + 0.409012, 0.303585, 0.900591, 0.588780, 0.712287, 0.115719, 0.133533, + 0.620788, 0.120334, 0.445995, 0.790720, 0.939497, 0.608759, 0.910331, + 0.812519, 0.878756, 0.638519, 0.845096, 0.557968, 0.630993, 0.203632, + 0.930233, 0.113477, 0.579697, 0.076247, 0.008244, 0.170785, 0.068549, + 0.698776, 0.123761, 0.007303, 0.107788, 0.427346, 0.907894, 0.696568, + 0.139633, 0.023613, 0.830100, 0.760421, 0.143947, 0.276096, 0.551141, + 0.083444, 0.884855, 0.461472, 0.895963, 0.763611, 0.099992, 0.741059, + 0.321579, 0.730984, 0.944691, 0.251812, 0.844461, 0.524388, 0.328059, + 0.852706, 0.695172, 0.396607, 0.551482, 0.818934, 0.403910, 0.659270, + 0.246280, 0.311804, 0.355838, 0.385913, 0.335418, 0.185938, 0.146334, + 0.479364, 0.462034, 0.697475, 0.562808, 0.346888, 0.158948, 0.458771, + 0.110499, 0.258939, 0.199830, 0.432078, 0.989924, 0.144521, 0.683890, + 0.834385, 0.668908, 0.011949, 0.687091, 0.364081, 0.408556, 0.238572, + 0.183015, 0.812466, 0.897842, 0.429294, 0.124271, 0.253680, 0.815207, + 0.459688, 0.439618, 0.961541, 0.939053, 0.901651, 0.659016, 0.501861, + 0.248539, 0.817964, 0.960632, 0.359038, 0.076903, 0.160462, 0.791117, + 0.066826, 0.304983, 0.475007, 0.901211, 0.973891, 0.486955, 0.588302, + 0.337972, 0.895512, 0.826874, 0.520987, 0.707978, 0.724716, 0.950281, + 0.832249, 0.978396, 0.765488, 0.291937, 0.418014, 0.727029, 0.230990, + 0.319665, 0.386045, 0.732850, 0.568204, 0.204009, 0.693482, 0.927242, + 0.280912, 0.853944, 0.718359, 0.347738, 0.158927, 0.193366, 0.248950, + 0.132818, 0.680321, 0.837252, 0.470790, 0.575833, 0.664126, 0.991777, + 0.283811, 0.388843, 0.942058, 0.116060, 0.367239, 0.707546, 0.407997, + 0.785253, 0.434575, 0.638986, 0.104917, 0.820620, 0.371837, 0.673121, + 0.024629, 0.065319, 0.600363, 0.305541, 0.919263, 0.318722, 0.653279, + 0.078190, 0.512088, 0.902229, 0.211009, 0.192409, 0.739480, 0.681799, + 0.768242, 0.403607, 0.673576, 0.052052, 0.792450, 0.615634, 0.168112, + 0.159689, 0.323180, 0.576109, 0.944941, 0.757755, 0.215095, 0.049858, + 0.578375, 0.586932, 0.722979, 0.603003, 0.652251, 0.323343, 0.908544, + 0.571514, 0.642065, 0.561823, 0.649704, 0.154153, 0.464051, 0.860713, + 0.346562, 0.203532, 0.542512, 0.114804, 0.607139, 0.216088, 0.166856, + 0.399588, 0.831722, 0.334968, 0.559277, 0.154902, 0.911077, 0.504218, + 0.912656, 0.126172, 0.554076, 0.491031, 0.713104, 0.277055, 0.094034, + 0.365355, 0.600398, 0.002578, 0.936869, 0.242463, 0.564401, 0.586574, + 0.396616, 0.028452, 0.447287, 0.743178, 0.231984, 0.989799, 0.857982, + 0.839122, 0.205887, 0.024838, 0.238711, 0.037608, 0.359806, 0.797987, + 0.192510, 0.270883, 0.302205, 0.105166, 0.397055, 0.856281, 0.596197, + 0.110160, 0.133336, 0.690231, 0.475515, 0.733734, 0.692809, 0.412384, + 0.976196, 0.257209, 0.998958, 0.372812, 0.285661, 0.446245, 0.115990, + 0.517645, 0.436044, 0.973972, 0.356767, 0.641930, 0.998810, 0.595478, + 0.679539, 0.358617, 0.393465, 0.872049, 0.629500, 0.695670, 0.977215, + 0.026555, 0.551951, 0.573412, 0.136715, 0.685287, 0.263643, 0.612229, + 0.419020, 0.956451, 0.024613, 0.395216, 0.213661, 0.023572, 0.768029, + 0.499322, 0.469816, 0.884019, 0.016967, 0.905860, 0.857991, 0.373734, + 0.547791, 0.856802, 0.969211, 0.227330, 0.215418, 0.362676, 0.099378, + 0.844918, 0.058346, 0.076594, 0.871473, 0.610297, 0.650006, 0.008188, + 0.295583, 0.913648, 0.620417, 0.714603, 0.870100, 0.645031, 0.109820, + 0.083760, 0.668602, 0.877849, 0.583082, 0.138419, 0.761868, 0.600049, + 0.044279, 0.619859, 0.973783, 0.592069, 0.476661, 0.942994, 0.819399, + 0.692079, 0.305670, 0.918778, 0.536997, 0.364016, 0.995371, 0.408470, + 0.974313, 0.645377, 0.416658, 0.269896, 0.559025, 0.037075, 0.984499, + 0.429125, 0.682105, 0.094319, 0.512885, 0.350707, 0.972168, 0.095967, + 0.489126, 0.734035, 0.696016, 0.533405, 0.353894, 0.669799, 0.125474, + 0.830555, 0.612793, 0.944873, 0.522634, 0.918463, 0.863651, 0.059631, + 0.282479, 0.859022, 0.468101, 0.256791, 0.504398, 0.884758, 0.526687, + 0.063423, 0.921833, 0.511186, 0.492548, 0.603939, 0.605505, 0.005433, + 0.954646, 0.577673, 0.101400, 0.443772, 0.311708, 0.797417, 0.977176, + 0.665602, 0.467216, 0.102650, 0.496157, 0.080009, 0.047524, 0.018791, + 0.998471, 0.911174, 0.078422, 0.280950, 0.770196, 0.546523, 0.537741, + 0.274594, 0.431281, 0.064428, 0.338017, 0.353115, 0.575615, 0.830565, + 0.957053, 0.181120, 0.835998, 0.911699, 0.758793, 0.937398, 0.355471, + 0.070501, 0.734815, 0.332647, 0.736103, 0.202031, 0.435297, 0.232261, + 0.282039, 0.482821, 0.251052, 0.280511, 0.393995, 0.329474, 0.561460, + 0.164191, 0.875997, 0.099202, 0.438785, 0.307278, 0.163630, 0.776802, + 0.660393, 0.739244, 0.607367, 0.617446, 0.920364, 0.443365, 0.529145, + 0.679157, 0.380763, 0.884616, 0.749658, 0.115578, 0.217263, 0.485761, + 0.317609, 0.652560, 0.718021, 0.599648, 0.135381, 0.969073, 0.880159, + 0.529376, 0.298547, 0.441619, 0.693567, 0.174544, 0.540821, 0.132351, + 0.481822, 0.704450, 0.909153, 0.142215, 0.443695, 0.516520, 0.759661, + 0.364059, 0.959885, 0.288806, 0.043216, 0.340648, 0.173422, 0.792874, + 0.456226, 0.390685, 0.278634, 0.773834, 0.043245, 0.996656, 0.373483, + 0.178625, 0.965729, 0.253641, 0.708001, 0.264276, 0.695260, 0.401568, + 0.438820, 0.236081, 0.533919, 0.920642, 0.940531, 0.443072, 0.062857, + 0.384226, 0.959592, 0.822518, 0.748285, 0.919477, 0.111325, 0.791501, + 0.260124, 0.284747, 0.584375, 0.716350, 0.675431, 0.863009, 0.490184, + 0.718676, 0.859665, 0.863666, 0.897301, 0.825393, 0.117308, 0.605302, + 0.089669, 0.812568, 0.006870, 0.528489, 0.048649, 0.540788, 0.449131, + 0.989180, 0.983860, 0.511988, 0.373407, 0.943452, 0.334506, 0.121692, + 0.862929, 0.445831, 0.913193, 0.123053, 0.730578, 0.497568, 0.839402, + 0.406009, 0.360577, 0.329586, 0.124685, 0.220241, 0.193253, 0.021986, + 0.045634, 0.310560, 0.627288, 0.135303, 0.123128, 0.634158, 0.663792, + 0.171777, 0.174946, 0.112923, 0.160958, 0.158806, 0.624911, 0.534364, + 0.102259, 0.959418, 0.656056, 0.965187, 0.405249, 0.569249, 0.088240, + 0.135827, 0.066817, 0.927642, 0.541836, 0.427393, 0.257229, 0.666520, + 0.647634, 0.450481, 0.688506, 0.693269, 0.761042, 0.315794, 0.828572, + 0.884170, 0.949952, 0.492364, 0.055947, 0.124898, 0.605288, 0.216905, + 0.283705, 0.230199, 0.751269, 0.385963, 0.189616, 0.407326, 0.351151, + 0.594865, 0.976575, 0.439391, 0.730692, 0.043392, 0.367033, 0.272527, + 0.470785, 0.624261, 0.939048, 0.118419, 0.074743, 0.627554, 0.811688, + 0.835784, 0.943348, 0.640260, 0.719954, 0.893300, 0.132625, 0.775901, + 0.018199, 0.737913, 0.992806, 0.301903, 0.968111, 0.744076, 0.687867, + 0.157728, 0.151401, 0.039017, 0.752593, 0.127976, 0.478408, 0.483284, + 0.171368, 0.845441, 0.755811, 0.642153, 0.469702, 0.694859, 0.760572, + 0.544445, 0.322413, 0.572260, 0.380229, 0.265761, 0.212521, 0.100183, + 0.159062, 0.345146, 0.876084, 0.177261, 0.083058, 0.868891, 0.479164, + 0.051169, 0.612966, 0.167030, 0.208897, 0.764367, 0.206048, 0.961490, + 0.892343, 0.684456, 0.444774, 0.063711, 0.529896, 0.200585, 0.705863, + 0.999598, 0.895444, 0.466435, 0.544043, 0.217857, 0.038696, 0.924272, + 0.483618, 0.251217, 0.024455, 0.642680, 0.596362, 0.900539, 0.819941, + 0.679420, 0.769430, 0.299105, 0.730590, 0.382396, 0.466135, 0.939487, + 0.146763, 0.672183, 0.900977, 0.039106, 0.356638, 0.345750, 0.102817, + 0.886535, 0.546336, 0.808681, 0.886133, 0.441780, 0.275116, 0.430176, + 0.659637, 0.313812, 0.354448, 0.143255, 0.565028, 0.378903, 0.785935, + 0.161391, 0.279443, 0.605876, 0.840811, 0.048873, 0.904980, 0.571401, + 0.431269, 0.371115, 0.510887, 0.578032, 0.043298, 0.411864, 0.617138, + 0.399936, 0.757614, 0.719955, 0.286471, 0.303950, 0.528636, 0.172604, + 0.745730, 0.803752, 0.602780, 0.405367, 0.117564, 0.957228, 0.548622, + 0.682592, 0.336131, 0.334557, 0.843983, 0.615574, 0.940433, 0.684794, + 0.664447, 0.845413, 0.256194, 0.095715, 0.216529, 0.767082, 0.673747, + 0.259827, 0.178946, 0.290885, 0.659763, 0.936560, 0.010840, 0.946234, + 0.240510, 0.539476, 0.118838, 0.986240, 0.343228, 0.721618, 0.391606, + 0.460792, 0.678846, 0.940228, 0.143384, 0.014977, 0.274785, 0.987367, + 0.630551, 0.215218, 0.672161, 0.294998, 0.060631, 0.928355, 0.390713, + 0.277160, 0.695436, 0.064460, 0.536987, 0.874382, 0.355345, 0.196751, + 0.810942, 0.366185, 0.142985, 0.051452, 0.905661, 0.261823, 0.037691, + 0.248889, 0.983441, 0.429297, 0.709681, 0.662286, 0.369525, 0.853066, + 0.677263, 0.644310, 0.840433, 0.307814, 0.859528, 0.512593, 0.602812, + 0.920160, 0.440948, 0.993525, 0.197320, 0.136384, 0.057984, 0.734307, + 0.010766, 0.413329, 0.931058, 0.821707, 0.779514, 0.074043, 0.873159, + 0.685175, 0.335865, 0.910850, 0.934065, 0.319306, 0.340147, 0.643746, + 0.981592, 0.709673, 0.496812, 0.658856, 0.353983, 0.337245, 0.966670, + 0.213511, 0.849838, 0.569482, 0.133671, 0.290786, 0.563007, 0.330991, + 0.427170, 0.620991, 0.065299, 0.437936, 0.034320, 0.996356, 0.259643, + 0.813834, 0.070399, 0.132802, 0.499009, 0.406265, 0.043652, 0.433074, + 0.725570, 0.383800, 0.076820, 0.707163, 0.093473, 0.573632, 0.366018, + 0.447456, 0.910877, 0.332688, 0.660967, 0.760714, 0.902170, 0.794638, + 0.051500, 0.465177, 0.125630, 0.478670, 0.086168, 0.190928, 0.916605, + 0.120488, 0.187285, 0.176248, 0.934322, 0.257684, 0.309050, 0.433331, + 0.663949, 0.352703, 0.866405, 0.389519, 0.736502, 0.943226, 0.096682, + 0.829975, 0.516858, 0.462700, 0.277430, 0.427734, 0.795388, 0.938398, + 0.188449, 0.697558, 0.733036, 0.239948, 0.162735, 0.858666, 0.718618, + 0.248903, 0.049594, 0.635223, 0.369391, 0.236879, 0.811472, 0.303713, + 0.494563, 0.120522, 0.737044, 0.158511, 0.473225, 0.603450, 0.548030, + 0.209727, 0.546675, 0.644712, 0.039702, 0.063533, 0.107412, 0.317132, + 0.491267, 0.902800, 0.255530, 0.679716, 0.600359, 0.988566, 0.919664, + 0.763094, 0.847232, 0.638283, 0.011997, 0.896825, 0.273506, 0.381388, + 0.133704, 0.084978, 0.685101, 0.628267, 0.205500, 0.422145, 0.786778, + 0.678725, 0.025595, 0.334808, 0.888452, 0.572271, 0.979520, 0.928154, + 0.635804, 0.086932, 0.245286, 0.127071, 0.989732, 0.500816, 0.806787, + 0.590091, 0.489382, 0.726451, 0.353185, 0.336614, 0.364734, 0.365182, + 0.233439, 0.638240, 0.746570, 0.367143, 0.723218, 0.431671, 0.995410, + 0.928718, 0.853816, 0.782188, 0.607442, 0.879411, 0.116995, 0.495894, + 0.451682, 0.096515, 0.424048, 0.087485, 0.183447, 0.669334, 0.214556, + 0.173179, 0.170151, 0.021343, 0.763269, 0.659533, 0.747794, 0.116454, + 0.996147, 0.112528, 0.481635, 0.229586, 0.750768, 0.228205, 0.596730, + 0.473985, 0.659876, 0.592139, 0.402703, 0.513692, 0.374327, 0.010145, + 0.393103, 0.491322, 0.506039, 0.844785, 0.587837, 0.930088, 0.932270, + 0.771284, 0.599422, 0.146826, 0.944463, 0.769573, 0.168169, 0.707732, + 0.429106, 0.915964, 0.824186, 0.425253, 0.028492, 0.305821, 0.654839, + 0.779259, 0.534026, 0.251569, 0.253245, 0.193901, 0.843708, 0.655947, + 0.707593, 0.218035, 0.666093, 0.100696, 0.709357, 0.172132, 0.945481, + 0.297195, 0.102220, 0.877751, 0.068479, 0.701642, 0.024577, 0.012941, + 0.471215, 0.192747, 0.720673, 0.900321, 0.108710, 0.544859, 0.325574, + 0.137202, 0.850679, 0.980413, 0.916462, 0.384705, 0.231982, 0.169706, + 0.578607, 0.075690, 0.825654, 0.286200, 0.293725, 0.491746, 0.386896, + 0.003083, 0.663878, 0.332377, 0.300278, 0.766098, 0.210128, 0.368756, + 0.467740, 0.234705, 0.381697, 0.938955, 0.427451, 0.102370, 0.839275, + 0.536162, 0.647229, 0.164849, 0.673364, 0.497908, 0.145262, 0.589825, + 0.882613, 0.377244, 0.759532, 0.461220, 0.452934, 0.585185, 0.747420, + 0.746660, 0.076932, 0.134316, 0.749743, 0.740810, 0.466692, 0.050020, + 0.506908, 0.676820, 0.418776, 0.974648, 0.911525, 0.800474, 0.913602, + 0.338976, 0.902844, 0.752878, 0.875138, 0.550072, 0.917727, 0.548502, + 0.047981, 0.062989, 0.138327, 0.930594, 0.440233, 0.897859, 0.391814, + 0.893168, 0.483044, 0.139234, 0.639828, 0.559975, 0.273549, 0.389570, + 0.300785, 0.740242, 0.439590, 0.807693, 0.417062, 0.858367, 0.782341, + 0.328586, 0.658840, 0.695943, 0.667562, 0.561684, 0.448821, 0.542700, + 0.111756, 0.366548, 0.091202, 0.159737, 0.429537, 0.229529, 0.090331, + 0.869770, 0.127388, 0.482145, 0.762938, 0.610432, 0.621379, 0.402765, + 0.170407, 0.894928, 0.792336, 0.471192, 0.635170, 0.231926, 0.278886, + 0.052232, 0.090293, 0.061226, 0.380818, 0.749133, 0.757170, 0.048380, + 0.310817, 0.205990, 0.591080, 0.422573, 0.572538, 0.682282, 0.582310, + 0.002075, 0.911812, 0.672641, 0.871845, 0.039199, 0.154786, 0.634783, + 0.649631, 0.776165, 0.037548, 0.820038, 0.671093, 0.829884, 0.291231, + 0.306263, 0.061810, 0.570116, 0.358495, 0.152103, 0.631343, 0.739313, + 0.901236, 0.388512, 0.787693, 0.212053, 0.594503, 0.378773, 0.634626, + 0.167040, 0.061056, 0.216937, 0.169115, 0.972867, 0.889578, 0.040960, + 0.012067, 0.044364, 0.675743, 0.661698, 0.820529, 0.713291, 0.481736, + 0.491623, 0.543175, 0.772966, 0.797886, 0.604985, 0.343083, 0.156380, + 0.757088, 0.974425, 0.895693, 0.658324, 0.362938, 0.683386, 0.870376, + 0.957440, 0.062159, 0.505002, 0.124481, 0.123215, 0.721939, 0.293596, + 0.096082, 0.611517, 0.334556, 0.108149, 0.655881, 0.010299, 0.769846, + 0.476411, 0.723590, 0.251582, 0.968033, 0.266765, 0.024548, 0.765919, + 0.871750, 0.367631, 0.922299, 0.628838, 0.342056, 0.817992, 0.287162, + 0.704994, 0.501378, 0.157538, 0.662434, 0.563537, 0.662541, 0.786915, + 0.686752, 0.384480, 0.080511, 0.782834, 0.995997, 0.415067, 0.890983, + 0.651878, 0.425365, 0.660829, 0.128289, 0.148956, 0.912411, 0.096322, + 0.415721, 0.936959, 0.862241, 0.287471, 0.304590, 0.784540, 0.916309, + 0.646646, 0.602533, 0.203471, 0.351640, 0.103911, 0.361009, 0.014074, + 0.667448, 0.023550, 0.800989, 0.354200, 0.408030, 0.881500, 0.137034, + 0.404026, 0.296566, 0.028017, 0.055904, 0.721932, 0.688846, 0.184193, + 0.870887, 0.601257, 0.280515, 0.286608, 0.538216, 0.142755, 0.574079, + 0.842806, 0.927296, 0.490388, 0.489452, 0.529828, 0.693859, 0.841092, + 0.633739, 0.054869, 0.855167, 0.301187, 0.078419, 0.656156, 0.655388, + 0.486448, 0.537656, 0.792422, 0.890475, 0.834222, 0.820439, 0.946379, + 0.556153, 0.509285, 0.130571, 0.427041, 0.110542, 0.411086, 0.713648, + 0.648758, 0.553842, 0.287727, 0.491563, 0.481137, 0.778116, 0.981015, + 0.010966, 0.471975, 0.822107, 0.644705, 0.526844, 0.677274, 0.945892, + 0.605263, 0.333430, 0.601280, 0.091711, 0.871086, 0.393702, 0.982186, + 0.705307, 0.214141, 0.928564, 0.261461, 0.723426, 0.059136, 0.688501, + 0.833968, 0.470222, 0.402150, 0.482725, 0.024063, 0.689877, 0.974289, + 0.505201, 0.467993, 0.955304, 0.516166, 0.939968, 0.777411, 0.160871, + 0.466812, 0.454685, 0.106763, 0.072075, 0.788115, 0.708043, 0.163786, + 0.659201, 0.101744, 0.145971, 0.364508, 0.315885, 0.074536, 0.625969, + 0.039311, 0.133672, 0.314471, 0.873279, 0.603893, 0.716620, 0.356004, + 0.627957, 0.406498, 0.330292, 0.133157, 0.874490, 0.285596, 0.649324, + 0.814458, 0.063007, 0.810195, 0.281270, 0.517693, 0.916958, 0.353345, + 0.305808, 0.625000, 0.517131, 0.965009, 0.726745, 0.663102, 0.329518, + 0.042630, 0.737638, 0.955487, 0.081940, 0.871310, 0.269957, 0.955219, + 0.475203, 0.986578, 0.311223, 0.103160, 0.393075, 0.641515, 0.236317, + 0.267566, 0.927112, 0.885641, 0.082024, 0.990119, 0.695835, 0.363295, + 0.507812, 0.612793, 0.716640, 0.813620, 0.237793, 0.233770, 0.778629, + 0.964538, 0.896872, 0.108147, 0.007167, 0.634510, 0.063633, 0.089108, + 0.505820, 0.333591, 0.044327, 0.981023, 0.320168, 0.355550, 0.084182, + 0.713244, 0.997065, 0.320499, 0.980810, 0.924177, 0.206140, 0.062834, + 0.914296, 0.901975, 0.426129, 0.422107, 0.514768, 0.142768, 0.235727, + 0.752561, 0.376539, 0.014356, 0.717099, 0.273411, 0.122502, 0.724266, + 0.907921, 0.186136, 0.813374, 0.413741, 0.519726, 0.857701, 0.394764, + 0.839895, 0.213251, 0.478946, 0.553139, 0.210317, 0.799446, 0.533948, + 0.134493, 0.005586, 0.596782, 0.048789, 0.907561, 0.022911, 0.470896, + 0.422329, 0.165679, 0.706623, 0.174890, 0.542218, 0.720979, 0.891989, + 0.815629, 0.843481, 0.616255, 0.723551, 0.029617, 0.429630, 0.137292, + 0.549343, 0.287331, 0.532056, 0.389238, 0.500583, 0.011002, 0.942377, + 0.710899, 0.810448, 0.476326, 0.845392, 0.816033, 0.073108, 0.894181, + 0.723594, 0.096019, 0.365077, 0.145923, 0.261699, 0.071700, 0.320813, + 0.803917, 0.792679, 0.212802, 0.619546, 0.636160, 0.829057, 0.343096, + 0.665777, 0.258687, 0.480388, 0.215121, 0.546018, 0.012444, 0.604359, + 0.046601, 0.023446, 0.546736, 0.757500, 0.833893, 0.023062, 0.602892, + 0.649927, 0.096170, 0.497074, 0.373521, 0.192189, 0.862151, 0.519444, + 0.453887, 0.933851, 0.840257, 0.257804, 0.726531, 0.053058, 0.877350, + 0.362691, 0.882115, 0.220446, 0.028468, 0.140802, 0.700834, 0.243589, + 0.686821, 0.713278, 0.847948, 0.733421, 0.736723, 0.394684, 0.490921, + 0.570617, 0.417746, 0.093813, 0.220543, 0.513916, 0.590887, 0.594064, + 0.706105, 0.453038, 0.113508, 0.159992, 0.386889, 0.953765, 0.417796, + 0.113420, 0.006823, 0.295146, 0.476111, 0.888938, 0.515592, 0.504579, + 0.029741, 0.216426, 0.748168, 0.716561, 0.929703, 0.596117, 0.449982, + 0.666427, 0.990801, 0.940903, 0.237043, 0.408547, 0.034717, 0.457587, + 0.922463, 0.625603, 0.051651, 0.628568, 0.078641, 0.165159, 0.788560, + 0.465530, 0.118923, 0.206356, 0.578950, 0.125746, 0.501502, 0.055060, + 0.014685, 0.017094, 0.559640, 0.044425, 0.233519, 0.307808, 0.760986, + 0.163223, 0.903925, 0.210969, 0.829650, 0.894726, 0.151872, 0.066693, + 0.303273, 0.186589, 0.524279, 0.225736, 0.812192, 0.575930, 0.854304, + 0.890833, 0.741089, 0.642864, 0.356363, 0.860012, 0.849220, 0.935313, + 0.985758, 0.350722, 0.990373, 0.000443, 0.367815, 0.550013, 0.044868, + 0.601335, 0.857820, 0.805855, 0.764557, 0.761745, 0.016823, 0.594207, + 0.656471, 0.168696, 0.660900, 0.959744, 0.355284, 0.185179, 0.185480, + 0.167477, 0.761110, 0.039784, 0.058310, 0.502199, 0.682648, 0.414673, + 0.362211, 0.531868, 0.349985, 0.347969, 0.882589, 0.340358, 0.348412, + 0.250404, 0.890371, 0.393280, 0.851739, 0.748191, 0.199135, 0.616297, + 0.509936, 0.215958, 0.210504, 0.166407, 0.384654, 0.871404, 0.126151, + 0.739938, 0.056583, 0.311631, 0.907415, 0.817693, 0.351415, 0.965724, + 0.319891, 0.034062, 0.380397, 0.682102, 0.565930, 0.730382, 0.030072, + 0.448519, 0.070741, 0.378484, 0.698924, 0.961112, 0.771764, 0.550663, + 0.709303, 0.970899, 0.166959, 0.219239, 0.186857, 0.377463, 0.385647, + 0.571511, 0.248867, 0.511798, 0.311449, 0.305450, 0.823429, 0.218864, + 0.123142, 0.174844, 0.184588, 0.443034, 0.208906, 0.564986, 0.125136, + 0.774836, 0.295368, 0.155207, 0.223355, 0.366109, 0.533691, 0.922279, + 0.327221, 0.305455, 0.472942, 0.036524, 0.276354, 0.639901, 0.255763, + 0.463211, 0.017364, 0.641410, 0.034722, 0.266231, 0.153207, 0.346171, + 0.571680, 0.976636, 0.565036, 0.694822, 0.151480, 0.749624, 0.137856, + 0.360386, 0.314610, 0.262992, 0.135222, 0.609978, 0.418200, 0.358578, + 0.976087, 0.951891, 0.280856, 0.303307, 0.257346, 0.753798, 0.339831, + 0.533700, 0.393699, 0.595594, 0.996911, 0.411063, 0.237003, 0.031634, + 0.677294, 0.390211, 0.377805, 0.248974, 0.366847, 0.942841, 0.943796, + 0.518327, 0.692465, 0.081653, 0.878713, 0.007074, 0.344645, 0.013936, + 0.617052, 0.762845, 0.372513, 0.593138, 0.714736, 0.653370, 0.896446, + 0.972082, 0.407168, 0.236276, 0.505782, 0.800867, 0.831870, 0.502693, + 0.211930, 0.068873, 0.534327, 0.889224, 0.459084, 0.912132, 0.138197, + 0.825931, 0.854972, 0.081994, 0.344259, 0.547437, 0.163646, 0.222972, + 0.554511, 0.508291, 0.236908, 0.171563, 0.271135, 0.609421, 0.764701, + 0.985871, 0.262790, 0.661147, 0.957953, 0.669958, 0.897423, 0.463734, + 0.470825, 0.729293, 0.966427, 0.682755, 0.798166, 0.500754, 0.571978, + 0.257251, 0.412886, 0.710176, 0.083182, 0.267858, 0.792169, 0.427441, + 0.815295, 0.955815, 0.650413, 0.369805, 0.464106, 0.887320, 0.541368, + 0.735242, 0.496741, 0.306069, 0.721113, 0.759531, 0.967216, 0.679065, + 0.429489, 0.864639, 0.142799, 0.900314, 0.593932, 0.109227, 0.583069, + 0.392098, 0.609981, 0.155047, 0.649349, 0.022867, 0.865222, 0.732531, + 0.290725, 0.657392, 0.159972, 0.106019, 0.613207, 0.810384, 0.475824, + 0.077313, 0.697704, 0.017192, 0.812555}; + +static float golden_endtoend_output[] = { + -1.881211, -0.028385, -3.585066, 1.939770, -3.461155, 1.280415, -4.408978, + 0.608663, -2.704937, 1.859742, -5.777429, 2.691839, -1.049012, 1.640870, + -4.856245, 1.604236, 0.992707, 0.422858, -4.307465, 1.887332, -0.884831, + -0.154277, -2.634801, 0.586827, -1.849960, 1.399608, -4.531559, 1.943591, + 0.271676, -2.893054, -2.066826, 0.235467, -1.248263, -1.164534, -2.640174, + -0.112878, -4.386484, 1.253024, -4.135623, 1.068984, -0.043579, -0.832957, + -3.257258, -0.514396, -1.651174, 0.638630, -4.364372, 1.548441, -0.289455, + 0.539845, -4.097627, 0.635001, -0.465071, -0.927701, -2.481498, 0.356616, + -2.355012, 0.728806, -3.340283, 1.609038, -4.786268, -0.532272, -1.886150, + 0.254797, 0.746620, -1.657134, -3.264265, 0.525551, -1.756837, 0.845446, + -5.572190, 1.715797, -2.856942, 3.394245, -5.803662, 2.281806, -3.014739, + 2.616136, -4.728482, 1.659984, -2.106307, 2.711709, -6.173832, 1.352869, + -0.038035, 0.107619, -4.279774, 2.341930, -0.980413, -0.119538, -4.049717, + 1.172128, -3.477744, 2.602274, -6.231380, 2.537300, -0.862214, 0.568722, + -3.858362, 0.197867, -1.725885, 3.687312, -7.067363, 2.403544, -0.944963, + 0.235639, -3.250094, 0.659117, -1.459576, 0.426128, -3.637207, 1.030386, + -4.224351, 3.516220, -6.053367, 0.993473, -2.182416, -0.762625, -1.884405, + -0.113736, -2.572602, 0.329290, -1.913233, 0.517418, -0.019757, 0.203176, + -3.715881, 0.482136, -1.912823, 1.357907, -5.473043, 1.714658, -3.177160, + 0.089285, -3.127669, 1.268076, 0.772498, -1.622712, -3.850314, 0.436124, + -1.495983, 3.439982, -7.623405, 1.726721, -0.423979, 0.180201, -2.902406, + 0.986457, -1.845638, 0.460903, -5.359343, -1.133931, -1.074456, 0.717304, + -3.519856, 1.012126, -0.562301, 1.881967, -6.716627, 2.525036, 0.945480, + 0.337081, -5.210562, 2.572035, -0.943370, 0.442026, -2.666313, 0.411296, + 0.002787, -0.000735, -2.498933, 0.771719, -3.568153, 3.833721, -6.617026, + 2.813922, -0.573970, 1.025208, -3.909923, 1.722648, -1.406849, 0.719783, + -5.207438, 1.819442, -0.530895, -0.010887, -2.939614, 0.971225, -1.660297, + 1.345243, -4.454571, 2.244876, -2.021213, 1.756090, -4.880947, 0.364597, + -2.380270, 2.763117, -5.613013, 2.137534, 0.289101, -2.279400, -3.365582, + 0.170028, -1.142254, -0.709604, -3.656223, 1.804870, -0.854690, 0.592102, + -5.010415, 2.462687, -1.474710, 0.566002, -3.621819, -0.391946, -0.423524, + -0.631428, -3.513310, 0.962825, -1.480262, 0.319791, -3.610137, 1.842339, + -0.250073, 1.182022, -6.249267, 1.604172, 1.153759, -0.734054, -4.620415, + -0.030858, 0.050911, 1.524406, -4.724010, 1.451846, -3.277104, 2.414182, + -4.605285, 1.846092, -1.503047, -0.618200, -2.746546, -0.459332, -0.980326, + -1.199977, -2.043865, -0.165793, -2.214698, 3.108281, -7.127830, -0.123065, + 1.244948, -3.039923, -4.660061, -0.225957, -0.307210, -1.513205, -2.456005, + 0.840048, -0.741445, 2.328635, -6.015267, 2.723240, -1.381171, -0.728878, + -5.114925, -0.362034, -0.574923, 0.518080, -3.892457, 1.798948, 0.435119, + -0.371696, -2.807571, 1.302864, -2.063052, 1.036388, -4.232038, 1.397059, + -1.615668, -1.511019, -3.095508, 1.290955, -3.428723, 2.000287, -4.196487, + 1.566983, 0.196957, 0.224343, -4.926359, -0.691975, -0.214941, 1.546821, + -5.384868, 2.290820, -1.878865, 0.493692, -4.129823, 2.112036, 0.516558, + -2.553077, -2.717338, 0.017146, -2.016057, 1.628995, -4.240602, 1.189533, + -5.460220, 1.254738, -4.214903, 0.755659, -2.893235, 2.937762, -6.169453, + 2.035456, -5.613212, -0.122254, -1.973646, -0.060619, -2.119598, 1.413512, + -4.938738, 1.890244, 0.544169, -2.062413, -3.329637, -0.062515, -1.855805, + -0.791297, -2.570353, 0.607615, 0.305812, 0.338930, -4.150270, 2.274937, + 0.042653, 0.133825, -3.538155, 1.523639, -3.173690, -1.496599, -2.414655, + 0.464687, -1.448998, -0.368907, -3.520129, 0.203382, -2.443626, 1.266233, + -3.393848, 0.605911, -0.015353, 1.402006, -4.441003, 1.419281, 0.603587, + 0.434146, -4.966566, 2.171872, -0.688264, -0.009981, -4.461103, 1.538354, + -5.029816, -0.264424, -1.713510, -0.315258, -1.891606, 0.252074, -2.419428, + 0.043970, -1.291143, 2.048704, -4.590105, 0.524734, -1.889576, 0.134836, + -3.462745, 1.390663, -0.112773, 0.402735, -4.203784, 1.381043, -1.201634, + -1.968277, -1.425637, -0.181725, -1.250742, -2.102041, -3.925464, -1.256797, + -3.701354, -1.754610, -1.917231, -1.455910, -1.838006, 2.041781, -5.666212, + 2.752957, -2.659553, 2.553637, -4.872212, 1.443437, -2.081846, 3.311263, + -5.912457, 1.871049, 0.196148, -0.307044, -4.024967, 2.149149, 0.361809, + 0.620415, -5.939984, 0.180672, -1.209180, -0.269122, -3.240285, 1.460315, + -1.040803, 1.125700, -6.060366, 0.887767, -3.214111, 1.314368, -3.026808, + 1.023640, -3.815175, 1.795642, -4.355603, 1.064454, -0.046472, 0.618463, + -5.941646, 2.861891, -2.852155, -0.990457, -2.624445, 1.794494, -1.176747, + -0.358159, -3.206776, 1.138721, -2.819523, -1.825522, -1.450902, -0.187312, + -0.808727, 0.636872, -4.120567, 1.192623, 0.810731, -1.768519, -3.699450, + 1.527116, -2.772720, 3.012835, -5.912736, 1.599365, -4.696381, 2.234591, + -4.139552, 1.061768, -1.880089, 3.596274, -7.006379, 2.382152, -3.158115, + 3.844430, -7.044156, 2.307596, -2.473970, 1.312644, -5.467269, 0.197154, + -1.530040, 1.762275, -5.550757, 0.630276, -3.048947, 1.043777, -3.096658, + 1.345893, -1.329494, 2.065748, -4.711032, 2.227600, -0.413321, -0.032428, + -4.599650, 1.668734, -4.351490, -0.200022, -2.359903, 0.021997, 0.116028, + 1.159718, -5.093972, -0.142951, -2.409895, 0.906133, -2.728812, 0.809932, + -2.597363, 0.494130, -2.357861, 0.369825, -2.165235, 1.148522, -3.130562, + 0.759034, 0.646335, -1.463660, -3.508299, 1.059679, -1.485465, 1.007319, + -4.340716, 1.789864, -1.590654, 1.612324, -4.452007, 2.389805, -5.200148, + -1.068398, -1.306923, -0.472408, -0.392165, -0.524996, -2.933478, 1.518430, + -1.287781, 0.113422, -3.020525, 1.338359, -0.105982, 0.936014, -4.132197, + 1.836807, -0.616589, -1.029716, -3.271347, 0.284889, -2.653359, 2.135829, + -4.643613, 1.627981, 0.287733, -2.017263, -2.776574, 1.184792, 1.004161, + -1.483019, -4.339290, -0.787322, 0.582420, 1.137839, -5.673941, -0.001862, + -1.219142, 0.532561, -4.457245, 1.826807, -3.343291, 3.034610, -6.179855, + 2.235917, -4.369989, 4.018128, -6.632714, 0.926585, -0.485469, 0.536073, + -4.179557, 1.489637, -0.521762, 1.636089, -6.137912, 1.500867, -4.086009, + 1.961372, -3.688977, 1.358220, -1.544034, 1.763837, -4.357567, 1.852201, + -2.018725, 1.046264, -6.211127, 1.609419, -0.118441, 1.602284, -6.242423, + 1.518578, -0.604078, 1.106613, -5.393445, 2.595629, 0.142712, -1.903953, + -2.821177, 0.032758, -0.009152, 0.184628, -4.227636, 2.046843, -2.240138, + 1.256176, -5.108516, -0.308447, -2.998571, 4.657396, -7.582112, 2.510951, + -3.535784, 1.704560, -5.068484, 1.318466, -3.058265, 3.073172, -6.998089, + 3.178849, -2.420286, 2.277806, -4.999528, 1.423890, -1.672914, 0.447460, + -4.088940, 1.351087, -1.051546, -0.417955, -4.042147, 1.604102, -1.700931, + 2.796663, -6.497579, 2.857974, -0.240828, 0.858001, -5.778933, 2.778508, + -0.406211, 1.300766, -5.073671, 2.089362, -0.201673, 1.588396, -6.000150, + 2.185055, -2.332125, 0.768216, -2.609184, 0.327277, -3.358943, -1.020736, + -2.389984, 0.315512, -0.561905, 1.948740, -6.408485, 2.231985, -0.603652, + 0.661829, -5.070386, -1.063058, -0.624796, 1.375772, -4.379606, 1.929358, + -1.047263, 0.739100, -5.217857, 2.127625, -5.025338, 0.650344, -2.068460, + 0.076936, -0.457505, -1.050984, -1.917765, 1.150908, 0.782625, 0.855595, + -5.321719, 0.787209, -0.460232, 1.106736, -5.552326, 2.801043, -0.360217, + -0.434432, -4.273378, 0.967556, -0.972652, 0.874811, -5.429918, -0.331039, + 0.115477, 0.111883, -5.418786, 1.240546, -1.842794, 0.505880, -3.676064, + -0.682369, 1.858984, -0.742566, -5.784060, 0.673239, -1.280398, 0.280842, + -4.848077, 2.214860, -0.785100, -0.588488, -2.438206, 0.786651, -1.568752, + 1.935400, -6.320256, 2.125338, -1.476457, -1.651941, -2.695734, 0.007338, + -3.280860, 2.310385, -5.319578, 1.890123, -0.775723, 0.630606, -4.321582, + 1.085521, -1.847371, 1.188521, -4.596577, 2.056443, -2.340172, -0.108501, + -3.156392, 0.933279, -0.495331, 0.122405, -5.171133, 1.763245, -0.796913, + 2.310487, -7.247197, 2.401678, -1.908860, 0.043798, -2.393796, 0.573806, + -0.608531, 0.154710, -4.669001, 0.750680, 0.468380, 0.392591, -4.755001, + 2.615217, -1.957774, 1.153513, -4.530099, 1.124362, -3.569415, 1.697154, + -3.536335, 0.910758, -2.976264, 1.833129, -4.287203, -0.547050, -2.409768, + 0.061585, -1.324116, 0.268497, -2.962222, -1.524245, -2.063413, 0.442058, + -4.292337, 3.538863, -6.699603, 1.718664, -2.290363, 1.994596, -6.245037, + -0.433084, -0.367059, 1.020297, -4.940721, 2.902264, -0.577056, -0.709887, + -5.001413, -0.268316, -1.112048, -1.083307, -1.753492, 0.209973, 0.139540, + 0.917602, -5.232745, 2.538467, -2.139234, -0.187388, -1.837249, -0.478582, + -0.731653, -0.481550, -2.531261, 1.044770, 0.707750, 0.279971, -3.221119, + 1.552074, -2.373144, 0.859518, -3.665156, 1.620278, -1.440871, -0.525581, + -2.758271, 1.491873, -2.302013, 1.119935, -5.257080, 2.627170, -3.174739, + 1.363282, -4.831639, 1.101076, -4.337008, 2.689639, -5.165915, 1.069201, + -1.882078, -0.120370, -2.287967, 1.147619, -1.403616, 1.077150, -5.084296, + 1.658236, -0.919642, 0.487423, -3.001075, 0.741268, 0.107300, 0.943556, + -3.544311, 1.000239, -1.627171, 2.871253, -5.179172, 1.429893, -0.826040, + 0.188670, -4.499894, 1.013447, -2.101299, 0.317516, -3.452141, -0.833776, + -1.362144, 1.272437, -4.449355, 1.613591, -2.039873, 2.613175, -6.229640, + 1.659790, -1.595520, -0.237462, -2.744997, 0.337841, 0.148981, -1.703771, + -2.388023, 1.276469, 1.058508, -0.401642, -4.680769, 0.861881, -1.336381, + 1.153080, -2.834378, 0.721075, 0.900115, 1.360511, -5.573611, 0.949182, + -2.970844, 2.017563, -5.186108, -0.201038, -1.192824, 0.610142, -4.450919, + -0.897114, -1.812093, 0.422310, -5.245487, 0.256549, 0.320275, -2.324150, + -2.967040, -0.260536, -0.721467, 0.454148, -5.058031, 0.526370, -0.895656, + 0.732240, -3.327363, 1.353953, -1.277912, -0.483171, -1.926713, 0.065044, + -2.167506, -0.196606, -1.923437, 0.604962, -2.088319, 1.406834, -5.227296, + 2.247351, -4.421744, 1.729791, -5.007922, 1.264769, -0.897019, 0.922902, + -3.887108, 2.087432, -1.310226, -0.101938, -3.359082, -0.079662, -0.514988, + -0.963179, -4.038209, 2.223278, -0.590083, -2.310458, -1.748338, 0.363406, + -0.540731, -0.885913, -4.179595, 2.216781, -3.044339, -0.447100, -2.446098, + 0.931101, -1.676190, 2.096175, -4.980755, 2.262151, -1.095047, 1.897516, + -5.996138, 2.191038, 0.297128, -0.780974, -2.884299, 1.195408, -0.521065, + -1.955837, -3.091064, -0.404183, -1.961519, 4.076096, -7.521851, 2.242064, + -1.988043, 0.303300, -2.422585, 0.322230, -3.377634, 3.499955, -7.084434, + 2.375587, -0.718851, 2.150076, -5.412241, 2.374280, -2.006088, 2.229828, + -5.848188, 2.543077, -2.171042, 2.096026, -5.300007, 0.141405, -1.187745, + 0.105340, -4.003816, 1.034281, -3.980804, 1.856709, -5.103042, 0.623737, + -2.080307, 0.896140, -3.104050, 0.983158, -0.424898, -1.154270, -3.805728, + 1.978917, -1.314387, 1.235096, -3.148906, 1.113173, 0.111713, 2.055213, + -7.565283, 2.100342}; +constexpr std::initializer_list biases = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + +constexpr std::initializer_list recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +class BidirectionalRNNOpModel : public SingleOpModel { + public: + BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, + int bw_units, int input_size) + : batches_(batches), + sequence_len_(sequence_len), + fw_units_(fw_units), + bw_units_(bw_units), + input_size_(input_size) { + input_ = AddInput(TensorType_FLOAT32); + fw_weights_ = AddInput(TensorType_FLOAT32); + fw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + fw_bias_ = AddInput(TensorType_FLOAT32); + fw_hidden_state_ = AddOutput(TensorType_FLOAT32); + fw_output_ = AddOutput(TensorType_FLOAT32); + bw_weights_ = AddInput(TensorType_FLOAT32); + bw_recurrent_weights_ = AddInput(TensorType_FLOAT32); + bw_bias_ = AddInput(TensorType_FLOAT32); + bw_hidden_state_ = AddOutput(TensorType_FLOAT32); + bw_output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_SequenceRNNOptions, + CreateSequenceRNNOptions(builder_, /*time_major=*/false, + ActivationFunctionType_RELU) + .Union()); + BuildInterpreter({ + {batches_, sequence_len_, input_size_}, // input + {fw_units_, input_size_}, // fw_weights + {fw_units_, fw_units_}, // fw_recurrent_weights + {fw_units_}, // fw_bias + {bw_units_, input_size_}, // bw_weights + {bw_units_, bw_units_}, // bw_recurrent_weights + {bw_units_} // bw_bias + }); + } + + void SetFwBias(std::initializer_list f) { + PopulateTensor(fw_bias_, f); + } + + void SetBwBias(std::initializer_list f) { + PopulateTensor(bw_bias_, f); + } + + void SetFwWeights(std::initializer_list f) { + PopulateTensor(fw_weights_, f); + } + + void SetBwWeights(std::initializer_list f) { + PopulateTensor(bw_weights_, f); + } + + void SetFwRecurrentWeights(std::initializer_list f) { + PopulateTensor(fw_recurrent_weights_, f); + } + + void SetBwRecurrentWeights(std::initializer_list f) { + PopulateTensor(bw_recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenStates() { + const int fw_zero_buffer_size = fw_units_ * batches_; + std::unique_ptr fw_zero_buffer(new float[fw_zero_buffer_size]); + memset(fw_zero_buffer.get(), 0, fw_zero_buffer_size * sizeof(float)); + PopulateTensor(fw_hidden_state_, 0, fw_zero_buffer.get(), + fw_zero_buffer.get() + fw_zero_buffer_size); + const int bw_zero_buffer_size = bw_units_ * batches_; + std::unique_ptr bw_zero_buffer(new float[bw_zero_buffer_size]); + memset(bw_zero_buffer.get(), 0, bw_zero_buffer_size * sizeof(float)); + PopulateTensor(bw_hidden_state_, 0, bw_zero_buffer.get(), + bw_zero_buffer.get() + bw_zero_buffer_size); + } + + std::vector GetFwOutput() { return ExtractVector(fw_output_); } + std::vector GetBwOutput() { return ExtractVector(bw_output_); } + + int input_size() { return input_size_; } + int num_fw_units() { return fw_units_; } + int num_bw_units() { return bw_units_; } + int num_batches() { return batches_; } + int sequence_len() { return sequence_len_; } + + private: + int input_; + int fw_weights_; + int fw_recurrent_weights_; + int fw_bias_; + int fw_hidden_state_; + int fw_output_; + int bw_weights_; + int bw_recurrent_weights_; + int bw_bias_; + int bw_hidden_state_; + int bw_output_; + + int batches_; + int sequence_len_; + int fw_units_; + int bw_units_; + int input_size_; +}; + +// TODO(mirkov): add another test which directly compares to TF once TOCO +// supports the conversion from dynamic_rnn with BasicRNNCell. +TEST(BidirectionalRNNOpTest, BlackBoxTest) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + rnn.ResetHiddenStates(); + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + float* golden_fw_start = rnn_golden_fw_output; + float* golden_fw_end = + golden_fw_start + rnn.num_fw_units() * rnn.sequence_len(); + std::vector fw_expected; + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + + float* golden_bw_start = rnn_golden_bw_output; + float* golden_bw_end = + golden_bw_start + rnn.num_bw_units() * rnn.sequence_len(); + std::vector bw_expected; + bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); + bw_expected.insert(bw_expected.end(), golden_bw_start, golden_bw_end); + EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); +} + +// Check that if the input sequence is reversed the outputs are the same just +// forward and backward are swapped (and reversed). +TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + rnn.ResetHiddenStates(); + + // Reverse inputs in each batch: in_1, in_2,..., in_k is inserted in the + // following order: [in_k,..., in_2, in_1, in_k,...,in_2, in_1]. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + const int reverse_idx = rnn.sequence_len() - i - 1; + rnn.SetInput(reverse_idx * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((rnn.sequence_len() + reverse_idx) * rnn.input_size(), + batch_start, batch_end); + } + + rnn.Invoke(); + + // The forward and backward outputs are swapped. + std::vector fw_expected; // consider using std::deque instead. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_fw_start = rnn_golden_bw_output + i * rnn.num_fw_units(); + float* golden_fw_end = golden_fw_start + rnn.num_fw_units(); + fw_expected.insert(fw_expected.begin(), golden_fw_start, golden_fw_end); + } + fw_expected.insert(fw_expected.end(), fw_expected.begin(), fw_expected.end()); + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); + + std::vector bw_expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_bw_start = rnn_golden_fw_output + i * rnn.num_bw_units(); + float* golden_bw_end = golden_bw_start + rnn.num_bw_units(); + bw_expected.insert(bw_expected.begin(), golden_bw_start, golden_bw_end); + } + bw_expected.insert(bw_expected.end(), bw_expected.begin(), bw_expected.end()); + EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); +} + +// Tests an end-to-end neural network with a Bidirectional RNN followed by a +// DNN that aggregates the outputs from the two sequences. +TEST(BidirectionalRNNOpTest, EndToEndTest) { + BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8); + const int output_size = 4; + float dnn_weights[] = { + -0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139, + -0.23420811, -0.39647382, 0.31423986, 0.61819065, -0.73659575, + -0.89698344, -0.8931554, -0.0845688, 0.5617367, 0.38415289, + -0.11487955, -0.7617774, 0.17927337, 0.15726972, 0.059798479, + 0.19009054, -0.27616632, -0.39142907, 0.77744663, -0.046830714, + -0.6603595, 0.21945822, 0.051494241, 0.23785079, 0.19239247, + -0.53268754, 0.65961659, -0.85981959, -0.80232513, 0.84745562, + -0.66070104, -0.036533296, -0.54901814, 0.65353882, -0.41834265, + -0.28561389, 0.75655544, -0.31149811, 0.62981737, 0.31829214, + -0.92734522, -0.48506218, 0.55651462, 0.25192821, 0.67220747, + -0.3836869, -0.55798125, -0.60395885, 0.22488403, -0.78053463, + 0.3492105, 0.56452453, 0.4389236, -0.59929526, -0.19762468, + -0.36868393, -0.13198286, -0.53800809, -0.22850353}; + + std::initializer_list dnn_biases = { + 0.29177809, -0.98799044, 0.065919638, 0.68781924}; + + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + rnn.ResetHiddenStates(); + + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + const int output_sequence_size = output_size * rnn.sequence_len(); + const int num_examples = 64; + for (int k = 0; k < num_examples; k++) { + float* batch_start = endtoend_input + k * input_sequence_size; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + + rnn.Invoke(); + + std::vector fw_output = rnn.GetFwOutput(); + std::vector bw_output = rnn.GetBwOutput(); + EXPECT_EQ(fw_output.size(), bw_output.size()); + + std::transform(fw_output.begin(), fw_output.end(), bw_output.begin(), + fw_output.begin(), std::plus()); + + std::vector sequence_result; + for (int s = 0; s < rnn.sequence_len(); s++) { + const float* rnn_output = fw_output.data() + s * rnn.num_fw_units(); + std::vector results(dnn_biases); + for (int i = 0; i < output_size; i++) { + for (int j = 0; j < rnn.num_fw_units(); j++) { + results[i] += *(rnn_output + j) * dnn_weights[output_size * j + i]; + } + } + sequence_result.insert(sequence_result.end(), results.begin(), + results.end()); + } + + float* golden_start = golden_endtoend_output + k * output_sequence_size; + float* golden_end = golden_start + output_sequence_size; + + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(sequence_result, ElementsAreArray(ArrayFloatNear(expected))); + } +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/concatenation.cc b/tensorflow/contrib/lite/kernels/concatenation.cc index 9e7a1233dac0f3cd02dc386f9d194597f38ca3b8..a619ada86af64c299f8e518a7493db20f1011a50 100644 --- a/tensorflow/contrib/lite/kernels/concatenation.cc +++ b/tensorflow/contrib/lite/kernels/concatenation.cc @@ -49,6 +49,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // dimensions except 'axis' must be equal. TfLiteTensor* t0 = &context->tensors[node->inputs->data[0]]; TfLiteType input_type = t0->type; + if (axis < 0) axis += t0->dims->size; TF_LITE_ENSURE(context, axis >= 0); TF_LITE_ENSURE(context, axis < t0->dims->size); @@ -95,53 +96,22 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return context->ResizeTensor(context, output, output_size); } -template -class VectorOfInputs { - public: - VectorOfInputs(const TfLiteContext& context, const TfLiteIntArray& inputs) { - int num_inputs = inputs.size; - - all_data_.reserve(num_inputs); - all_dims_.reserve(num_inputs); - all_dims_ptr_.reserve(num_inputs); - - for (int i = 0; i < num_inputs; ++i) { - TfLiteTensor* input = &context.tensors[inputs.data[i]]; - all_data_.push_back(GetTensorData(input)); - all_dims_.push_back(GetTensorDims(input)); - } - - // Taking the pointer from inside a std::vector is only OK if the vector is - // never modified, so we populate all_dims in the previous loop and then we - // are free to grab iterators here. - for (int i = 0; i < num_inputs; ++i) { - all_dims_ptr_.push_back(&all_dims_[i]); - } - } - const T* const* data() const { return all_data_.data(); } - const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } - - private: - std::vector all_data_; - std::vector> all_dims_; - std::vector*> all_dims_ptr_; -}; - template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); - + int axis = params->axis; TfLiteTensor* output = &context->tensors[node->outputs->data[0]]; + if (axis < 0) axis += output->dims->size; // TODO(ahentz): Creating 'all_inputs' below is not very efficient. We should // allocate and populate these during Prepare(). // TODO(ycling): Activation function parameter is ignored. For now we dont have // a model with a Concatenation with fused activation function. #define TF_LITE_CONCATENATION(type, scalar) \ - VectorOfInputs all_inputs(*context, *node->inputs); \ + VectorOfTensors all_inputs(*context, *node->inputs); \ type::Concatenation( \ - RemapDim(NumDimensions(output), params->axis), all_inputs.data(), \ + RemapDim(NumDimensions(output), axis), all_inputs.data(), \ all_inputs.dims(), node->inputs->size, GetTensorData(output), \ GetTensorDims(output)) diff --git a/tensorflow/contrib/lite/kernels/concatenation_test.cc b/tensorflow/contrib/lite/kernels/concatenation_test.cc index 94e5b2acdcabeedb4652baa1a008b22bf6bc8433..ba1ffc5f8423b9626c9c8e2a1086ea0dcca43f50 100644 --- a/tensorflow/contrib/lite/kernels/concatenation_test.cc +++ b/tensorflow/contrib/lite/kernels/concatenation_test.cc @@ -94,7 +94,7 @@ TEST(ConcatenationOpTest, TwoDimensionalOneInput) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); } -TEST(ConcatenationOpTest, TwoInputsTwoAxis) { +TEST(ConcatenationOpTest, TwoInputsTwoAxesNegativeAxes) { // We will concatenate two tensors along different dimensions. auto tensor0 = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; auto tensor1 = {7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f}; @@ -107,6 +107,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { EXPECT_THAT(m0.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m0_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-2, + /*num_inputs=*/2); + m0_negative.SetInput(0, tensor0); + m0_negative.SetInput(1, tensor1); + m0_negative.Invoke(); + EXPECT_THAT(m0_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + ConcatenationOpModel m1({TensorType_FLOAT32, {2, 3}}, /*axis=*/1, /*num_inputs=*/2); m1.SetInput(0, tensor0); @@ -114,6 +122,14 @@ TEST(ConcatenationOpTest, TwoInputsTwoAxis) { m1.Invoke(); EXPECT_THAT(m1.GetOutput(), ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); + + ConcatenationOpModel m1_negative({TensorType_FLOAT32, {2, 3}}, /*axis=*/-1, + /*num_inputs=*/2); + m1_negative.SetInput(0, tensor0); + m1_negative.SetInput(1, tensor1); + m1_negative.Invoke(); + EXPECT_THAT(m1_negative.GetOutput(), + ElementsAreArray({1, 2, 3, 7, 8, 9, 4, 5, 6, 10, 11, 12})); } TEST(ConcatenationOpTest, FourInputs) { @@ -156,7 +172,7 @@ TEST(ConcatenationOpTest, FourInputsQuantized) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/conv.cc b/tensorflow/contrib/lite/kernels/conv.cc index c75c04baeac2ce53c6261d677dca8d72fafa0da5..66d2c04bba4a164bbcdcf4b1a097d9aac0b3aeeb 100644 --- a/tensorflow/contrib/lite/kernels/conv.cc +++ b/tensorflow/contrib/lite/kernels/conv.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" #include "tensorflow/contrib/lite/kernels/gemm_support.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" #include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" @@ -38,11 +39,16 @@ namespace ops { namespace builtin { namespace conv { -// This file has three implementation of Conv. +// This file has 4 implementation of Conv. enum KernelType { kReference, kGenericOptimized, // Neon-free - kNeonOptimized, + kMultithreadOptimized, + // The kernel uses use CBLAS interface for matrix multiplication. + // It's fast when an optimized CBLAS implementation is available (e.g. Apple + // Accelerate Framework), and it's slow when falling back to naive + // implementation. + kCblasOptimized, }; struct OpData { @@ -265,10 +271,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { free(hwcn_weights->data.raw); hwcn_weights->data.raw = nullptr; } + + // Note that hwcn_weights_status is a kTfLiteDynamic tensor, and + // ResizeTensor will actually allocate space for it. The would be more + // efficient if we placed hwcn_weights_status in the persistent arena. auto hwcn_weights_status = context->ResizeTensor(context, hwcn_weights, hwcn_weights_size); if (hwcn_weights_status != kTfLiteOk) return hwcn_weights_status; - hwcn_weights->data.raw = static_cast(malloc(hwcn_weights->bytes)); // TODO(petewarden): If Resize() is called when the size hasn't actually // changed, this will do extra redundant work. @@ -290,26 +299,34 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, auto filter_offset = -filter->params.zero_point; auto output_offset = output->params.zero_point; - if (kernel_type == kReference) { - reference_ops::Conv( - GetTensorData(input), GetTensorDims(input), input_offset, - GetTensorData(filter), GetTensorDims(filter), filter_offset, - GetTensorData(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, - output_offset, data->output_multiplier, data->output_shift, - data->output_activation_min, data->output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col), gemm_context); - } else { - optimized_ops::Conv( - GetTensorData(input), GetTensorDims(input), input_offset, - GetTensorData(filter), GetTensorDims(filter), filter_offset, - GetTensorData(bias), GetTensorDims(bias), params->stride_width, - params->stride_height, data->padding.width, data->padding.height, - output_offset, data->output_multiplier, data->output_shift, - data->output_activation_min, data->output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col), gemm_context); + switch (kernel_type) { + case kReference: + reference_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col), gemm_context); + break; + case kGenericOptimized: + case kMultithreadOptimized: + case kCblasOptimized: + // There is only one optimized implementation for Quantized Conv. + optimized_ops::Conv( + GetTensorData(input), GetTensorDims(input), input_offset, + GetTensorData(filter), GetTensorDims(filter), filter_offset, + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, data->padding.width, + data->padding.height, output_offset, data->output_multiplier, + data->output_shift, data->output_activation_min, + data->output_activation_max, GetTensorData(output), + GetTensorDims(output), GetTensorData(im2col), + GetTensorDims(im2col), gemm_context); + break; } } @@ -322,30 +339,57 @@ void EvalFloat(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); - const float* filter_data; - if (data->need_hwcn_weights) { - filter_data = GetTensorData(hwcn_weights); - } else { - filter_data = GetTensorData(filter); - } - - if (kernel_type == kReference) { - reference_ops::Conv( - GetTensorData(input), GetTensorDims(input), filter_data, - GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, output_activation_min, output_activation_max, - GetTensorData(output), GetTensorDims(output), - GetTensorData(im2col), GetTensorDims(im2col)); - } else { - multithreaded_ops::Conv( - GetTensorData(input), GetTensorDims(input), filter_data, - GetTensorDims(filter), GetTensorData(bias), GetTensorDims(bias), - params->stride_width, params->stride_height, data->padding.width, - data->padding.height, params->padding, output_activation_min, - output_activation_max, GetTensorData(output), - GetTensorDims(output), GetTensorData(im2col), - GetTensorDims(im2col)); + switch (kernel_type) { + case kReference: { + reference_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kGenericOptimized: { + optimized_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kMultithreadOptimized: { + const float* filter_data; + if (data->need_hwcn_weights) { + filter_data = GetTensorData(hwcn_weights); + } else { + filter_data = GetTensorData(filter); + } + multithreaded_ops::Conv( + GetTensorData(input), GetTensorDims(input), filter_data, + GetTensorDims(filter), GetTensorData(bias), + GetTensorDims(bias), params->stride_width, params->stride_height, + data->padding.width, data->padding.height, params->padding, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } + case kCblasOptimized: { + cblas_ops::Conv(GetTensorData(input), GetTensorDims(input), + GetTensorData(filter), GetTensorDims(filter), + GetTensorData(bias), GetTensorDims(bias), + params->stride_width, params->stride_height, + data->padding.width, data->padding.height, + output_activation_min, output_activation_max, + GetTensorData(output), GetTensorDims(output), + GetTensorData(im2col), GetTensorDims(im2col)); + break; + } } } @@ -406,17 +450,23 @@ TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT() { return &r; } -TfLiteRegistration* Register_CONVOLUTION_NEON_OPT() { +TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT() { + static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, + conv::Eval}; + return &r; +} + +TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT() { static TfLiteRegistration r = {conv::Init, conv::Free, conv::Prepare, - conv::Eval}; + conv::Eval}; return &r; } TfLiteRegistration* Register_CONV_2D() { -#ifdef USE_NEON - return Register_CONVOLUTION_NEON_OPT(); +#ifdef TFLITE_USE_APPLE_ACCELERATE_FOR_CONV + return Register_CONVOLUTION_CBLAS_OPT(); #else - return Register_CONVOLUTION_GENERIC_OPT(); + return Register_CONVOLUTION_MULTITHREADED_OPT(); #endif } diff --git a/tensorflow/contrib/lite/kernels/conv_test.cc b/tensorflow/contrib/lite/kernels/conv_test.cc index 18d7a31d594efb6a05fe7292a0194ea17599a65b..d2393c3c97bb9516e2b8a6c8ae037dc0dfdfe64b 100644 --- a/tensorflow/contrib/lite/kernels/conv_test.cc +++ b/tensorflow/contrib/lite/kernels/conv_test.cc @@ -15,12 +15,25 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "tensorflow/contrib/lite/interpreter.h" #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/kernels/test_util.h" #include "tensorflow/contrib/lite/model.h" namespace tflite { + +namespace ops { +namespace builtin { + +TfLiteRegistration* Register_CONVOLUTION_REF(); +TfLiteRegistration* Register_CONVOLUTION_GENERIC_OPT(); +TfLiteRegistration* Register_CONVOLUTION_MULTITHREADED_OPT(); +TfLiteRegistration* Register_CONVOLUTION_CBLAS_OPT(); + +} // namespace builtin +} // namespace ops + namespace { using ::testing::ElementsAreArray; @@ -30,9 +43,9 @@ class BaseConvolutionOpModel : public SingleOpModel { // TODO(ahentz): Also test different activation types, bias, padding types, // stride values. BaseConvolutionOpModel( - const TensorData& input, const TensorData& filter, - const TensorData& output, int stride_width = 2, int stride_height = 2, - enum Padding padding = Padding_VALID, + TfLiteRegistration* registration, const TensorData& input, + const TensorData& filter, const TensorData& output, int stride_width = 2, + int stride_height = 2, enum Padding padding = Padding_VALID, enum ActivationFunctionType activation = ActivationFunctionType_NONE) { input_ = AddInput(input); filter_ = AddInput(filter); @@ -62,6 +75,8 @@ class BaseConvolutionOpModel : public SingleOpModel { stride_height, activation) .Union()); + resolver_ = absl::make_unique(BuiltinOperator_CONV_2D, + registration); BuildInterpreter({GetShape(input_), GetShape(filter_), GetShape(bias_)}); } @@ -83,12 +98,26 @@ class ConvolutionOpModel : public BaseConvolutionOpModel { void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } - std::vector GetOutput() { return ExtractVector(output_); } }; -TEST(ConvolutionOpTest, SimpleTestFloat32) { - ConvolutionOpModel m({TensorType_FLOAT32, {2, 2, 4, 1}}, +const auto kKernelMap = new std::map({ + {"Reference", ops::builtin::Register_CONVOLUTION_REF()}, + {"GenericOptimized", ops::builtin::Register_CONVOLUTION_GENERIC_OPT()}, + {"MultithreadedOptimized", + ops::builtin::Register_CONVOLUTION_MULTITHREADED_OPT()}, + {"CblasOptimized", ops::builtin::Register_CONVOLUTION_CBLAS_OPT()}, +}); + +class ConvolutionOpTest : public SingleOpTest { + protected: + const std::map& GetKernelMap() override { + return *kKernelMap; + } +}; + +TEST_P(ConvolutionOpTest, SimpleTestFloat32) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {2, 2, 4, 1}}, {TensorType_FLOAT32, {3, 2, 2, 1}}, {TensorType_FLOAT32, {}}); @@ -117,8 +146,8 @@ TEST(ConvolutionOpTest, SimpleTestFloat32) { })); } -TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { - ConvolutionOpModel m({TensorType_FLOAT32, {1, 3, 6, 1}}, +TEST_P(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { + ConvolutionOpModel m(GetRegistration(), {TensorType_FLOAT32, {1, 3, 6, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, /*stride_width=*/3, /*stride_height=*/1); @@ -139,7 +168,7 @@ TEST(ConvolutionOpTest, SimpleTestFloat32WithAnisotropicStrides) { })); } -TEST(ConvolutionOpTest, HandCalculatedFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -150,6 +179,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -192,7 +222,7 @@ TEST(ConvolutionOpTest, HandCalculatedFloat32) { 178, 187, 234, 261, 121})); } -TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -203,6 +233,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -245,7 +276,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithBiasFloat32) { 367, 188, 197, 244, 271, 131})); } -TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedWithReluFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -256,6 +287,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { const int stride_height = 1; const Padding padding = Padding_SAME; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -300,7 +332,7 @@ TEST(ConvolutionOpTest, HandCalculatedWithReluFloat32) { ElementsAreArray({0, 0, 0, 0, 35, 112, 157, 0, 0, 34, 61, 0})); } -TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { +TEST_P(ConvolutionOpTest, HandCalculatedValidFloat32) { const int depth = 1; const int image_width = 4; const int image_height = 3; @@ -311,6 +343,7 @@ TEST(ConvolutionOpTest, HandCalculatedValidFloat32) { const int stride_height = 1; const Padding padding = Padding_VALID; ConvolutionOpModel m( + GetRegistration(), {TensorType_FLOAT32, {image_batch_count, image_height, image_width, depth}}, {TensorType_FLOAT32, {depth, filter_size, filter_size, filter_count}}, @@ -366,8 +399,9 @@ class QuantizedConvolutionOpModel : public BaseConvolutionOpModel { // In this tests we set the input and output scales so that the results // match exactly the 'non-quantized' version. -TEST(ConvolutionOpTest, SimpleTestQuantized) { - QuantizedConvolutionOpModel m({TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, +TEST_P(ConvolutionOpTest, SimpleTestQuantized) { + QuantizedConvolutionOpModel m(GetRegistration(), + {TensorType_UINT8, {2, 2, 4, 1}, -63.5, 64}, {TensorType_UINT8, {3, 2, 2, 1}, -63.5, 64}, {TensorType_UINT8, {}, -127, 128}); m.SetInput({ @@ -405,8 +439,9 @@ TEST(ConvolutionOpTest, SimpleTestQuantized) { })); } -TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { - QuantizedConvolutionOpModel m({TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, +TEST_P(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { + QuantizedConvolutionOpModel m(GetRegistration(), + {TensorType_UINT8, {1, 3, 6, 1}, -63.5, 64}, {TensorType_UINT8, {1, 2, 2, 1}, -63.5, 64}, {TensorType_UINT8, {}, -127, 128}, /*stride_width=*/3, /*stride_height=*/1); @@ -430,11 +465,16 @@ TEST(ConvolutionOpTest, SimpleTestQuantizedWithAnisotropicStrides) { 167, 93, // })); } + +INSTANTIATE_TEST_CASE_P( + ConvolutionOpTest, ConvolutionOpTest, + ::testing::ValuesIn(SingleOpTest::GetKernelTags(*kKernelMap))); + } // namespace } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc index 39227b2811e2be719a0be77f89793bcf9366d513..1439c8bce14ad127ed68dc54991aed8b8bb39383 100644 --- a/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc +++ b/tensorflow/contrib/lite/kernels/depthwise_conv_test.cc @@ -180,7 +180,7 @@ TEST(QuantizedDepthwiseConvolutionOpTest, SimpleTestQuantized) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/div.cc b/tensorflow/contrib/lite/kernels/div.cc new file mode 100644 index 0000000000000000000000000000000000000000..44bd0dc85d50c98ec6b6888e05064a8f2e2731c0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/div.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace div { + +// This file has three implementation of Div. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalDivFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteDivParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_DIV(type) \ + type::Div(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_DIV(reference_ops); + } else { + TF_LITE_DIV(optimized_ops); + } +#undef TF_LITE_DIV +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalDivFloat(context, node, params, input1, input2, output); + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace div + +TfLiteRegistration* Register_DIV_REF() { + static TfLiteRegistration r = {nullptr, nullptr, div::Prepare, + div::Eval}; + return &r; +} + +TfLiteRegistration* Register_DIV_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, div::Prepare, + div::Eval}; + return &r; +} + +TfLiteRegistration* Register_DIV_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, div::Prepare, + div::Eval}; + return &r; +} + +TfLiteRegistration* Register_DIV() { +#ifdef USE_NEON + return Register_DIV_NEON_OPT(); +#else + return Register_DIV_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc index 69d9c5cc7dec13a65f1c5050f2f1c56812ad5aa1..ef2b5422253ea880a9ded4d3c0efc5cec07178a9 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_sparse_test.cc @@ -123,18 +123,16 @@ TEST(EmbeddingLookupOpTest, SimpleTestSqrtn) { [](int i, int j, int k) { return i + j / 10.0f + k / 100.0f; }); m.Invoke(); - EXPECT_THAT( - m.GetOutput(), - ElementsAreArray(ArrayFloatNear({ - 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 - 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - - 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), - 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), - 7.20f / std::sqrt(20.0f), - 7.26f / - std::sqrt( - 20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * Row 3 + 4 * Row 0 - }))); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({ + 1.00, 1.01, 1.10, 1.11, 1.20, 1.21, // Row 1 + 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, // - + 6.00f / std::sqrt(20.0f), 6.06f / std::sqrt(20.0f), + 6.60f / std::sqrt(20.0f), 6.66f / std::sqrt(20.0f), + 7.20f / std::sqrt(20.0f), + 7.26f / std::sqrt(20.0f), // 2 * Row 3 + 4 * Row 0, // 2 * + // Row 3 + 4 * Row 0 + }))); } TEST(EmbeddingLookupOpTest, Indices3DTest) { @@ -158,9 +156,7 @@ TEST(EmbeddingLookupOpTest, Indices3DTest) { } // namespace tflite int main(int argc, char** argv) { -#ifdef OS_LINUX - tflite::LogToStderr(); -#endif + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc index 8c030b06772ac0c6af34a45897f03ebc4637d4de..9b501878f196216a61568bfa36e6615f4dd07478 100644 --- a/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/embedding_lookup_test.cc @@ -88,7 +88,7 @@ TEST(EmbeddingLookupOpTest, SimpleTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9e79b742dc2c80ce4ed9a3aa786814265dcb660 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/exp.cc @@ -0,0 +1,92 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace exp { + +// This file has reference implementation of Exp. +enum KernelType { + kReference, +}; + +struct ExpContext { + ExpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + ExpContext op_context(context, node); + TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims); + op_context.output->type = op_context.input->type; + return context->ResizeTensor(context, op_context.output, output_dims); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + ExpContext op_context(context, node); + +#define TF_LITE_EXP(kernel_type, data_type) \ + kernel_type::Exp(GetTensorData(op_context.input), \ + NumElements(op_context.input), \ + GetTensorData(op_context.output)) + + // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128. + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_EXP(reference_ops, float); + break; + default: + context->ReportError(context, + "Type %d is currently not supported by Exp.", + op_context.input->type); + return kTfLiteError; + } + } +#undef TF_LITE_EXP + return kTfLiteOk; +} + +} // namespace exp + +TfLiteRegistration* Register_EXP_REF() { + static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare, + exp::Eval}; + return &r; +} + +// TODO(kanlig): add optimized implementation of Exp. +TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/exp_test.cc b/tensorflow/contrib/lite/kernels/exp_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..eed67369a1f30e57cd29a3975a899db41938def0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/exp_test.cc @@ -0,0 +1,70 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class ExpOpModel : public SingleOpModel { + public: + ExpOpModel(const TensorData& input, const TensorType& output) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_EXP, BuiltinOptions_ExpOptions, + CreateExpOptions(builder_).Union()); + BuildInterpreter({GetShape(input_)}); + } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int output_; +}; + +TEST(ExpOpTest, FloatTest) { + std::initializer_list data = {1.0, 0.0, -1.0, 1.0, 1.0, -1.0}; + ExpOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/fully_connected_test.cc b/tensorflow/contrib/lite/kernels/fully_connected_test.cc index 112e3f1ba01a428023eea5ee8410fb76c1d67de6..a0f766c4f4580d7679275c0b63aa200410fcb5ad 100644 --- a/tensorflow/contrib/lite/kernels/fully_connected_test.cc +++ b/tensorflow/contrib/lite/kernels/fully_connected_test.cc @@ -370,8 +370,7 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc new file mode 100644 index 0000000000000000000000000000000000000000..0e4187d1eac64636a2e2b25e9a1cc45c3a4da557 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -0,0 +1,131 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace gather { +constexpr int kInputTensor = 0; +constexpr int kInputPositions = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Only INT32 positions are supported. + TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); + // Check that input and output types match. + TF_LITE_ENSURE_EQ(context, input->type, output->type); + // TODO(mgubin): only 0D or 1D positions are currently supported. + TF_LITE_ENSURE(context, NumDimensions(positions) <= 1); + // TODO(mgubin): Only default axis == 0 is supported. + TF_LITE_ENSURE_EQ(context, params->axis, 0); + // Check conditions for different types. + switch (input->type) { + case kTfLiteFloat32: + case kTfLiteUInt8: + case kTfLiteInt32: { + // Fully supported by reference_ops::Gather. + } break; + + case kTfLiteString: { + // Only 1D input is supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + } break; + default: + context->ReportError(context, + "Only float32 and string types are supported"); + return kTfLiteError; + } + const int num_dimensions = + NumDimensions(input) + NumDimensions(positions) - 1; + TF_LITE_ENSURE(context, params->axis <= num_dimensions); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + int output_index = 0; + for (int i = 0; i < params->axis; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + for (int i = 0; i < positions->dims->size; ++i) { + output_shape->data[output_index++] = positions->dims->data[i]; + } + for (int i = params->axis + 1; i < input->dims->size; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const int input_rank = NumDimensions(input); +#define TF_LITE_GATHER(data_type, index_type) \ + optimized_ops::Gather( \ + GetTensorData(input), GetTensorDims(input), input_rank, \ + GetTensorData(positions), GetTensorDims(positions), \ + GetTensorData(output), GetTensorDims(output)); + switch (input->type) { + case kTfLiteFloat32: + TF_LITE_GATHER(float, int32_t); + break; + case kTfLiteUInt8: + TF_LITE_GATHER(uint8_t, int32_t); + break; + case kTfLiteInt32: + TF_LITE_GATHER(int32_t, int32_t); + break; + case kTfLiteString: { + DynamicBuffer buffer; + const int32* indexes = positions->data.i32; + const int num_strings = GetStringCount(input); + for (int i = 0; i < positions->dims->data[0]; ++i) { + const int pos = indexes[i]; + TF_LITE_ENSURE(context, pos < num_strings); + const auto string_ref = GetString(input, pos); + buffer.AddString(string_ref.str, string_ref.len); + } + buffer.WriteToTensor(output); + } break; + default: + return kTfLiteError; + } +#undef TF_LITE_GATHER + return kTfLiteOk; +} +} // namespace gather + +TfLiteRegistration* Register_GATHER() { + static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, + gather::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..cdadbeda1884ba0186846826dd16be6ff69878d9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class GatherOpModel : public SingleOpModel { + public: + GatherOpModel(std::initializer_list input_shape, TensorType input_type, + std::initializer_list positions_shape) { + input_ = AddInput(input_type); + positions_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_GATHER, BuiltinOptions_GatherOptions, + CreateGatherOptions(builder_, 0).Union()); + BuildInterpreter({input_shape, positions_shape}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUint8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(std::initializer_list data) { + PopulateStringTensor(input_, data); + } + + void SetPositions(std::initializer_list data) { + PopulateTensor(positions_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + std::vector GetOutputUint8() { + return ExtractVector(output_); + } + std::vector GetOutputString() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int positions_; + int output_; +}; + +TEST(GatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2}))); +} + +TEST(GatherOpTest, Test0DIndex) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.7, 0.8}))); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); +} + +TEST(GatherOpTest, Test0DIndexWith0DResult) { + // 0D tensor is special case in current TFLite. Test it once to make sure + // existing workarounds are fine with it. + GatherOpModel m({3}, TensorType_FLOAT32, {}); + m.SetInputFloat({1.0, 2.0, 3.0}); + m.SetPositions({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({2.0}))); + EXPECT_TRUE(m.GetOutputShape().empty()); +} + +TEST(FloatGatherOpTest, Duplicate) { + GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({0, 0}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({-2, 0.2, 0.7, 0.8, -2, 0.2, 0.7, 0.8}))); +} + +TEST(FloatGatherOpTest, Slice) { + GatherOpModel m({4, 1}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.2, 0.8}))); +} + +TEST(Uint8tGatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_UINT8, {2}); + m.SetInputUint8({133, 134, 14, 15}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputUint8(), ElementsAreArray({14, 15, 133, 134})); +} + +TEST(GatherOpTest, SimpleString) { + GatherOpModel m({3}, TensorType_STRING, {2}); + m.SetInput({"A", "B", "C"}); + m.SetPositions({0, 2}); + m.Invoke(); + ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutputString(), ElementsAreArray({"A", "C"})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/gemm_support.h b/tensorflow/contrib/lite/kernels/gemm_support.h index b531959ffb143c774ee715743480b03ebfbdc114..466781cbcecc7fb851d9078c450cc6c12364d2bb 100644 --- a/tensorflow/contrib/lite/kernels/gemm_support.h +++ b/tensorflow/contrib/lite/kernels/gemm_support.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ #include "public/gemmlowp.h" #include "tensorflow/contrib/lite/context.h" @@ -51,4 +51,4 @@ void SetMaxNumThreads(TfLiteContext* context, int num_threads); } // namespace gemm_support } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_GEMM_SUPPORT_H_ diff --git a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc index 916a23225e2ad3c5645a7809169677a7a8880535..ba0ed5ce06392613238b757308dddc2b22e7eb30 100644 --- a/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc +++ b/tensorflow/contrib/lite/kernels/hashtable_lookup_test.cc @@ -116,7 +116,10 @@ TEST(HashtableLookupOpTest, Test2DInput) { 1.0, 1.1, // 1-st item }))); EXPECT_THAT(m.GetHit(), ElementsAreArray({ - 1, 0, 1, 1, + 1, + 0, + 1, + 1, })); } @@ -170,7 +173,7 @@ TEST(HashtableLookupOpTest, TestString) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/internal/BUILD b/tensorflow/contrib/lite/kernels/internal/BUILD index 288534099b9e090ce0c223a401b4152ca6ffb61f..f47fb04cbaa688b75e763ff9d3cb7df44ac3f166 100644 --- a/tensorflow/contrib/lite/kernels/internal/BUILD +++ b/tensorflow/contrib/lite/kernels/internal/BUILD @@ -124,6 +124,20 @@ config_setting( }, ) +config_setting( + name = "darwin_x86_64", + values = { + "cpu": "darwin_x86_64", + }, +) + +config_setting( + name = "freebsd", + values = { + "cpu": "freebsd", + }, +) + cc_library( name = "optimized_base", srcs = [], @@ -138,7 +152,7 @@ cc_library( ":types", ":round", "//third_party/eigen3", - "@gemmlowp//:gemmlowp", + "@gemmlowp", "//tensorflow/contrib/lite:builtin_op_data", ] + select({ ":haswell": tflite_deps_intel, @@ -147,6 +161,8 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, "//conditions:default": [], }), ) @@ -154,6 +170,8 @@ cc_library( cc_library( name = "optimized", hdrs = [ + "optimized/cblas_conv.h", + "optimized/cblas_reference.h", "optimized/eigen_spatial_convolutions.h", "optimized/eigen_tensor_reduced_instantiations_oss.h", "optimized/multithreaded_conv.h", @@ -215,7 +233,7 @@ cc_library( ":round", ":types", "//third_party/eigen3", - "@gemmlowp//:gemmlowp", + "@gemmlowp", "//tensorflow/contrib/lite:builtin_op_data", ] + select({ ":haswell": tflite_deps_intel, @@ -224,6 +242,8 @@ cc_library( ":x86": tflite_deps_intel, ":x86_64": tflite_deps_intel, ":darwin": tflite_deps_intel, + ":darwin_x86_64": tflite_deps_intel, + ":freebsd": tflite_deps_intel, "//conditions:default": [], }), ) @@ -258,6 +278,8 @@ cc_library( "optimized/neon_tensor_utils.cc", ], hdrs = [ + "common.h", + "optimized/cpu_check.h", "optimized/neon_tensor_utils.h", "optimized/tensor_utils_impl.h", ], @@ -265,8 +287,21 @@ cc_library( deps = [ ":cpu_check", ":portable_tensor_utils", + ":types", "//tensorflow/contrib/lite:builtin_op_data", "//tensorflow/contrib/lite/kernels:activation_functor", + "@arm_neon_2_x86_sse", + "@gemmlowp", + ], +) + +cc_library( + name = "kernel_utils", + srcs = ["kernel_utils.cc"], + hdrs = ["kernel_utils.h"], + deps = [ + ":tensor_utils", + "//tensorflow/contrib/lite:builtin_op_data", ], ) @@ -276,14 +311,21 @@ cc_library( "tensor_utils.cc", ], hdrs = [ + "common.h", + "compatibility.h", + "optimized/cpu_check.h", + "optimized/neon_tensor_utils.h", "optimized/tensor_utils_impl.h", "reference/portable_tensor_utils.h", "tensor_utils.h", + "types.h", ], copts = NEON_FLAGS_IF_APPLICABLE, deps = [ "//tensorflow/contrib/lite/kernels:activation_functor", "//tensorflow/contrib/lite:builtin_op_data", + "@arm_neon_2_x86_sse", + "@gemmlowp", ] + select({ ":arm": [ ":neon_tensor_utils", @@ -303,6 +345,21 @@ cc_library( ":ios_arm64": [ ":neon_tensor_utils", ], + ":ios_x86_64": [ + ":neon_tensor_utils", + ], + ":x86_64": [ + ":neon_tensor_utils", + ], + ":x86": [ + ":neon_tensor_utils", + ], + ":k8": [ + ":neon_tensor_utils", + ], + ":darwin": [ + ":neon_tensor_utils", + ], "//conditions:default": [ ":portable_tensor_utils", ], diff --git a/tensorflow/contrib/lite/kernels/internal/common.h b/tensorflow/contrib/lite/kernels/internal/common.h index 28f19a250629aec4d03aa71df57d31d8a5014e9f..18601df22c1894dea6ce51f46ba815cd12dab095 100644 --- a/tensorflow/contrib/lite/kernels/internal/common.h +++ b/tensorflow/contrib/lite/kernels/internal/common.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ #ifndef ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK #ifdef GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK @@ -102,6 +102,17 @@ inline int32 MultiplyByQuantizedMultiplierGreaterThanOne( quantized_multiplier); } +inline int32 MultiplyByQuantizedMultiplier(int32 x, int32 quantized_multiplier, + int shift) { + using gemmlowp::RoundingDivideByPOT; + using gemmlowp::SaturatingRoundingDoublingHighMul; + int left_shift = shift > 0 ? shift : 0; + int right_shift = shift > 0 ? 0 : -shift; + return RoundingDivideByPOT(SaturatingRoundingDoublingHighMul( + x * (1 << left_shift), quantized_multiplier), + right_shift); +} + } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMMON_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/compatibility.h b/tensorflow/contrib/lite/kernels/internal/compatibility.h index 796a03566a4bf971294dd2375f590dfd20d600f7..51426bb1c584b82af7b1a2ffaf5a675a1dd9a6fd 100644 --- a/tensorflow/contrib/lite/kernels/internal/compatibility.h +++ b/tensorflow/contrib/lite/kernels/internal/compatibility.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ #include #include @@ -27,6 +27,10 @@ limitations under the License. #define TFLITE_DCHECK_EQ(x, y) ((x) == (y)) ? (void)0 : assert(false) #endif +#ifndef TFLITE_DCHECK_NE +#define TFLITE_DCHECK_NE(x, y) ((x) != (y)) ? (void)0 : assert(false) +#endif + #ifndef TFLITE_DCHECK_GE #define TFLITE_DCHECK_GE(x, y) ((x) >= (y)) ? (void)0 : assert(false) #endif @@ -52,6 +56,10 @@ limitations under the License. #define TFLITE_CHECK_EQ(x, y) ((x) == (y)) ? (void)0 : abort() #endif +#ifndef TFLITE_CHECK_NE +#define TFLITE_CHECK_NE(x, y) ((x) != (y)) ? (void)0 : abort() +#endif + #ifndef TFLITE_CHECK_GE #define TFLITE_CHECK_GE(x, y) ((x) >= (y)) ? (void)0 : abort() #endif @@ -75,4 +83,4 @@ using uint16 = std::uint16_t; using int32 = std::int32_t; using uint32 = std::uint32_t; -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_COMPATIBILITY_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..510395126ce3785b1d44fec1e0eb994c29ff0db7 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.cc @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" + +namespace tflite { +namespace kernel_utils { + +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch) { + // Output = bias + tensor_utils::VectorBatchVectorAssign(bias_ptr, num_units, batch_size, + output_ptr_batch); + // Output += input * input_weights + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_weights_ptr, num_units, input_size, input_ptr_batch, batch_size, + output_ptr_batch, /*result_stride=*/1); + // Output += recurrent_weights * hidden_state + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_weights_ptr, num_units, num_units, hidden_state_ptr_batch, + batch_size, output_ptr_batch, /*result_stride=*/1); + // Output = activation(Output) and update hidden_state + tensor_utils::ApplyActivationToVector( + output_ptr_batch, num_units * batch_size, activation, output_ptr_batch); + tensor_utils::VectorBatchVectorAssign(output_ptr_batch, num_units, batch_size, + hidden_state_ptr_batch); +} + +} // namespace kernel_utils +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/internal/kernel_utils.h b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..9872d4500b862388ed4b96c97e3755f548e35d35 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/kernel_utils.h @@ -0,0 +1,40 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ + +#include "tensorflow/contrib/lite/builtin_op_data.h" + +namespace tflite { +namespace kernel_utils { + +// Performs an RNN batch inference step for inputs specified by input_ptr_batch. +// The RNN cell is specified by the pointers to its input and recurrent weights, +// and biases, along with the input size, number of units, activation. +// +// The pointers to the hidden state and the output are updated as a result. +// +// The pointers with the suffix "_batch" point to data aligned in batch_major +// order, and each step processes batch_size many inputs from input_ptr_batch, +// and updates batch_size many outputs and hidden states. +void RnnBatchStep(const float* input_ptr_batch, const float* input_weights_ptr, + const float* recurrent_weights_ptr, const float* bias_ptr, + int input_size, int num_units, int batch_size, + TfLiteFusedActivation activation, + float* hidden_state_ptr_batch, float* output_ptr_batch); + +} // namespace kernel_utils +} // namespace tflite +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_KERNEL_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..4a90e7e640ef29b675c236d8bbb479aa16560761 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_conv.h @@ -0,0 +1,92 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ + +// The Conv implementation based on CBLAS interface. This is only used on iOS +// for now, utilizing Apple's Accelerate framework. + +#if TFLITE_USE_APPLE_ACCELERATE_FOR_CONV +#include +#else +#include "tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h" +#endif + +#include "tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" + +namespace tflite { +namespace cblas_ops { + +inline void Conv(const float* input_data, const Dims<4>& input_dims, + const float* filter_data, const Dims<4>& filter_dims, + const float* bias_data, const Dims<4>& bias_dims, + int stride_width, int stride_height, int pad_width, + int pad_height, float output_activation_min, + float output_activation_max, float* output_data, + const Dims<4>& output_dims, float* im2col_data, + const Dims<4>& im2col_dims) { + gemmlowp::ScopedProfilingLabel label("Conv/cblas"); + + const float* gemm_input_data = nullptr; + const Dims<4>* gemm_input_dims = nullptr; + const int filter_width = ArraySize(filter_dims, 1); + const int filter_height = ArraySize(filter_dims, 2); + const bool need_im2col = stride_width != 1 || stride_height != 1 || + filter_width != 1 || filter_height != 1; + if (need_im2col) { + TFLITE_DCHECK(im2col_data); + optimized_ops::Im2col(input_data, input_dims, stride_width, stride_height, + pad_width, pad_height, filter_height, filter_width, 0, + im2col_data, im2col_dims); + gemm_input_data = im2col_data; + gemm_input_dims = &im2col_dims; + } else { + TFLITE_DCHECK(!im2col_data); + gemm_input_data = input_data; + gemm_input_dims = &input_dims; + } + + // The following code computes matrix multiplication c = a * transponse(b) + // with CBLAS, where: + // * `a` is a matrix with dimensions (m, k). + // * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n). + // * `c` is a matrix with dimensions (m, n). + // The naming of variables are aligned with CBLAS specification here. + const float* a = gemm_input_data; + const float* b = filter_data; + float* c = output_data; + int m = gemm_input_dims->sizes[1] * gemm_input_dims->sizes[2] * + gemm_input_dims->sizes[3]; + int n = output_dims.sizes[0]; + int k = gemm_input_dims->sizes[0]; + // The stride of matrix a, b and c respectively. + int stride_a = k; + int stride_b = k; + int stride_c = n; + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, m, n, k, 1.0f, a, + stride_a, b, stride_b, 0.0f, c, stride_c); + + optimized_ops::AddBiasAndEvalActivationFunction( + bias_data, bias_dims, output_data, output_dims, output_activation_min, + output_activation_max); +} + +} // namespace cblas_ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_CONV_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h new file mode 100644 index 0000000000000000000000000000000000000000..6acc513805c9398c304f3e24175d3bd6c96938f6 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cblas_reference.h @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ + +#include "tensorflow/contrib/lite/kernels/internal/compatibility.h" + +// The reference implementation for a small subset of CBLAS interface. +// This is only used for testing CBLAS implementation, and should never be used +// in production code. + +namespace tflite { +namespace cblas_ops { + +// The following code follows the original CBLAS specification, and it might +// conflict with the TensorFlow naming convention. +// TODO(ycling): Find another way to test CBLAS with bazel, without writing +// a reference implementation by ourselves. +enum CBLAS_ORDER { CblasRowMajor = 0, CblasColMajor = 1 }; + +enum CBLAS_TRANSPOSE { CblasNoTrans = 0, CblasTrans = 1, CblasConjTrans = 2 }; + +// A reference implementation for matrix multiplication. +// The following code computes, c = a * transponse(b) matrix multiplication +// with CBLAS, where: +// * `a` is a matrix with dimensions (m, k). +// * `b` is a matrix with dimensions (n, k), so transpose(b) is (k, n). +// * `c` is a matrix with dimensions (m, n). +// The naming of variables is aligned with CBLAS specification here. +void cblas_sgemm(const enum CBLAS_ORDER order, + const enum CBLAS_TRANSPOSE trans_a, + const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, + const int k, const float alpha, const float *a, + const int stride_a, const float *b, const int stride_b, + const float beta, float *c, const int stride_c) { + TFLITE_DCHECK(order == CblasRowMajor); + TFLITE_DCHECK(trans_a == CblasNoTrans); + TFLITE_DCHECK(trans_b == CblasTrans); + TFLITE_DCHECK(beta == 0.0f); + for (int row = 0; row < m; ++row) { + for (int col = 0; col < n; ++col) { + // If `beta` non-zero, multiple it with the original values in output. + // Otherwise, ignore the original value in output completely. + float value = 0.0f; + for (int idx = 0; idx < k; ++idx) { + value += alpha * a[stride_a * row + idx] * b[stride_b * col + idx]; + } + c[stride_c * row + col] = value; + } + } +} + +} // namespace cblas_ops +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_CBLAS_REFERENCE_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h index dea46cc12065ed34cf681916a46a55bd7a86f463..3a53d3ab07faf63250fc18fc846e0b8f5a39d9c4 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/cpu_check.h @@ -34,17 +34,13 @@ inline bool TestCPUFeatureNeon() { #endif // __aarch64__ } -#elif __ARM_NEON +#elif defined USE_NEON || defined __ARM_NEON -inline bool TestCPUFeatureNeon() { - return true; -} +inline bool TestCPUFeatureNeon() { return true; } #else -inline bool TestCPUFeatureNeon() { - return false; -} +inline bool TestCPUFeatureNeon() { return false; } #endif diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h index 974611f52ac74cec275f978c5af5bd561688db78..7f6eea2d5d1cfd6f4e2a569760ecbe0d96f754c8 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_float.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ #include "public/gemmlowp.h" #include "tensorflow/contrib/lite/kernels/internal/common.h" @@ -311,6 +311,9 @@ struct FloatDepthwiseConvKernel { } }; +// Note this implementation is very slow for input_depths < 8 +// (e.g. comparable to reference implementation) see, specializations for +// input_depth=3 below. template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -417,6 +420,74 @@ struct FloatDepthwiseConvKernel { } }; +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x2_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1_f32(filter_ptr + 2 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x2_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1_f32(acc_buffer_ptr + 2 * i); + } + // Multiply-accumulate for each input channel there 2 outputs + acc[0] = vmla_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmla_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmla_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1_f32(acc_buffer_ptr + 2 * i, acc[i]); + } + acc_buffer_ptr += 6; + input_ptr += input_ptr_increment; + } + } +}; + +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter[3]; + for (int i = 0; i < 3; i++) { + filter[i] = vld1q_f32(filter_ptr + 4 * i); + } + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // NOTE: we only want 3 values, so we read it as two ops where + // the second op just duplicates the lane + const float32x2_t input01 = vld1_f32(input_ptr); + const float32x2_t input2 = vld1_dup_f32(input_ptr + 2); + // Load the accumulators from acc_buffer + float32x4_t acc[3]; + for (int i = 0; i < 3; i++) { + acc[i] = vld1q_f32(acc_buffer_ptr + 4 * i); + } + // Multiply-accumulate all outputs. + acc[0] = vmlaq_lane_f32(acc[0], filter[0], input01, 0); + acc[1] = vmlaq_lane_f32(acc[1], filter[1], input01, 1); + acc[2] = vmlaq_lane_f32(acc[2], filter[2], input2, 0); + // Store the accumulators back to acc_buffer + for (int i = 0; i < 3; i++) { + vst1q_f32(acc_buffer_ptr + 4 * i, acc[i]); + } + acc_buffer_ptr += 12; + input_ptr += input_ptr_increment; + } + } +}; + template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -502,6 +573,46 @@ struct FloatDepthwiseConvKernel { } }; +template <> +struct FloatDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const float* input_ptr, int input_ptr_increment, + const float* filter_ptr, float* acc_buffer_ptr) { + // Load the filters + float32x4_t filter_0 = vld1q_f32(filter_ptr + 4 * 0); + float32x4_t filter_1 = vld1q_f32(filter_ptr + 4 * 1); + float32x4_t filter_2 = vld1q_f32(filter_ptr + 4 * 2); + float32x4_t filter_3 = vld1q_f32(filter_ptr + 4 * 3); + float32x4_t filter_4 = vld1q_f32(filter_ptr + 4 * 4); + + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + // Load the inputs + const float input_val = *input_ptr; + input_ptr += input_ptr_increment; + // Load the accumulators from acc_buffer + float32x4_t acc_0 = vld1q_f32(acc_buffer_ptr + 4 * 0); + float32x4_t acc_1 = vld1q_f32(acc_buffer_ptr + 4 * 1); + float32x4_t acc_2 = vld1q_f32(acc_buffer_ptr + 4 * 2); + float32x4_t acc_3 = vld1q_f32(acc_buffer_ptr + 4 * 3); + float32x4_t acc_4 = vld1q_f32(acc_buffer_ptr + 4 * 4); + // Multiply-accumulate + acc_0 = vmlaq_n_f32(acc_0, filter_0, input_val); + acc_1 = vmlaq_n_f32(acc_1, filter_1, input_val); + acc_2 = vmlaq_n_f32(acc_2, filter_2, input_val); + acc_3 = vmlaq_n_f32(acc_3, filter_3, input_val); + acc_4 = vmlaq_n_f32(acc_4, filter_4, input_val); + // Store the accumulators back to acc_buffer + vst1q_f32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_f32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_f32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_f32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_f32(acc_buffer_ptr + 4 * 4, acc_4); + acc_buffer_ptr += 20; + } + } +}; + template <> struct FloatDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -855,8 +966,11 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 2, 1) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 2) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 3, 4) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 4, 1) // Finally, the kernels allowing a variable input depth, @@ -919,11 +1033,11 @@ inline void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, for (int k = 0; k < 4; k++) { acc[k] = vld1q_f32(acc_buffer + i + 4 * k); } - for (int k = 0; k < 4; k++) { - acc[k] = vmaxq_f32( - vdupq_n_f32(output_activation_min), - vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); - } + for (int k = 0; k < 4; k++) { + acc[k] = vmaxq_f32( + vdupq_n_f32(output_activation_min), + vminq_f32(vdupq_n_f32(output_activation_max), acc[k])); + } for (int k = 0; k < 4; k++) { vst1q_f32(output_ptr + 4 * k, acc[k]); } @@ -984,4 +1098,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, } // namespace optimized_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h index 051ed2a2c44a04f0473dfd26637e53865a5a51ac..dbc4f0d6fdca8279072d6ea225334722d6a89eb2 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/depthwiseconv_uint8.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ #include "fixedpoint/fixedpoint.h" #include "public/gemmlowp.h" @@ -1205,6 +1205,55 @@ struct QuantizedDepthwiseConvKernel { } }; +template <> +struct QuantizedDepthwiseConvKernel { + static void Run(int num_output_pixels, int input_depth, int depth_multiplier, + const uint8* input_ptr, int16 input_offset, + int input_ptr_increment, const uint8* filter_ptr, + int16 filter_offset, int32* acc_buffer_ptr) { + // Load the filters, add filter_offset. + // NEON wants to load 8 bytes at a time, but 20 is not divisible by 8. + // We load the first 16 bytes into filter_u8_{0,1} as usual. + // Then we load the 8 last bytes into filter_u8_x (x for 'extra'). + // This is redundant: the first 4 bytes of filter_u8_x are the same + // as the last 4 bytes of filter_u8_x. + uint8x8_t filter_u8_0 = vld1_u8(filter_ptr + 8 * 0); + uint8x8_t filter_u8_1 = vld1_u8(filter_ptr + 8 * 1); + uint8x8_t filter_u8_x = vld1_u8(filter_ptr + 8 * 1 + 4); + int16x8_t filter_0 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_0)); + int16x8_t filter_1 = vreinterpretq_s16_u16(vmovl_u8(filter_u8_1)); + int16x8_t filter_x = vreinterpretq_s16_u16(vmovl_u8(filter_u8_x)); + filter_0 = vaddq_s16(filter_0, vdupq_n_s16(filter_offset)); + filter_1 = vaddq_s16(filter_1, vdupq_n_s16(filter_offset)); + filter_x = vaddq_s16(filter_x, vdupq_n_s16(filter_offset)); + // Handle one output pixel at a time. + for (int outp = 0; outp < num_output_pixels; outp++) { + uint8 input_u8 = *input_ptr; + input_ptr += input_ptr_increment; + int16 input = static_cast(input_u8 + input_offset); + // Load the accumulators from acc_buffer + int32x4_t acc_0 = vld1q_s32(acc_buffer_ptr + 4 * 0); + int32x4_t acc_1 = vld1q_s32(acc_buffer_ptr + 4 * 1); + int32x4_t acc_2 = vld1q_s32(acc_buffer_ptr + 4 * 2); + int32x4_t acc_3 = vld1q_s32(acc_buffer_ptr + 4 * 3); + int32x4_t acc_4 = vld1q_s32(acc_buffer_ptr + 4 * 4); + // Multiply-accumulate + acc_0 = vmlal_n_s16(acc_0, vget_low_s16(filter_0), input); + acc_1 = vmlal_n_s16(acc_1, vget_high_s16(filter_0), input); + acc_2 = vmlal_n_s16(acc_2, vget_low_s16(filter_1), input); + acc_3 = vmlal_n_s16(acc_3, vget_high_s16(filter_1), input); + acc_4 = vmlal_n_s16(acc_4, vget_high_s16(filter_x), input); + // Store the accumulators back to acc_buffer + vst1q_s32(acc_buffer_ptr + 4 * 0, acc_0); + vst1q_s32(acc_buffer_ptr + 4 * 1, acc_1); + vst1q_s32(acc_buffer_ptr + 4 * 2, acc_2); + vst1q_s32(acc_buffer_ptr + 4 * 3, acc_3); + vst1q_s32(acc_buffer_ptr + 4 * 4, acc_4); + acc_buffer_ptr += 20; + } + } +}; + template <> struct QuantizedDepthwiseConvKernel { static void Run(int num_output_pixels, int input_depth, int depth_multiplier, @@ -1504,7 +1553,7 @@ inline void QuantizedDepthwiseConvAccumRowGeneric( << "*\n" << "* If you would like to carry on with the slow code, compile\n" << "* with this preprocessor token defined:\n" - << "* TFLITE_ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" + << "* ALLOW_SLOW_GENERIC_DEPTHWISECONV_FALLBACK.\n" << "*\n" << "* The right thing to do, if you care about performance, is to add\n" << "* a new DepthwiseConv kernel to tfmini to cover your case.\n" @@ -1691,6 +1740,7 @@ inline void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 2) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 16, 1) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 16) + TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 20) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 32) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 1, 8) TFMINI_USE_DEPTHWISECONV_KERNEL(true, 8, 1) @@ -1913,4 +1963,4 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, } // namespace optimized_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h index 8004c24a9914e216974539930853d0aadf61e324..ce3cde76999c77e1f9bf1eaccdba7e84ed508dda 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_spatial_convolutions.h @@ -16,8 +16,8 @@ limitations under the License. // Copied from tensorflow/core/kernels/eigen_spatial_convolutions.h. // TODO(petewarden) - move this to a common location in Eigen itself. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ #define EIGEN_USE_CUSTOM_THREAD_POOL #define EIGEN_USE_THREADS @@ -39,7 +39,6 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #endif - namespace Eigen { /** SpatialConvolution @@ -215,17 +214,16 @@ EIGEN_DEVICE_FUNC } // TODO(yangke): choose() is defined in TensorContraction.h -- consider // moving it to somewhere more "common". - return - input - .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, - row_in_stride, col_in_stride, padding_type) - .reshape(pre_contract_dims) - .contract(kernel.reshape(kernel_dims), contract_dims) - .reshape(post_contract_dims); + return input + .extract_image_patches(kernelRows, kernelCols, row_stride, col_stride, + row_in_stride, col_in_stride, padding_type) + .reshape(pre_contract_dims) + .contract(kernel.reshape(kernel_dims), contract_dims) + .reshape(post_contract_dims); } } // end namespace Eigen // clang-format on -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_SPATIAL_CONVOLUTIONS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h index 7f78f69360b1ebbfb08600c8bc427f1ba9d5244d..d85e06a5d5af8d23235a08592d49754e4f493d34 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_google.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_GOOGLE_H_ #define EIGEN_USE_CUSTOM_THREAD_POOL #define EIGEN_USE_THREADS @@ -140,4 +140,4 @@ limitations under the License. #include "third_party/eigen3/unsupported/Eigen/CXX11/src/Tensor/TensorIO.h" #include "Eigen/src/Core/util/ReenableStupidWarnings.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_H diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h index 1d5c316194df0b87ee7eecbdd04bd5ce9e2e40b5..d34708b8fd0c0732c13ddbd8d70c87a278c40ff8 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/eigen_tensor_reduced_instantiations_oss.h @@ -19,8 +19,8 @@ limitations under the License. // clang-format off -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ #include "Eigen/Core" @@ -164,4 +164,4 @@ typedef unsigned __int64 uint64_t; #include "Eigen/src/Core/util/ReenableStupidWarnings.h" -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_EIGEN_TENSOR_REDUCED_INSTANTIATIONS_OSS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h index b3615f4658a1a70284cc9d386a868a87aa09819b..0bfb4e9b1f8ee4167cfb629645a38538be1d73d4 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/multithreaded_conv.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV #include #include @@ -192,4 +192,4 @@ inline void Conv(const float* input_data, const Dims<4>& input_dims, } // namespace multithreaded_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_MULTITHREAD_CONV diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc index bf0bdfb1fb875c4b54c55e25d4a17541507ecd4c..780401e052733cccae0cc34f495df090c1530624 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.cc @@ -15,12 +15,13 @@ limitations under the License. #include #include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h" #ifdef USE_NEON -#include #define kFloatWeightsPerNeonLane 4 namespace tflite { diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h index 3a4af87304eaf33489b38bd9b15ad9789e091d24..b7e317dc60e2c68e9e993ff45c9090a01bd13b94 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/neon_tensor_utils.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ // TODO(ghodrat): Remove this header file and the dependency to internal data // structure. @@ -110,4 +110,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_NEON_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h index cd565c16a1ee7226f83c19f0020beed75e401497..cd52385f417b469a24b6aa2b15f54ddad5fa9731 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ #include #include @@ -1538,9 +1538,10 @@ void Add(const int32* input1_data, const Dims<4>& input1_dims, // reference_ops.h. Once an optimized version is implemented and NdArrayDesc // is no longer referenced in this file, move NdArrayDesc from types.h to // reference_ops.h. -template +template void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); @@ -1563,15 +1564,30 @@ void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastAdd(int left_shift, const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, int32 input1_multiplier, int input1_shift, @@ -1772,9 +1788,10 @@ void Mul(const int32* input1_data, const Dims<4>& input1_dims, // reference_ops.h. Once an optimized version is implemented and NdArrayDesc // is no longer referenced in this file, move NdArrayDesc from types.h to // reference_ops.h. -template +template void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastMul"); @@ -1797,15 +1814,30 @@ void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] * - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastMul(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, const uint8* input2_data, const Dims<4>& input2_dims, int32 input2_offset, @@ -1868,6 +1900,61 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } +// TODO(aselle): This is not actually optimized yet. +inline void Div(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] / + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +// TODO(aselle): This is not actually optimized yet. +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] - + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, @@ -1994,6 +2081,166 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, output_state_map.tanh(); } +// Quantized LSTM cell. Currently just a copy of the reference impl in +// reference_ops.h. See the big function comment there, not replicating it +// here. +template +void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, + const uint8* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, + const Dims<4>& weights_dims, const int32* bias_data_int32, + const Dims<4>& bias_dims, const int16* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16* output_state_data_int16, + const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32 weights_zero_point, + int32 accum_multiplier, int accum_shift) { + gemmlowp::ScopedProfilingLabel label( + "LstmCell/quantized (8bit external, 16bit internal)"); + // Gather dimensions information, and perform consistency checks. + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = ArraySize(activ_temp_dims, 1) * + ArraySize(activ_temp_dims, 2) * + ArraySize(activ_temp_dims, 3); + const int fc_output_depth = + MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); + const int fc_accum_depth = ArraySize(weights_dims, 0); + TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); + + // Depth-concatenate prev_activ and input data together. + uint8 const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; + Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; + Concatenation( + 0, concat_input_arrays_data, concat_input_arrays_dims, 2, + concat_temp_data_uint8, concat_temp_dims); + + // Implementation of the fully connected node inside the LSTM cell. + // The operands are 8-bit integers, the accumulators are internally 32bit + // integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. + for (int b = 0; b < fc_batches; ++b) { + for (int out_c = 0; out_c < fc_output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32 accum = bias_data_int32[out_c]; + // Accumulation loop. + for (int d = 0; d < fc_accum_depth; ++d) { + int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128; + int16 weights_val = + weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point; + accum += input_val * weights_val; + } + // Down-scale the final int32 accumulator to the scale used by our + // (16-bit, using 3 integer bits) fixed-point format. The quantized + // multiplier and shift here have been pre-computed offline + // (e.g. by toco). + // Note that the implicit assumption here, that this multiplier is smaller + // than one, is equivalent to the assumption that the fully-connected + // weights min-max is enclosed within [-4, 4] (it may be narrower). + // If that eventually fails, offline tools (e.g. toco) will fail early + // and that will be easy to support as needed. For now, assuming that + // this multiplier is less than one allows us to use a simpler, more + // accurate implementation. + accum = + MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift); + // Saturate, cast to int16, and store to the temporary activations array. + accum = std::max(-32768, std::min(32767, accum)); + activ_temp_data_int16[out_c + fc_output_depth * b] = accum; + } + } + + // Rest of the LSTM cell: tanh and logistic math functions, and some adds + // and muls, all done in 16-bit fixed-point. + const int outer_size = batches * width * height; + for (int b = 0; b < outer_size; ++b) { + for (int c = 0; c < output_depth; ++c) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]); + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]); + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]); + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]); + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]); + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal tanh node, still in fixed-point. + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Store the new internal state back to memory, as 16-bit integers. + output_state_data_int16[b * output_depth + c] = new_state.raw(); + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16 rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int16 clamped_output_activ = + std::max(-128, std::min(127, rescaled_output_activ)); + output_activ_data_uint8[b * output_depth + c] = + 128 + clamped_output_activ; + } + } +} + template void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -2851,6 +3098,156 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, output_map.array() = input_map.array().tanh(); } +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + // Note that this is almost the exact same code as in Logistic(). + gemmlowp::ScopedProfilingLabel label("Tanh"); + /* batches */ MatchingArraySize(input_dims, 3, output_dims, 3); + /* height */ MatchingArraySize(input_dims, 2, output_dims, 2); + /* width */ MatchingArraySize(input_dims, 1, output_dims, 1); + /* depth */ MatchingArraySize(input_dims, 0, output_dims, 0); + const int size = RequiredBufferSizeForDims(input_dims); + + int c = 0; + int32_t output_zero_point = 128; +#ifdef USE_NEON + // Handle 16 values at a time + for (; c <= size - 16; c += 16) { + // Read input uint8 values, cast to int16 and subtract input_zero_point + uint8x16_t input_val_u8 = vld1q_u8(input_data + c); + int16x8_t input_val_centered_0 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + int16x8_t input_val_centered_1 = + vsubq_s16(vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(input_val_u8))), + vdupq_n_s16(input_zero_point)); + + // Prepare the bit masks that we will use at the end to implement the logic + // that was expressed in the scalar code with branching: + // if (input_val_centered < -input_range_radius) { + // output_val = 0; + // } else if (input_val_centered > input_range_radius) { + // output_val = 255; + // } else { + // ... + uint16x8_t mask_rightclamp_0 = + vcgtq_s16(input_val_centered_0, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_rightclamp_1 = + vcgtq_s16(input_val_centered_1, vdupq_n_s16(input_range_radius)); + uint16x8_t mask_leftclamp_0 = + vcgeq_s16(input_val_centered_0, vdupq_n_s16(-input_range_radius)); + uint16x8_t mask_leftclamp_1 = + vcgeq_s16(input_val_centered_1, vdupq_n_s16(-input_range_radius)); + uint8x16_t mask_rightclamp = vcombine_u8(vshrn_n_u16(mask_rightclamp_0, 8), + vshrn_n_u16(mask_rightclamp_1, 8)); + uint8x16_t mask_leftclamp = vcombine_u8(vshrn_n_u16(mask_leftclamp_0, 8), + vshrn_n_u16(mask_leftclamp_1, 8)); + + // This performs what is expressed in the scalar code as + // const int32 input_val_rescaled = + // MultiplyByQuantizedMultiplierGreaterThanOne( + // input_val_centered, input_multiplier, input_left_shift); + int32x4_t input_val_rescaled_0 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_1 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_0)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_2 = + vshlq_s32(vmovl_s16(vget_low_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + int32x4_t input_val_rescaled_3 = + vshlq_s32(vmovl_s16(vget_high_s16(input_val_centered_1)), + vdupq_n_s32(input_left_shift)); + input_val_rescaled_0 = + vqrdmulhq_n_s32(input_val_rescaled_0, input_multiplier); + input_val_rescaled_1 = + vqrdmulhq_n_s32(input_val_rescaled_1, input_multiplier); + input_val_rescaled_2 = + vqrdmulhq_n_s32(input_val_rescaled_2, input_multiplier); + input_val_rescaled_3 = + vqrdmulhq_n_s32(input_val_rescaled_3, input_multiplier); + + // Invoke gemmlowp::tanh on FixedPoint wrapping int32x4_t + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4_0 = + FixedPoint4::FromRaw(input_val_rescaled_0); + const FixedPoint4 input_val_f4_1 = + FixedPoint4::FromRaw(input_val_rescaled_1); + const FixedPoint4 input_val_f4_2 = + FixedPoint4::FromRaw(input_val_rescaled_2); + const FixedPoint4 input_val_f4_3 = + FixedPoint4::FromRaw(input_val_rescaled_3); + const FixedPoint0 output_val_f0_0 = gemmlowp::tanh(input_val_f4_0); + const FixedPoint0 output_val_f0_1 = gemmlowp::tanh(input_val_f4_1); + const FixedPoint0 output_val_f0_2 = gemmlowp::tanh(input_val_f4_2); + const FixedPoint0 output_val_f0_3 = gemmlowp::tanh(input_val_f4_3); + + // Divide by 2^24 as in the scalar code + using gemmlowp::RoundingDivideByPOT; + int32x4_t output_val_s32_0 = RoundingDivideByPOT(output_val_f0_0.raw(), 24); + int32x4_t output_val_s32_1 = RoundingDivideByPOT(output_val_f0_1.raw(), 24); + int32x4_t output_val_s32_2 = RoundingDivideByPOT(output_val_f0_2.raw(), 24); + int32x4_t output_val_s32_3 = RoundingDivideByPOT(output_val_f0_3.raw(), 24); + + // Add the output zero point + int32x4_t output_zero_point_s32 = vdupq_n_s32(output_zero_point); + output_val_s32_0 = vaddq_s32(output_val_s32_0, output_zero_point_s32); + output_val_s32_1 = vaddq_s32(output_val_s32_1, output_zero_point_s32); + output_val_s32_2 = vaddq_s32(output_val_s32_2, output_zero_point_s32); + output_val_s32_3 = vaddq_s32(output_val_s32_3, output_zero_point_s32); + + // Cast output values to uint8, saturating + int16x8_t output_val_s16_0 = vcombine_s16(vqmovn_s32(output_val_s32_0), + vqmovn_s32(output_val_s32_1)); + int16x8_t output_val_s16_1 = vcombine_s16(vqmovn_s32(output_val_s32_2), + vqmovn_s32(output_val_s32_3)); + uint8x16_t output_val_u8 = vcombine_u8(vqmovun_s16(output_val_s16_0), + vqmovun_s16(output_val_s16_1)); + + // Perform the bit-masking with the bit masks computed at the beginning, + // see the comment there. + output_val_u8 = vorrq_u8(output_val_u8, mask_rightclamp); + output_val_u8 = vandq_u8(output_val_u8, mask_leftclamp); + + // Store back to memory + vst1q_u8(output_data + c, output_val_u8); + } +#endif + // Leftover loop: handle one value at a time with scalar code. + for (; c < size; ++c) { + const uint8 input_val_u8 = input_data[c]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered < -input_range_radius) { + output_val = 0; + } else if (input_val_centered > input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + output_val_s32 += output_zero_point; + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[c] = output_val; + } +} inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -3323,7 +3720,7 @@ inline void ResizeBilinearGeneric(const float* input_data, inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, bool align_corners) { gemmlowp::ScopedProfilingLabel label("ResizeBilinear"); int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); @@ -3338,13 +3735,20 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; // Specialize for 2x2 upsample. - if (output_height == 2 * input_height && output_width == 2 * input_width) { + if (!align_corners && output_height == 2 * input_height && + output_width == 2 * input_width) { ResizeBilinear2x2(input_data, input_dims, output_data, output_dims, batches, input_height, input_width, depth, output_height, output_width); } else { float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; + if (align_corners && output_height > 1) { + height_scale = static_cast(input_height - 1) / (output_height - 1); + } + if (align_corners && output_width > 1) { + width_scale = static_cast(input_width - 1) / (output_width - 1); + } ResizeBilinearGeneric(input_data, input_dims, output_data, output_dims, batches, input_height, input_width, depth, @@ -3353,6 +3757,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -3381,10 +3794,11 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, for (int out_h = 0; out_h < output_height; ++out_h) { for (int out_w = 0; out_w < output_width; ++out_w) { T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height < padding_top || - out_h * block_shape_height >= padding_top + input_height || - out_w * block_shape_width < padding_left || - out_w * block_shape_width >= padding_left + input_width) { + if (out_h * block_shape_height + shift_h < padding_top || + out_h * block_shape_height + shift_h >= + padding_top + input_height || + out_w * block_shape_width + shift_w < padding_left || + out_w * block_shape_width + shift_w >= padding_left + input_width) { memset(out, 0, depth * sizeof(T)); } else { const T* in = @@ -3704,6 +4118,43 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, auto max_value = input2_data[0]; output_map.array() = input1_map.array().max(max_value); } + +template +void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims) { + gemmlowp::ScopedProfilingLabel label("ArgMax"); + + // The current ArgMax implemention can only determine the index of the maximum + // value in the last dimension. So the axis argument is ignored. + TFLITE_DCHECK_EQ(axis[0], 3); + + // For ArgMax, the number of output dimensions = (number of input dimensions - + // 1). For the sake of simplicity, the output dimensions are equal to the + // input dimensions here. We enforce the constraint that the last dimension + // must always be 1. + TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = ArraySize(input_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + auto max_value = input_data[Offset(input_dims, 0, x, y, b)]; + int max_index = 0; + for (int d = 1; d < depth; ++d) { + const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)]; + if (curr_value > max_value) { + max_value = curr_value; + max_index = d; + } + } + output_data[Offset(output_dims, 0, x, y, b)] = max_index; + } + } + } +} + } // namespace optimized_ops } // namespace tflite @@ -3712,4 +4163,4 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, #pragma GCC diagnostic pop #endif -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h index f8be99e82fb8721ced7a3e5da686b20ce241ea2d..4e324a5e107cf5a90c0042331899edab831c8e51 100644 --- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h +++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h @@ -15,7 +15,7 @@ limitations under the License. #ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ #define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_ -// TDOD(ghodrat): Remove this header file and the dependency to internal data +// TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc index 98f2e365c5249a6c28673fc185ebec34cc2105b2..18be6777a5caeb45a4ffabd8b7f1793de7b053f8 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.cc @@ -22,27 +22,20 @@ limitations under the License. namespace tflite { -void QuantizeMultiplierSmallerThanOne(double double_multiplier, - int32_t* quantized_multiplier, - int* right_shift) { - TFLITE_CHECK(double_multiplier >= 0.); - TFLITE_CHECK(double_multiplier < 1.); +void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, + int* shift) { if (double_multiplier == 0.) { *quantized_multiplier = 0; - *right_shift = 0; + *shift = 0; return; } - TFLITE_CHECK(double_multiplier > 0.); - const double q = std::frexp(double_multiplier, right_shift); - *right_shift *= -1; - + const double q = std::frexp(double_multiplier, shift); auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); TFLITE_CHECK(q_fixed <= (1ll << 31)); if (q_fixed == (1ll << 31)) { q_fixed /= 2; - --*right_shift; + ++*shift; } - TFLITE_CHECK_GE(*right_shift, 0); TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); *quantized_multiplier = static_cast(q_fixed); } @@ -50,17 +43,20 @@ void QuantizeMultiplierSmallerThanOne(double double_multiplier, void QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, int* left_shift) { - TFLITE_CHECK(double_multiplier > 1.); - const double q = std::frexp(double_multiplier, left_shift); - auto q_fixed = static_cast(TfLiteRound(q * (1ll << 31))); - TFLITE_CHECK(q_fixed <= (1ll << 31)); - if (q_fixed == (1ll << 31)) { - q_fixed /= 2; - ++*left_shift; - } + TFLITE_CHECK_GT(double_multiplier, 1.); + QuantizeMultiplier(double_multiplier, quantized_multiplier, left_shift); TFLITE_CHECK_GE(*left_shift, 0); - TFLITE_CHECK_LE(q_fixed, std::numeric_limits::max()); - *quantized_multiplier = static_cast(q_fixed); +} + +void QuantizeMultiplierSmallerThanOne(double double_multiplier, + int32_t* quantized_multiplier, + int* right_shift) { + TFLITE_CHECK_LT(double_multiplier, 1.); + TFLITE_CHECK_GT(double_multiplier, 0.); + int shift; + QuantizeMultiplier(double_multiplier, quantized_multiplier, &shift); + TFLITE_CHECK_LE(shift, 0); + *right_shift = -shift; } void PreprocessSoftmaxScaling(double beta, double input_scale, diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util.h b/tensorflow/contrib/lite/kernels/internal/quantization_util.h index efb7191c8deb2a23ea5473ab131d2b6537202765..ba06bc0975b6847b24592daa60efe99983d03707 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util.h +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util.h @@ -20,7 +20,8 @@ limitations under the License. namespace tflite { // Decompose a double multiplier into a Q0.31 int32 representation of its -// significand, and shift representation of its exponent. +// significand, and shift representation of NEGATIVE its exponent --- +// this is intended as a RIGHT-shift. // // Restricted to the case where the multiplier < 1 (and non-negative). void QuantizeMultiplierSmallerThanOne(double double_multiplier, @@ -35,6 +36,16 @@ void QuantizeMultiplierGreaterThanOne(double double_multiplier, int32_t* quantized_multiplier, int* left_shift); +// Decompose a double multiplier into a Q0.31 int32 representation of its +// significand, and shift representation of its exponent. +// +// Handles an arbitrary positive multiplier. The 'shift' output-value is +// basically the 'floating-point exponent' of the multiplier: +// Negative for a right-shift (when the multiplier is <1), positive for a +// left-shift (when the multiplier is >1) +void QuantizeMultiplier(double double_multiplier, int32_t* quantized_multiplier, + int* shift); + // This first creates a multiplier in a double equivalent of // Q(input_integer_bits).(31-input_integer_bits) representation, with extra // precision in the double's fractional bits. It then splits the result into diff --git a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc index d6f306e2cbae3c780b3d773638ba46cd2abf02f5..19b1b408ec74b0939065b0ad10b91ecfc2cd4765 100644 --- a/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc +++ b/tensorflow/contrib/lite/kernels/internal/quantization_util_test.cc @@ -31,7 +31,7 @@ TEST(QuantizationUtilTest, QuantizeMultiplierSmallerThanOne) { }; EXPECT_DEATH(quantize(-0.1), ""); - EXPECT_THAT(quantize(0.0), Pair(0, 0)); + EXPECT_DEATH(quantize(0.0), ""); EXPECT_THAT(quantize(0.25), Pair(1073741824, 1)); // Around 0.5 we can see the change in exponent and how we try hard to diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h index 8e0f234545e43dd8b2412e065aaecad8325a1182..9aabee5000c29ed97fcf7e874d661e72fd768f84 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_float.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ #include "tensorflow/contrib/lite/kernels/internal/common.h" #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" @@ -112,4 +112,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims, } // end namespace reference_ops } // end namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_FLOAT_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h index 8a80558b32f2858778460956cd9f57617674e21e..e9b6baeaee87d22aef238410bc9f447509a81c47 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/depthwiseconv_uint8.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ #include @@ -135,4 +135,4 @@ void DepthwiseConv(const uint8* input_data, const Dims<4>& input_dims, } // end namespace reference_ops } // end namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_DEPTHWISECONV_UINT8_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h index c2ab78000b81485f037c507933cd024e70f39850..c05c21b472b05f2cbe133adf94d91ab0c6d9ef40 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/portable_tensor_utils.h @@ -12,17 +12,17 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ -// TDOD(ghodrat): Remove this header file and the dependency to internal data +// TODO(ghodrat): Remove this header file and the dependency to internal data // structure. #include "tensorflow/contrib/lite/builtin_op_data.h" namespace tflite { namespace tensor_utils { -// Limit a float input f betweeen +abs_limit and -abs_limit. +// Limit a float input f between +abs_limit and -abs_limit. float PortableClip(float f, float abs_limit); // Multiply a matrix by a batch vector, and store results in a batch-size @@ -186,4 +186,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h index b9ca3d5c626dff4ea8ba52949e8fea8e9b43689f..2e0376656ac286585ce967c37cbbeb66a7e29172 100644 --- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h +++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ #include #include @@ -889,10 +889,11 @@ inline void Add(int left_shift, const uint8* input1_data, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastAdd"); NdArrayDesc<4> desc1; @@ -914,15 +915,30 @@ void BroadcastAdd(const float* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] + - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] + + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastAdd(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastAdd(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastAdd(int left_shift, const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, int32 input1_multiplier, int input1_shift, @@ -1053,10 +1069,11 @@ void Mul(const float* input1_data, const Dims<4>& input1_dims, // dimensionality if the runtime code does a single loop over one dimension // that handles broadcasting as the base case. The code generator would then // generate max(D1, D2) nested for loops. -template -void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, - const float* input2_data, const Dims<4>& input2_dims, - float* output_data, const Dims<4>& output_dims) { +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T output_activation_min, T output_activation_max, + T* output_data, const Dims<4>& output_dims) { gemmlowp::ScopedProfilingLabel label("BroadcastMul"); NdArrayDesc<4> desc1; @@ -1078,15 +1095,30 @@ void BroadcastMul(const float* input1_data, const Dims<4>& input1_dims, for (int y = 0; y < ArraySize(output_dims, 2); ++y) { for (int x = 0; x < ArraySize(output_dims, 1); ++x) { for (int c = 0; c < ArraySize(output_dims, 0); ++c) { - output_data[Offset(output_dims, c, x, y, b)] = ActivationFunction( - input1_data[SubscriptToIndex(desc1, c, x, y, b)] * - input2_data[SubscriptToIndex(desc2, c, x, y, b)]); + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[SubscriptToIndex(desc1, c, x, y, b)] * + input2_data[SubscriptToIndex(desc2, c, x, y, b)], + output_activation_min, output_activation_max); } } } } } +// legacy, for compatibility with old checked-in code +template +void BroadcastMul(const T* input1_data, const Dims<4>& input1_dims, + const T* input2_data, const Dims<4>& input2_dims, + T* output_data, const Dims<4>& output_dims) { + T output_activation_min, output_activation_max; + GetActivationMinMax(Ac, &output_activation_min, &output_activation_max); + + BroadcastMul(input1_data, input1_dims, input2_data, input2_dims, + output_activation_min, output_activation_max, output_data, + output_dims); +} + inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, int32 input1_offset, const uint8* input2_data, const Dims<4>& input2_dims, int32 input2_offset, @@ -1149,6 +1181,60 @@ inline void BroadcastMul(const uint8* input1_data, const Dims<4>& input1_dims, output_data, output_dims); } +inline void Div(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] / + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + +inline void Sub(const float* input1_data, const Dims<4>& input1_dims, + const float* input2_data, const Dims<4>& input2_dims, + float output_activation_min, float output_activation_max, + float* output_data, const Dims<4>& output_dims) { + const int batches = + MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3); + const int height = + MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2); + const int width = + MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1); + const int depth = + MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + output_data[Offset(output_dims, c, x, y, b)] = + ActivationFunctionWithMinMax( + input1_data[Offset(input1_dims, c, x, y, b)] - + input2_data[Offset(input2_dims, c, x, y, b)], + output_activation_min, output_activation_max); + } + } + } + } +} + template void Concatenation(int concat_dim, const Scalar* const* input_data, const Dims<4>* const* input_dims, int inputs_count, @@ -1272,6 +1358,238 @@ inline void LstmCell(const float* input_data, const Dims<4>& input_dims, } } +// Quantized LSTM cell implementation. +// The quantization of the input, output arrays is as follows: +// - The input activations are quantized as uint8 on the interval +// [-1, 127/128]. +// The rationale for that is that that is the natural interval for output +// activations (see next point) and these need to be concatenated together. +// We could accommodate different ranges by re-scaling, but we empirically +// found that setting the input activations range to be [-1, 127/128] in the +// first place, removing the need for re-scaling, greatly improves accuracy. +// - The output activations are quantized as uint8 on the interval +// [-1, 127/128]. +// The rationale for that is that the definition of a LSTM cell makes them +// intrinsically constrained in [-1, 1]; tweaking that to [-1, 127/128] +// makes for simpler, more accurate fixed-point arithmetic. +// - The output-at-previous-timestep state array is obviously quantized as +// the output activations. +// - The internal LSTM memory (not the output-at-previous-timestep, the other +// internal state array) is int16-quantized and may use any power-of-two, +// symmetric range i.e. [-2^N, 2^N * 32767/32768] for any N, which we call +// StateIntegerBits below, see the below discussion of that template +// parameter ("The StateIntegerBits template parameter"). +// - The output of the internal fully-connected node is int16-quantized +// on the interval [-8, 8 * 32767/32768], the rationale for which is +// explained just below ("Why [-8, 8] for fully-connected output?"). +// +// +// === The StateIntegerBits template parameter === +// +// The StateIntegerBits template parameter controls the fixed-point format used +// to represent the internal memory of the LSTM cell (not the +// output-at-previous-timestep, the other internal state array). It's currently +// a template parameter so that the model can control that. The most typical +// value for StateIntegerBits is 4. Other plausible values are anywhere between +// 3 and 5. We might eventually standardize on a single supported value, e.g. 4, +// and drop that template parameter. The reason why it can't be a runtime +// parameter is that this controls the fixed-point format used, i.e. we need to +// generate actually different code based on it. In particular, we generate code +// for a fixed-point tanh() implementation for that format, which internally +// uses a fixed-point exp() implementation, which internally uses a +// barrel-shifter with a number of steps that depends on StateIntegerBits. +// Another consequence of that is that a higher value of StateIntegerBits +// results in a more expensive implementation (more barrel shifter steps +// needed). +// +// +// === Why [-8, 8] for fully-connected output? === +// +// This array is only fed to Logistic and Tanh functions, for which +// the quantized implementation will want to use fixed-point arithmetic, +// requiring a power-of-two representation interval. Thus, we should right +// away quantize this array to a power-of-two interval; otherwise, +// implementation will need to rescale that, losing any benefit that a tighter +// representation interval might otherwise yield, while introducting some +// numerical error and computational overhead. +// +// Now, Logistic and Tanh +// are nearly constant (nearly equal to their horizontal asymptotes) +// outside of a small bounded interval around 0: +// +// Logistic(4) = 1 - 1.8e-2 Tanh(4) = 1 - 6.7e-4 +// Logistic(8) = 1 - 3.4e-4 Tanh(8) = 1 - 2.3e-7 +// Logistic(16) = 1 - 1.1e-7 Tanh(16) = 1 - 2.5e-14 +// +// From this, we see that clamping to [-4, 4] would be too inaccurate +// (the error of 1.8e-2 on Logistic would be felt even in 8bit precision) +// while clamping to [-16, 16] would make no difference even in float32. +// However, for a fixed-point implementation in 16-bit integers, using 5 +// integer bits to represent the [-16, 16] range would leave only 11 +// fractional bits, giving an increment of 2^-11 = 4.9e-4 between consecutive +// representable values. Notice that that is higher than the +// worst-case clamping error with clamping to [-8, 8]: 3.4e-4 for Logistic. +// Using [-8, 8] thus seems like the better compromise overall, enjoying +// an increment of 2.4e-4 between representable values and a worst-case +// clamping error of 3.4e-4, both better than the increment of 4.9e-4 with +// [-16, 16]. +// +// Moreover, all other things being equal, it is nice to choose the narrower +// representation range, as that makes the implementation of fixed-point +// math functions a little cheaper (each integer bit requires an additional +// barrel-shifter atep in the implementation of exp(-x)). That is further +// reason to prefer [-8, 8] over [-16, 16]. The choice of [-16, 16] would make +// sense for 32-bit float or 32-bit fixed-point quantization, but we are +// aiming for 16-bit fixed-point quantization of these internal nodes here. +// +template +void LstmCell(const uint8* input_data_uint8, const Dims<4>& input_dims, + const uint8* prev_activ_data_uint8, + const Dims<4>& prev_activ_dims, const uint8* weights_data_uint8, + const Dims<4>& weights_dims, const int32* bias_data_int32, + const Dims<4>& bias_dims, const int16* prev_state_data_int16, + const Dims<4>& prev_state_dims, int16* output_state_data_int16, + const Dims<4>& output_state_dims, uint8* output_activ_data_uint8, + const Dims<4>& output_activ_dims, uint8* concat_temp_data_uint8, + const Dims<4>& concat_temp_dims, int16* activ_temp_data_int16, + const Dims<4>& activ_temp_dims, int32 weights_zero_point, + int32 accum_multiplier, int accum_shift) { + // Gather dimensions information, and perform consistency checks. + const int batches = + MatchingArraySize(input_dims, 3, prev_activ_dims, 3, prev_state_dims, 3, + output_state_dims, 3, output_activ_dims, 3); + const int height = + MatchingArraySize(input_dims, 2, prev_activ_dims, 2, prev_state_dims, 2, + output_state_dims, 2, output_activ_dims, 2); + const int width = + MatchingArraySize(input_dims, 1, prev_activ_dims, 1, prev_state_dims, 1, + output_state_dims, 1, output_activ_dims, 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 2), 1); + TFLITE_CHECK_EQ(ArraySize(weights_dims, 3), 1); + const int input_depth = ArraySize(input_dims, 0); + const int prev_activ_depth = ArraySize(prev_activ_dims, 0); + const int total_input_depth = prev_activ_depth + input_depth; + TFLITE_CHECK_EQ(ArraySize(weights_dims, 0), total_input_depth); + TFLITE_CHECK_EQ(MatchingArraySize(bias_dims, 1, bias_dims, 2, bias_dims, 3), + 1); + const int intern_activ_depth = + MatchingArraySize(weights_dims, 1, bias_dims, 0); + TFLITE_CHECK_EQ(intern_activ_depth % 4, 0); + const int output_depth = + MatchingArraySize(prev_state_dims, 0, prev_activ_dims, 0, + output_state_dims, 0, output_activ_dims, 0); + TFLITE_CHECK_EQ(output_depth, intern_activ_depth / 4); + const int fc_batches = ArraySize(activ_temp_dims, 1) * + ArraySize(activ_temp_dims, 2) * + ArraySize(activ_temp_dims, 3); + const int fc_output_depth = + MatchingArraySize(weights_dims, 1, activ_temp_dims, 0); + const int fc_accum_depth = ArraySize(weights_dims, 0); + TFLITE_CHECK_EQ(fc_output_depth, 4 * output_depth); + + // Depth-concatenate prev_activ and input data together. + uint8 const* concat_input_arrays_data[2] = {input_data_uint8, + prev_activ_data_uint8}; + Dims<4> const* concat_input_arrays_dims[2] = {&input_dims, &prev_activ_dims}; + Concatenation( + 0, concat_input_arrays_data, concat_input_arrays_dims, 2, + concat_temp_data_uint8, concat_temp_dims); + + // Implementation of the fully connected node inside the LSTM cell. + // The operands are 8-bit integers, the accumulators are internally 32bit + // integers, and the output is 16-bit fixed-point with 3 integer bits so + // the output range is [-2^3, 2^3] == [-8, 8]. The rationale for that + // is explained in the function comment above. + for (int b = 0; b < fc_batches; ++b) { + for (int out_c = 0; out_c < fc_output_depth; ++out_c) { + // Internal accumulation. + // Initialize accumulator with the bias-value. + int32 accum = bias_data_int32[out_c]; + // Accumulation loop. + for (int d = 0; d < fc_accum_depth; ++d) { + int16 input_val = concat_temp_data_uint8[b * fc_accum_depth + d] - 128; + int16 weights_val = + weights_data_uint8[out_c * fc_accum_depth + d] - weights_zero_point; + accum += input_val * weights_val; + } + // Down-scale the final int32 accumulator to the scale used by our + // (16-bit, using 3 integer bits) fixed-point format. The quantized + // multiplier and shift here have been pre-computed offline + // (e.g. by toco). + accum = + MultiplyByQuantizedMultiplier(accum, accum_multiplier, accum_shift); + // Saturate, cast to int16, and store to the temporary activations array. + accum = std::max(-32768, std::min(32767, accum)); + activ_temp_data_int16[out_c + fc_output_depth * b] = accum; + } + } + + // Rest of the LSTM cell: tanh and logistic math functions, and some adds + // and muls, all done in 16-bit fixed-point. + const int outer_size = batches * width * height; + for (int b = 0; b < outer_size; ++b) { + for (int c = 0; c < output_depth; ++c) { + // Define the fixed-point data types that we will use here. All use + // int16 as the underlying integer type i.e. all are 16-bit fixed-point. + // They only differ by the number of integral vs. fractional bits, + // determining the range of values that they can represent. + // + // F0 uses 0 integer bits, range [-1, 1]. + // This is the return type of math functions such as tanh, logistic, + // whose range is in [-1, 1]. + using F0 = gemmlowp::FixedPoint; + // F3 uses 3 integer bits, range [-8, 8]. + // This is the range of the previous fully-connected node's output, + // which is our input here. + using F3 = gemmlowp::FixedPoint; + // FS uses StateIntegerBits integer bits, range [-2^StateIntegerBits, + // 2^StateIntegerBits]. It's used to represent the internal state, whose + // number of integer bits is currently dictated by the model. See comment + // on the StateIntegerBits template parameter above. + using FS = gemmlowp::FixedPoint; + // Implementation of input gate, using fixed-point logistic function. + F3 input_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 0 * output_depth + c]); + F0 input_gate_output = gemmlowp::logistic(input_gate_input); + // Implementation of input modulation gate, using fixed-point tanh + // function. + F3 input_modulation_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 1 * output_depth + c]); + F0 input_modulation_gate_output = + gemmlowp::tanh(input_modulation_gate_input); + // Implementation of forget gate, using fixed-point logistic function. + F3 forget_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 2 * output_depth + c]); + F0 forget_gate_output = gemmlowp::logistic(forget_gate_input); + // Implementation of output gate, using fixed-point logistic function. + F3 output_gate_input = F3::FromRaw( + activ_temp_data_int16[b * fc_output_depth + 3 * output_depth + c]); + F0 output_gate_output = gemmlowp::logistic(output_gate_input); + // Implementation of internal multiplication nodes, still in fixed-point. + F0 input_times_input_modulation = + input_gate_output * input_modulation_gate_output; + FS prev_state = FS::FromRaw(prev_state_data_int16[b * output_depth + c]); + FS prev_state_times_forget_state = forget_gate_output * prev_state; + // Implementation of internal addition node, saturating. + FS new_state = gemmlowp::SaturatingAdd( + gemmlowp::Rescale(input_times_input_modulation), + prev_state_times_forget_state); + // Implementation of last internal tanh node, still in fixed-point. + F0 output_activ_int16 = output_gate_output * gemmlowp::tanh(new_state); + // Store the new internal state back to memory, as 16-bit integers. + output_state_data_int16[b * output_depth + c] = new_state.raw(); + // Down-scale the output activations to 8-bit integers, saturating, + // and store back to memory. + int16 rescaled_output_activ = + gemmlowp::RoundingDivideByPOT(output_activ_int16.raw(), 8); + int16 clamped_output_activ = + std::max(-128, std::min(127, rescaled_output_activ)); + output_activ_data_uint8[b * output_depth + c] = + 128 + clamped_output_activ; + } + } +} + template void TensorFlowSplit(const Scalar* input_data, const Dims<4>& input_dims, int outputs_count, Scalar* const* output_data, @@ -1957,6 +2275,54 @@ inline void Tanh(const float* input_data, const Dims<4>& input_dims, } } +inline void Tanh(const uint8* input_data, const Dims<4>& input_dims, + int32 input_zero_point, int32 input_range_radius, + int32 input_multiplier, int input_left_shift, + uint8* output_data, const Dims<4>& output_dims) { + const int32 output_zero_point = 128; + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = MatchingArraySize(input_dims, 0, output_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + for (int c = 0; c < depth; ++c) { + const uint8 input_val_u8 = input_data[Offset(input_dims, c, x, y, b)]; + const int32 input_val_centered = + static_cast(input_val_u8) - input_zero_point; + uint8 output_val; + if (input_val_centered <= -input_range_radius) { + output_val = 0; + } else if (input_val_centered >= input_range_radius) { + output_val = 255; + } else { + const int32 input_val_rescaled = + MultiplyByQuantizedMultiplierGreaterThanOne( + input_val_centered, input_multiplier, input_left_shift); + using FixedPoint4 = gemmlowp::FixedPoint; + using FixedPoint0 = gemmlowp::FixedPoint; + const FixedPoint4 input_val_f4 = + FixedPoint4::FromRaw(input_val_rescaled); + const FixedPoint0 output_val_f0 = gemmlowp::tanh(input_val_f4); + + using gemmlowp::RoundingDivideByPOT; + int32 output_val_s32 = RoundingDivideByPOT(output_val_f0.raw(), 24); + output_val_s32 += output_zero_point; + if (output_val_s32 == 256) { + output_val_s32 = 255; + } + TFLITE_DCHECK_GE(output_val_s32, 0); + TFLITE_DCHECK_LE(output_val_s32, 255); + output_val = static_cast(output_val_s32); + } + output_data[Offset(output_dims, c, x, y, b)] = output_val; + } + } + } + } +} + inline void Dequantize(const uint8* input_data, const Dims<4>& input_dims, int32 zero_point, double scale, float* output_data, const Dims<4>& output_dims) { @@ -2116,7 +2482,7 @@ inline void Gather(const T* input_data, const Dims<4>& input_dims, inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, const int32* output_size_data, const Dims<4>& output_size_dims, float* output_data, - const Dims<4>& output_dims) { + const Dims<4>& output_dims, bool align_corners) { int32 batches = MatchingArraySize(input_dims, 3, output_dims, 3); int32 input_height = ArraySize(input_dims, 2); int32 input_width = ArraySize(input_dims, 1); @@ -2130,6 +2496,12 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, int32 output_width = output_size_data[Offset(output_size_dims, 1, 0, 0, 0)]; float height_scale = static_cast(input_height) / output_height; float width_scale = static_cast(input_width) / output_width; + if (align_corners && output_height > 1) { + height_scale = static_cast(input_height - 1) / (output_height - 1); + } + if (align_corners && output_width > 1) { + width_scale = static_cast(input_width - 1) / (output_width - 1); + } for (int b = 0; b < batches; ++b) { for (int y = 0; y < output_height; ++y) { @@ -2157,6 +2529,15 @@ inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, } } +// legacy, for compatibility with old checked-in code +inline void ResizeBilinear(const float* input_data, const Dims<4>& input_dims, + const int32* output_size_data, + const Dims<4>& output_size_dims, float* output_data, + const Dims<4>& output_dims) { + ResizeBilinear(input_data, input_dims, output_size_data, output_size_dims, + output_data, output_dims, /*align_corners=*/false); +} + template inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, const int32* block_shape_data, @@ -2183,10 +2564,11 @@ inline void SpaceToBatchND(const T* input_data, const Dims<4>& input_dims, for (int out_h = 0; out_h < output_height; ++out_h) { for (int out_w = 0; out_w < output_width; ++out_w) { T* out = output_data + Offset(output_dims, 0, out_w, out_h, out_b); - if (out_h * block_shape_height < padding_top || - out_h * block_shape_height >= padding_top + input_height || - out_w * block_shape_width < padding_left || - out_w * block_shape_width >= padding_left + input_width) { + if (out_h * block_shape_height + shift_h < padding_top || + out_h * block_shape_height + shift_h >= + padding_top + input_height || + out_w * block_shape_width + shift_w < padding_left || + out_w * block_shape_width + shift_w >= padding_left + input_width) { memset(out, 0, depth * sizeof(T)); } else { const T* in = @@ -2275,27 +2657,60 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims, } } +inline bool LoopCondition(int index, int stop, int stride) { + return stride > 0 ? index < stop : index > stop; +} + +inline int StartIndex(int start, int stride, int dim, bool masked) { + return masked ? (stride > 0 ? 0 : dim - 1) : start; +} + +inline int StopIndex(int start, int stop, int stride, int dim, bool masked, + bool shrink_axis_masked) { + return shrink_axis_masked ? stride > 0 ? start + 1 : start - 1 + : masked ? (stride > 0 ? dim : -1) : stop; +} + template inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, - int begin_mask, int end_mask, + int begin_mask, int end_mask, int shrink_axis_mask, const std::vector& starts, const std::vector& stops, const std::vector& strides, T* output_data, const Dims<4>& output_dims) { - const int start_b = (begin_mask & 8) ? 0 : starts[3]; - const int stop_b = (end_mask & 8) ? input_dims.sizes[3] : stops[3]; - const int start_h = (begin_mask & 4) ? 0 : starts[2]; - const int stop_h = (end_mask & 4) ? input_dims.sizes[2] : stops[2]; - const int start_w = (begin_mask & 2) ? 0 : starts[1]; - const int stop_w = (end_mask & 2) ? input_dims.sizes[1] : stops[1]; - const int start_d = (begin_mask & 1) ? 0 : starts[0]; - const int stop_d = (end_mask & 1) ? input_dims.sizes[0] : stops[0]; + TFLITE_DCHECK_EQ(starts.size(), 4); + TFLITE_DCHECK_EQ(stops.size(), 4); + TFLITE_DCHECK_EQ(strides.size(), 4); + const int start_b = + StartIndex(starts[3], strides[3], input_dims.sizes[3], begin_mask & 8); + const int stop_b = + StopIndex(start_b, stops[3], strides[3], input_dims.sizes[3], + end_mask & 8, shrink_axis_mask & 8); + const int start_h = + StartIndex(starts[2], strides[2], input_dims.sizes[2], begin_mask & 4); + const int stop_h = + StopIndex(start_h, stops[2], strides[2], input_dims.sizes[2], + end_mask & 4, shrink_axis_mask & 4); + const int start_w = + StartIndex(starts[1], strides[1], input_dims.sizes[1], begin_mask & 2); + const int stop_w = + StopIndex(start_w, stops[1], strides[1], input_dims.sizes[1], + end_mask & 2, shrink_axis_mask & 2); + const int start_d = + StartIndex(starts[0], strides[0], input_dims.sizes[0], begin_mask & 1); + const int stop_d = + StopIndex(start_d, stops[0], strides[0], input_dims.sizes[0], + end_mask & 1, shrink_axis_mask & 1); T* out_ptr = output_data; - for (int in_b = start_b; in_b < stop_b; in_b += strides[3]) { - for (int in_h = start_h; in_h < stop_h; in_h += strides[2]) { - for (int in_w = start_w; in_w < stop_w; in_w += strides[1]) { - for (int in_d = start_d; in_d < stop_d; in_d += strides[0]) { + for (int in_b = start_b; LoopCondition(in_b, stop_b, strides[3]); + in_b += strides[3]) { + for (int in_h = start_h; LoopCondition(in_h, stop_h, strides[2]); + in_h += strides[2]) { + for (int in_w = start_w; LoopCondition(in_w, stop_w, strides[1]); + in_w += strides[1]) { + for (int in_d = start_d; LoopCondition(in_d, stop_d, strides[0]); + in_d += strides[0]) { *out_ptr++ = input_data[Offset(input_dims, in_d, in_w, in_h, in_b)]; } } @@ -2303,6 +2718,18 @@ inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void StridedSlice(const T* input_data, const Dims<4>& input_dims, + int begin_mask, int end_mask, + const std::vector& starts, + const std::vector& stops, + const std::vector& strides, T* output_data, + const Dims<4>& output_dims) { + StridedSlice(input_data, input_dims, begin_mask, end_mask, + /*shrink_axis_mask=*/0, starts, stops, strides, output_data, + output_dims); +} + template inline void Slice(const T* input_data, const Dims<4>& input_dims, const std::vector& begin, const std::vector& size, @@ -2335,6 +2762,72 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims, } } +template +inline void Exp(const T* input_data, const size_t num_elements, + T* output_data) { + for (size_t idx = 0; idx < num_elements; ++idx) { + output_data[idx] = exp(input_data[idx]); + } +} + +template +inline void Mean(T* input_data, const int* input_dims, const int input_num_dims, + T* output_data, const int* output_dims, + const int output_num_dims, const int* axis, + const int num_axis_dimensions, bool keep_dims, int* temp_index, + int* resolved_axis) { + // resets output data. + size_t num_outputs = 1; + for (int idx = 0; idx < output_num_dims; ++idx) { + num_outputs *= static_cast(output_dims[idx]); + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = 0; + } + // resets temp index. + for (int idx = 0; idx < input_num_dims; ++idx) { + temp_index[idx] = 0; + } + // resolves axis. + int num_resolved_axis = 0; + for (int idx = 0; idx < num_axis_dimensions; ++idx) { + int current = axis[idx]; + TFLITE_DCHECK(current < input_num_dims && current + input_num_dims >= 0); + if (current < 0) { + current += input_num_dims; + } + bool is_dup = false; + for (int j = 0; j < num_resolved_axis; ++j) { + if (resolved_axis[j] == current) { + is_dup = true; + break; + } + } + if (!is_dup) { + resolved_axis[num_resolved_axis++] = current; + } + } + // iterates through input_data. + for (bool has_next = true; has_next; + has_next = NextIndex(input_num_dims, input_dims, temp_index)) { + size_t input_offset = + ReducedOutputOffset(input_num_dims, input_dims, temp_index, 0, nullptr); + size_t output_offset = + ReducedOutputOffset(input_num_dims, input_dims, temp_index, + num_resolved_axis, resolved_axis); + output_data[output_offset] += input_data[input_offset]; + } + // takes average by num of elements added to get mean. + size_t num_elements_in_axis = 1; + for (int idx = 0; idx < num_resolved_axis; ++idx) { + num_elements_in_axis *= static_cast(input_dims[resolved_axis[idx]]); + } + for (size_t idx = 0; idx < num_outputs; ++idx) { + output_data[idx] = static_cast(static_cast(output_data[idx]) / + num_elements_in_axis); + } +} + template inline void Mean(const T* input_data, const Dims<4>& input_dims, const std::vector& reduction_indices, T* output_data, @@ -2449,7 +2942,70 @@ void TensorFlowMaximum(const T* input1_data, const Dims<4>& input1_dims, } } +template +void ArgMax(const T3* axis, const T1* input_data, const Dims<4>& input_dims, + T2* output_data, const Dims<4>& output_dims) { + // The current ArgMax implemention can only determine the index of the maximum + // value in the last dimension. So the axis argument is ignored. + TFLITE_DCHECK_EQ(axis[0], 3); + + // For ArgMax, the number of output dimensions = (number of input dimensions - + // 1). For the sake of simplicity, the output dimensions are equal to the + // input dimensions here. We enforce the constraint that the last dimension + // must always be 1. + TFLITE_DCHECK_EQ(ArraySize(output_dims, 0), 1); + const int batches = MatchingArraySize(input_dims, 3, output_dims, 3); + const int height = MatchingArraySize(input_dims, 2, output_dims, 2); + const int width = MatchingArraySize(input_dims, 1, output_dims, 1); + const int depth = ArraySize(input_dims, 0); + for (int b = 0; b < batches; ++b) { + for (int y = 0; y < height; ++y) { + for (int x = 0; x < width; ++x) { + auto max_value = input_data[Offset(input_dims, 0, x, y, b)]; + int max_index = 0; + for (int d = 1; d < depth; ++d) { + const auto& curr_value = input_data[Offset(input_dims, d, x, y, b)]; + if (curr_value > max_value) { + max_value = curr_value; + max_index = d; + } + } + output_data[Offset(output_dims, 0, x, y, b)] = max_index; + } + } + } +} + +template +void Transpose(const T* input, const Dims<4>& input_dims, T* output, + const Dims<4>& output_dims, int* permuted_axes) { + int out_sizes[4]; + // Compute the inverse permutation array so we can do an output centered + // transpose. Also, check to make sure output_dims is matching input_dims. + for (int k = 0; k < 4; k++) { + out_sizes[k] = + MatchingArraySize(input_dims, permuted_axes[k], output_dims, k); + } + + // Naive transpose loop (iterate on output index and compute input index). + int o[4]; // loop index (on output). + int i[4]; + for (o[3] = 0; o[3] < out_sizes[3]; o[3]++) { + i[permuted_axes[3]] = o[3]; + for (o[2] = 0; o[2] < out_sizes[2]; o[2]++) { + i[permuted_axes[2]] = o[2]; + for (o[1] = 0; o[1] < out_sizes[1]; o[1]++) { + i[permuted_axes[1]] = o[1]; + for (o[0] = 0; o[0] < out_sizes[0]; o[0]++) { + i[permuted_axes[0]] = o[0]; + output[Offset(output_dims, o)] = input[Offset(input_dims, i)]; + } + } + } + } +} + } // namespace reference_ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_REFERENCE_REFERENCE_OPS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/round.h b/tensorflow/contrib/lite/kernels/internal/round.h index 38525b0e208b852343849096ac68cbfc9ef3e389..f299d0bd8733dc603c4950091c8ac3d7890548a7 100644 --- a/tensorflow/contrib/lite/kernels/internal/round.h +++ b/tensorflow/contrib/lite/kernels/internal/round.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ #include @@ -36,4 +36,4 @@ inline T TfLiteRound(const T x) { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_ROUND_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor.h b/tensorflow/contrib/lite/kernels/internal/tensor.h index ee4111e0416560d94d513c528971bdf3bf819662..62e38e0d4c3e023d0ed2242fc9438b096b86dc59 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ #include #include "tensorflow/contrib/lite/context.h" @@ -41,8 +41,7 @@ inline int32_t* GetTensorData(TfLiteTensor* tensor) { template <> inline int64_t* GetTensorData(TfLiteTensor* tensor) { - return tensor != nullptr ? reinterpret_cast(tensor->data.raw) - : nullptr; + return tensor != nullptr ? tensor->data.i64 : nullptr; } inline int RemapDim(int max_dimensions, int d) { @@ -82,6 +81,51 @@ inline Dims<4> GetTensorDims(const TfLiteTensor* tensor) { return GetTensorDims(dims->data, dims->size); } +// A list of tensors in a format that can be used by kernels like split and +// concatenation. +template +class VectorOfTensors { + public: + // Build with the tensors in 'tensor_list'. + VectorOfTensors(const TfLiteContext& context, + const TfLiteIntArray& tensor_list) { + int num_tensors = tensor_list.size; + + all_data_.reserve(num_tensors); + all_dims_.reserve(num_tensors); + all_dims_ptr_.reserve(num_tensors); + + for (int i = 0; i < num_tensors; ++i) { + TfLiteTensor* t = &context.tensors[tensor_list.data[i]]; + all_data_.push_back(GetTensorData(t)); + all_dims_.push_back(GetTensorDims(t)); + } + + // Taking the pointer from inside a std::vector is only OK if the vector is + // never modified, so we populate all_dims in the previous loop and then we + // are free to grab iterators here. + for (int i = 0; i < num_tensors; ++i) { + all_dims_ptr_.push_back(&all_dims_[i]); + } + } + // Return a pointer to the data pointers of all tensors in the list. For + // example: + // float* const* f = v.data(); + // f[0][1] is the second element of the first tensor. + T* const* data() const { return all_data_.data(); } + + // Return a pointer the dim pointers of all tensors in the list. For + // example: + // const Dims<4>* const* d = v.dims(); + // dims[1] are the dimensions of the second tensor in the list. + const Dims<4>* const* dims() const { return all_dims_ptr_.data(); } + + private: + std::vector all_data_; + std::vector> all_dims_; + std::vector*> all_dims_ptr_; +}; + } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc index 904a97803a6a9ba369c1e64c711b12d19ffc10c4..f4181b18a8f46fd9bef4b81a210a6b8134a4e9d0 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/internal/common.h" #ifndef USE_NEON #if defined(__ARM_NEON__) || defined(__ARM_NEON) diff --git a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h index 0e69ef5982f01e364d865684652d1dfecab6fee3..40d144979b2f965725db86ff311e90f39438802f 100644 --- a/tensorflow/contrib/lite/kernels/internal/tensor_utils.h +++ b/tensorflow/contrib/lite/kernels/internal/tensor_utils.h @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ #include "tensorflow/contrib/lite/builtin_op_data.h" namespace tflite { namespace tensor_utils { -// Limit a float input f betweeen +abs_limit and -abs_limit. +// Limit a float input f between +abs_limit and -abs_limit. float Clip(float f, float abs_limit); // Multiply a matrix by a batch vector, and store results in a batch-size @@ -113,4 +113,4 @@ void ReductionSumVector(const float* input_vector, float* output_vector, } // namespace tensor_utils } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TENSOR_UTILS_H_ diff --git a/tensorflow/contrib/lite/kernels/internal/types.h b/tensorflow/contrib/lite/kernels/internal/types.h index 07f1cb40045fff3ae47ed4efa6ec43b0cb88a0a7..afe131b06ec41201395e80aa5415fd7db990f8d4 100644 --- a/tensorflow/contrib/lite/kernels/internal/types.h +++ b/tensorflow/contrib/lite/kernels/internal/types.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ #include "tensorflow/contrib/lite/kernels/internal/compatibility.h" @@ -27,6 +27,58 @@ struct Dims { int strides[N]; }; +// Gets next index to iterate through a multidimensional array. +inline bool NextIndex(const int num_dims, const int* dims, int* current) { + TFLITE_DCHECK_GT(num_dims, 0); + TFLITE_DCHECK(dims != nullptr); + TFLITE_DCHECK(current != nullptr); + int carry = 1; + for (int idx = num_dims - 1; idx >= 0; --idx) { + int current_val = current[idx] + carry; + TFLITE_DCHECK_GE(dims[idx], current_val); + if (dims[idx] == current_val) { + current[idx] = 0; + } else { + current[idx] = current_val; + carry = 0; + break; + } + } + return (carry == 0); +} + +// Gets offset of index if reducing on axis. When reducing, the flattened offset +// will not change, if the input index changes on the given axis. For example, +// if you have a 3D tensor and you are reducing to 2D by eliminating axis 0, +// then index (0, 1, 2) and index (1, 1, 2) will map to the same flattened +// offset. +// TODO(kanlig): uses Dims to represent dimensions. +inline size_t ReducedOutputOffset(const int num_dims, const int* dims, + const int* index, const int num_axis, + const int* axis) { + TFLITE_DCHECK_GT(num_dims, 0); + TFLITE_DCHECK(dims != nullptr); + TFLITE_DCHECK(index != nullptr); + size_t offset = 0; + for (int idx = 0; idx < num_dims; ++idx) { + // if we need to skip this axis + bool is_axis = false; + if (axis != nullptr) { + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (idx == axis[axis_idx]) { + is_axis = true; + break; + } + } + } + if (!is_axis) { + offset = offset * static_cast(dims[idx]) + + static_cast(index[idx]); + } + } + return offset; +} + inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { TFLITE_DCHECK(i0 >= 0 && i0 < dims.sizes[0]); TFLITE_DCHECK(i1 >= 0 && i1 < dims.sizes[1]); @@ -36,6 +88,10 @@ inline int Offset(const Dims<4>& dims, int i0, int i1, int i2, int i3) { i3 * dims.strides[3]; } +inline int Offset(const Dims<4>& dims, int* index) { + return Offset(dims, index[0], index[1], index[2], index[3]); +} + // Get array size, DCHECKing that the dim index is in range. template int ArraySize(const Dims& array, int index) { @@ -78,4 +134,4 @@ bool IsPackedWithoutStrides(const Dims& dims) { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_TYPES_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util.cc b/tensorflow/contrib/lite/kernels/kernel_util.cc index b0546c00cf977af5f722a802866448b0cb293b8d..955e8c5764c6adad37a0009f4ddf8accb437b174 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.cc +++ b/tensorflow/contrib/lite/kernels/kernel_util.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/kernels/kernel_util.h" + #include #include +#include + #include "tensorflow/contrib/lite/kernels/internal/round.h" namespace tflite { @@ -84,4 +87,27 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, } } +bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2) { + return TfLiteIntArrayEqual(input1->dims, input2->dims); +} + +TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, + TfLiteTensor* input1, + TfLiteTensor* input2, + TfLiteIntArray** output_shape) { + int64_t dims1 = NumDimensions(input1); + int64_t dims2 = NumDimensions(input2); + int64_t out_dims = std::max(dims1, dims2); + std::unique_ptr shape( + TfLiteIntArrayCreate(out_dims), TfLiteIntArrayFree); + for (int i = 0; i < out_dims; ++i) { + int64_t d1 = i >= dims1 ? 1 : SizeOfDimension(input1, dims1 - i - 1); + int64_t d2 = i >= dims2 ? 1 : SizeOfDimension(input2, dims2 - i - 1); + TF_LITE_ENSURE(context, d1 == d2 || d1 == 1 || d2 == 1); + shape->data[out_dims - i - 1] = std::max(d1, d2); + } + *output_shape = shape.release(); + return kTfLiteOk; +} + } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/kernel_util.h b/tensorflow/contrib/lite/kernels/kernel_util.h index 25556ae4567aca45b3bfe4ba02b1cb58331d239d..28f53b9fbbc5620f2fab5c73e40bed8af4af5f1e 100644 --- a/tensorflow/contrib/lite/kernels/kernel_util.h +++ b/tensorflow/contrib/lite/kernels/kernel_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ #include "tensorflow/contrib/lite/builtin_op_data.h" #include "tensorflow/contrib/lite/context.h" @@ -35,6 +35,14 @@ inline TfLiteTensor* GetOutput(TfLiteContext* context, TfLiteNode* node, inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; } inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; } +inline int64_t NumElements(const TfLiteTensor* t) { + int64_t count = 1; + for (int i = 0; i < NumDimensions(t); ++i) { + count *= SizeOfDimension(t, i); + } + return count; +} + inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, const TfLiteNode* node, int index) { const bool use_tensor = node->inputs->data[index] != kOptionalTensor; @@ -44,6 +52,25 @@ inline TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context, return nullptr; } +// Determines whether tensor is constant. +inline bool IsConstantTensor(TfLiteTensor* tensor) { + return tensor->allocation_type == kTfLiteMmapRo; +} + +// Determines whether tensor is dynamic. Note that a tensor can be non-const and +// not dynamic. This function specificially checks for a dynamic tensor. +inline bool IsDynamicTensor(TfLiteTensor* tensor) { + return tensor->allocation_type == kTfLiteDynamic; +} + +// Sets tensor to dynamic. +inline void SetTensorToDynamic(TfLiteTensor* tensor) { + if (tensor->allocation_type != kTfLiteDynamic) { + tensor->allocation_type = kTfLiteDynamic; + tensor->data.raw = nullptr; + } +} + // Calculates the multiplication factor for a quantized convolution (or // quantized depthwise convolution) involving the given tensors. Returns an // error if the scales of the tensors are not compatible. @@ -60,6 +87,15 @@ void CalculateActivationRangeFloat(TfLiteFusedActivation activation, float* activation_min, float* activation_max); +// Return true if the given tensors have the same shape. +bool HaveSameShapes(TfLiteTensor* input1, TfLiteTensor* input2); + +// Calculate the output_shape that is necessary for element-wise operations +// with broadcasting involving the two input tensors. +TfLiteStatus CalculateShapeForBroadcast(TfLiteContext* context, + TfLiteTensor* input1, + TfLiteTensor* input2, + TfLiteIntArray** output_shape); } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_KERNEL_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/kernel_util_test.cc b/tensorflow/contrib/lite/kernels/kernel_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c65b68970f6853e17af3a70aad7a2bc982a1ee60 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/kernel_util_test.cc @@ -0,0 +1,152 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/kernels/kernel_util.h" + +#include +#include +#include "tensorflow/contrib/lite/testing/util.h" + +namespace tflite { +namespace { + +void ReportError(TfLiteContext* context, const char* format, ...) {} + +class KernelUtilTest : public ::testing::Test { + public: + KernelUtilTest() { + context_.ReportError = ReportError; + + tensor1_.dims = nullptr; + tensor2_.dims = nullptr; + tensor1_.allocation_type = kTfLiteMmapRo; + tensor2_.allocation_type = kTfLiteMmapRo; + } + ~KernelUtilTest() { + TfLiteTensorFree(&tensor1_); + TfLiteTensorFree(&tensor2_); + } + + void SetShape(TfLiteTensor* tensor, std::initializer_list dims) { + TfLiteTensorFree(tensor); + tensor->dims = TfLiteIntArrayCreate(dims.size()); + int i = 0; + for (int d : dims) { + tensor->dims->data[i] = d; + ++i; + } + } + + std::vector GetShape(TfLiteIntArray* dims) { + std::vector result; + for (int i = 0; i < dims->size; ++i) { + result.push_back(dims->data[i]); + } + return result; + } + + protected: + TfLiteContext context_; + TfLiteTensor tensor1_; + TfLiteTensor tensor2_; +}; + +TEST_F(KernelUtilTest, SameShapeEmpty) { + EXPECT_TRUE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor1_, {1, 2, 3}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {1, 2}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {1, 2, 3, 4}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {1, 2, 3}); + EXPECT_TRUE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor2_, {}); + EXPECT_FALSE(HaveSameShapes(&tensor1_, &tensor2_)); + + SetShape(&tensor1_, {}); + EXPECT_TRUE(HaveSameShapes(&tensor1_, &tensor2_)); +} + +TEST_F(KernelUtilTest, BroadcastShapeIncompatibleDim) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {1, 3}); + EXPECT_NE(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_EQ(output, nullptr); +} + +TEST_F(KernelUtilTest, BroadcastShapeOnes) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 1}); + SetShape(&tensor2_, {1, 3}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + TfLiteIntArrayFree(output); + + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {1, 1}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + TfLiteIntArrayFree(output); +} + +TEST_F(KernelUtilTest, BroadcastShapeScalars) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(1, 2)); + TfLiteIntArrayFree(output); + + SetShape(&tensor1_, {}); + SetShape(&tensor2_, {2}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(2)); + TfLiteIntArrayFree(output); +} + +TEST_F(KernelUtilTest, BroadcastShapeDifferentSizes) { + TfLiteIntArray* output = nullptr; + SetShape(&tensor1_, {1, 2}); + SetShape(&tensor2_, {3, 1, 1}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(3, 1, 2)); + TfLiteIntArrayFree(output); + + SetShape(&tensor1_, {1, 2, 3, 4}); + SetShape(&tensor2_, {1, 3, 1}); + EXPECT_EQ(kTfLiteOk, CalculateShapeForBroadcast(&context_, &tensor1_, + &tensor2_, &output)); + EXPECT_THAT(GetShape(output), ::testing::ElementsAre(1, 2, 3, 4)); + TfLiteIntArrayFree(output); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/l2norm.cc b/tensorflow/contrib/lite/kernels/l2norm.cc index f43aa372b6398a38e57dd38f3d7c7db2bd3aefc1..ee8bfe56d95e9f383ef49b40b8f58b63d61da3e1 100644 --- a/tensorflow/contrib/lite/kernels/l2norm.cc +++ b/tensorflow/contrib/lite/kernels/l2norm.cc @@ -43,8 +43,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - // TODO(ahentz): Our current implementations rely on the inputs being 4D. - TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE(context, NumDimensions(input) <= 4); // TODO(ahentz): Our current implementations only support float32. TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); @@ -54,12 +53,7 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // activations. TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone); - TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); - output_size->data[0] = input->dims->data[0]; - output_size->data[1] = input->dims->data[1]; - output_size->data[2] = input->dims->data[2]; - output_size->data[3] = input->dims->data[3]; - + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims); return context->ResizeTensor(context, output, output_size); } diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index b1db89b8bd3474ac868d7215e4a0de12088c48ef..30e103f3303484c339ef98e6a68e0438291c102f 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -57,7 +57,7 @@ TEST(L2NormOpTest, SimpleTest) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc index 63a8b0a3d0186def7da2c9f31481721f1a55281c..d75ce258a04c820d8f82735988c01d0154ef36f2 100644 --- a/tensorflow/contrib/lite/kernels/local_response_norm_test.cc +++ b/tensorflow/contrib/lite/kernels/local_response_norm_test.cc @@ -95,7 +95,7 @@ TEST(LocalResponseNormOpTest, SmallRadius) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc index 1011927848d586c8541fb694914b5eee123cb8dc..414d728dfc153058ec878d3c766f58e86815cd3f 100644 --- a/tensorflow/contrib/lite/kernels/lsh_projection_test.cc +++ b/tensorflow/contrib/lite/kernels/lsh_projection_test.cc @@ -117,7 +117,7 @@ TEST(LSHProjectionOpTest2, Sparse3DInputs) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/lstm_test.cc b/tensorflow/contrib/lite/kernels/lstm_test.cc index be4c7ddbf88fc902368cda13aff72f5aecb9dac4..c068286b0d84bcb51ebb0e239350a42863de6523 100644 --- a/tensorflow/contrib/lite/kernels/lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/lstm_test.cc @@ -1081,8 +1081,7 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/mean.cc b/tensorflow/contrib/lite/kernels/mean.cc new file mode 100644 index 0000000000000000000000000000000000000000..aff19581ea56f94c08638b7b388ae181f566cf4f --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mean.cc @@ -0,0 +1,233 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace mean { + +// This file has reference implementation of Mean. +enum KernelType { + kReference, +}; + +struct MeanContext { + MeanContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + axis = GetInput(context, node, 1); + output = GetOutput(context, node, 0); + } + TfLiteMeanParams* params; + TfLiteTensor* input; + TfLiteTensor* axis; + TfLiteTensor* output; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + // Creates two temp tensors to store index and axis for internal + // implementation only. + auto* scratch_tensor_index = new int; + context->AddTensors(context, 2, scratch_tensor_index); + return scratch_tensor_index; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + +// Resizes the temp tensor that stores resolved axis. +TfLiteStatus ResizeTempAxis(TfLiteContext* context, MeanContext* op_context, + TfLiteTensor* resolved_axis) { + TfLiteIntArray* axis_size = TfLiteIntArrayCreate(1); + axis_size->data[0] = static_cast(NumElements(op_context->axis)); + return context->ResizeTensor(context, resolved_axis, axis_size); +} + +// Resizes output array based on the input size and resolved axis. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + MeanContext* op_context) { + size_t num_axis = NumElements(op_context->axis); + const TfLiteIntArray* input_dims = op_context->input->dims; + int input_num_dims = NumDimensions(op_context->input); + const int* axis = GetTensorData(op_context->axis); + if (op_context->params->keep_dims) { + TfLiteIntArray* output_dims = TfLiteIntArrayCreate(input_num_dims); + for (int idx = 0; idx < input_num_dims; ++idx) { + bool is_axis = false; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { + is_axis = true; + break; + } + } + if (is_axis) { + output_dims->data[idx] = 1; + } else { + output_dims->data[idx] = input_dims->data[idx]; + } + } + return context->ResizeTensor(context, op_context->output, output_dims); + } else { + // Calculates size of reducing axis. + int num_reduce_axis = num_axis; + for (int i = 0; i < num_axis; ++i) { + int current = axis[i]; + if (current < 0) { + current += input_num_dims; + } + TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims); + for (int j = 0; j < i; ++j) { + int previous = axis[j]; + if (previous < 0) { + previous += input_num_dims; + } + if (current == previous) { + --num_reduce_axis; + break; + } + } + } + // Determines output dimensions. + TfLiteIntArray* output_dims = + TfLiteIntArrayCreate(input_num_dims - num_reduce_axis); + int num_skip_axis = 0; + for (int idx = 0; idx < input_num_dims; ++idx) { + bool is_axis = false; + for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx) { + if (axis[axis_idx] == idx || axis[axis_idx] + input_num_dims == idx) { + ++num_skip_axis; + is_axis = true; + break; + } + } + if (!is_axis) { + output_dims->data[idx - num_skip_axis] = input_dims->data[idx]; + } + } + return context->ResizeTensor(context, op_context->output, output_dims); + } +} + +// Initializes temp tensors to store index and resolved axis. +TfLiteStatus InitializeTemporaries(TfLiteContext* context, TfLiteNode* node, + MeanContext* op_context) { + // Creates a temp index to iterate through input data. + int* scratch_tensor_index = reinterpret_cast(node->user_data); + TfLiteIntArrayFree(node->temporaries); + node->temporaries = TfLiteIntArrayCreate(2); + node->temporaries->data[0] = *scratch_tensor_index; + TfLiteTensor* scratch_tensor = &context->tensors[node->temporaries->data[0]]; + scratch_tensor->type = kTfLiteInt32; + scratch_tensor->allocation_type = kTfLiteArenaRw; + TfLiteIntArray* index_size = TfLiteIntArrayCreate(1); + index_size->data[0] = NumDimensions(op_context->input); + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, scratch_tensor, index_size)); + + // Creates a temp tensor to store resolved axis given input data. + node->temporaries->data[1] = *scratch_tensor_index + 1; + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + resolved_axis->type = kTfLiteInt32; + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + MeanContext op_context(context, node); + TF_LITE_ENSURE_OK(context, InitializeTemporaries(context, node, &op_context)); + + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + // Leaves work to Eval if axis is not constant; else resizes output. + if (!IsConstantTensor(op_context.axis)) { + SetTensorToDynamic(op_context.output); + SetTensorToDynamic(resolved_axis); + return kTfLiteOk; + } + resolved_axis->allocation_type = kTfLiteArenaRw; + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + return ResizeOutputTensor(context, &op_context); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + MeanContext op_context(context, node); + int num_axis = static_cast(NumElements(op_context.axis)); + TfLiteTensor* temp_index = &context->tensors[node->temporaries->data[0]]; + TfLiteTensor* resolved_axis = &context->tensors[node->temporaries->data[1]]; + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, + ResizeTempAxis(context, &op_context, resolved_axis)); + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_MEAN(kernel_type, data_type) \ + kernel_type::Mean<>( \ + GetTensorData(op_context.input), \ + op_context.input->dims->data, op_context.input->dims->size, \ + GetTensorData(op_context.output), \ + op_context.output->dims->data, op_context.output->dims->size, \ + GetTensorData(op_context.axis), num_axis, \ + op_context.params->keep_dims, GetTensorData(temp_index), \ + GetTensorData(resolved_axis)) + + if (kernel_type == kReference) { + switch (op_context.input->type) { + case kTfLiteFloat32: + TF_LITE_MEAN(reference_ops, float); + break; + case kTfLiteInt32: + TF_LITE_MEAN(reference_ops, int); + break; + case kTfLiteUInt8: + TF_LITE_MEAN(reference_ops, uint8_t); + break; + case kTfLiteInt64: + TF_LITE_MEAN(reference_ops, int64_t); + break; + default: + return kTfLiteError; + } + } +#undef TF_LITE_MEAN + return kTfLiteOk; +} + +} // namespace mean + +TfLiteRegistration* Register_MEAN_REF() { + static TfLiteRegistration r = {mean::Init, mean::Free, mean::Prepare, + mean::Eval}; + return &r; +} + +// TODO(kanlig): add optimized implementation of Mean. +TfLiteRegistration* Register_MEAN() { return Register_MEAN_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/mean_test.cc b/tensorflow/contrib/lite/kernels/mean_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..c4c53c2ded351849e7c458fc754c36395a25ebd0 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/mean_test.cc @@ -0,0 +1,140 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class BaseMeanOpModel : public SingleOpModel { + public: + void SetAxis(std::initializer_list data) { PopulateTensor(axis_, data); } + + template + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + template + std::vector GetOutput() { + return ExtractVector(output_); + } + + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int axis_; + int output_; +}; + +// Model for the tests case where axis is a const tensor. +class MeanOpConstModel : public BaseMeanOpModel { + public: + MeanOpConstModel(const TensorData& input, const TensorData& output, + std::initializer_list axis_shape, + std::initializer_list axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddConstInput(TensorType_INT32, axis, axis_shape); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, + CreateMeanOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +// Model for the tests case where axis is a dynamic tensor. +class MeanOpDynamicModel : public BaseMeanOpModel { + public: + MeanOpDynamicModel(const TensorData& input, const TensorData& output, + const TensorData& axis, bool keep_dims) { + input_ = AddInput(input); + axis_ = AddInput(axis); + output_ = AddOutput(output); + SetBuiltinOp(BuiltinOperator_MEAN, BuiltinOptions_MeanOptions, + CreateMeanOptions(builder_, keep_dims).Union()); + BuildInterpreter({GetShape(input_)}); + } +}; + +TEST(ConstMeanOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {2}}, + {4}, {1, 0, -3, -3}, false); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(ConstMeanOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpConstModel m({TensorType_FLOAT32, {4, 3, 2}}, {TensorType_FLOAT32, {3}}, + {2}, {0, 2}, true); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +TEST(DynamicMeanOpTest, NotKeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {2}}, {TensorType_INT32, {4}}, + false); + std::initializer_list axis = {1, 0, -3, -3}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({12, 13}))); +} + +TEST(DynamicMeanOpTest, KeepDims) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + MeanOpDynamicModel m({TensorType_FLOAT32, {4, 3, 2}}, + {TensorType_FLOAT32, {3}}, {TensorType_INT32, {2}}, + true); + std::initializer_list axis = {0, 2}; + m.SetAxis(axis); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3, 1})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear({10.5, 12.5, 14.5}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/mul.cc b/tensorflow/contrib/lite/kernels/mul.cc index 81c73f2523186c2d4072d56bdc8980fcdbb588a3..54575019de4c678ce25561cf2ac8dc80c9973363 100644 --- a/tensorflow/contrib/lite/kernels/mul.cc +++ b/tensorflow/contrib/lite/kernels/mul.cc @@ -37,7 +37,23 @@ constexpr int kInputTensor1 = 0; constexpr int kInputTensor2 = 1; constexpr int kOutputTensor = 0; +struct OpData { + bool requires_broadcast; +}; + +void* Init(TfLiteContext* context, const char* buffer, size_t length) { + auto* data = new OpData; + data->requires_broadcast = false; + return data; +} + +void Free(TfLiteContext* context, void* buffer) { + delete reinterpret_cast(buffer); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + OpData* data = reinterpret_cast(node->user_data); + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -45,43 +61,56 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); - TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); - for (int i = 0; i < NumDimensions(input1); ++i) { - TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), - SizeOfDimension(input2, i)); - } + TF_LITE_ENSURE_EQ(context, input1->type, input2->type); + output->type = input2->type; + + data->requires_broadcast = !HaveSameShapes(input1, input2); - TF_LITE_ENSURE_EQ(context, input1->type, output->type); - TF_LITE_ENSURE_EQ(context, input2->type, output->type); + TfLiteIntArray* output_size = nullptr; + if (data->requires_broadcast) { + TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast( + context, input1, input2, &output_size)); + } else { + output_size = TfLiteIntArrayCopy(input1->dims); + } - TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); return context->ResizeTensor(context, output, output_size); } template void EvalFloat(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteMulParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { float output_activation_min, output_activation_max; CalculateActivationRangeFloat(params->activation, &output_activation_min, &output_activation_max); -#define TF_LITE_MUL(type) \ - type::Mul(GetTensorData(input1), GetTensorDims(input1), \ - GetTensorData(input2), GetTensorDims(input2), \ - output_activation_min, output_activation_max, \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops); + if (data->requires_broadcast) { + TF_LITE_MUL(reference_ops, BroadcastMul); + } else { + TF_LITE_MUL(reference_ops, Mul); + } } else { - TF_LITE_MUL(optimized_ops); + if (data->requires_broadcast) { + TF_LITE_MUL(optimized_ops, BroadcastMul); + } else { + TF_LITE_MUL(optimized_ops, Mul); + } } #undef TF_LITE_MUL } template void EvalQuantized(TfLiteContext* context, TfLiteNode* node, - TfLiteMulParams* params, TfLiteTensor* input1, - TfLiteTensor* input2, TfLiteTensor* output) { + TfLiteMulParams* params, const OpData* data, + TfLiteTensor* input1, TfLiteTensor* input2, + TfLiteTensor* output) { auto input1_offset = -input1->params.zero_point; auto input2_offset = -input2->params.zero_point; auto output_offset = output->params.zero_point; @@ -98,17 +127,19 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, CalculateActivationRangeUint8(params->activation, output, &output_activation_min, &output_activation_max); -#define TF_LITE_MUL(type) \ - type::BroadcastMul(GetTensorData(input1), GetTensorDims(input1), \ - input1_offset, GetTensorData(input2), \ - GetTensorDims(input2), input2_offset, output_offset, \ - output_multiplier, output_shift, output_activation_min, \ - output_activation_max, GetTensorData(output), \ - GetTensorDims(output)); +#define TF_LITE_MUL(type, opname) \ + type::opname(GetTensorData(input1), GetTensorDims(input1), \ + input1_offset, GetTensorData(input2), \ + GetTensorDims(input2), input2_offset, output_offset, \ + output_multiplier, output_shift, output_activation_min, \ + output_activation_max, GetTensorData(output), \ + GetTensorDims(output)); + // The quantized version of Mul doesn't support activations, so we + // always use BroadcastMul. if (kernel_type == kReference) { - TF_LITE_MUL(reference_ops); + TF_LITE_MUL(reference_ops, BroadcastMul); } else { - TF_LITE_MUL(optimized_ops); + TF_LITE_MUL(optimized_ops, BroadcastMul); } #undef TF_LITE_MUL } @@ -116,15 +147,17 @@ void EvalQuantized(TfLiteContext* context, TfLiteNode* node, template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { auto* params = reinterpret_cast(node->builtin_data); + OpData* data = reinterpret_cast(node->user_data); TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); if (output->type == kTfLiteFloat32) { - EvalFloat(context, node, params, input1, input2, output); + EvalFloat(context, node, params, data, input1, input2, output); } else if (output->type == kTfLiteUInt8) { - EvalQuantized(context, node, params, input1, input2, output); + EvalQuantized(context, node, params, data, input1, input2, + output); } else { context->ReportError(context, "Mul only supports FLOAT32 and quantized UINT8 now."); @@ -137,19 +170,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } // namespace mul TfLiteRegistration* Register_MUL_REF() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } TfLiteRegistration* Register_MUL_GENERIC_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } TfLiteRegistration* Register_MUL_NEON_OPT() { - static TfLiteRegistration r = {nullptr, nullptr, mul::Prepare, + static TfLiteRegistration r = {mul::Init, mul::Free, mul::Prepare, mul::Eval}; return &r; } diff --git a/tensorflow/contrib/lite/kernels/mul_test.cc b/tensorflow/contrib/lite/kernels/mul_test.cc index 4b858e1f396252e7f7bdc231bc1e00f47277f08a..f1a30f82634631ba8320421d5b36ffe446f443fa 100644 --- a/tensorflow/contrib/lite/kernels/mul_test.cc +++ b/tensorflow/contrib/lite/kernels/mul_test.cc @@ -25,10 +25,11 @@ using ::testing::ElementsAreArray; class BaseMulOpModel : public SingleOpModel { public: - BaseMulOpModel(TensorData input, TensorData output, + BaseMulOpModel(const TensorData& input1, const TensorData& input2, + const TensorData& output, ActivationFunctionType activation_type) { - input1_ = AddInput(input); - input2_ = AddInput(input); + input1_ = AddInput(input1); + input2_ = AddInput(input2); output_ = AddOutput(output); SetBuiltinOp(BuiltinOperator_MUL, BuiltinOptions_MulOptions, CreateMulOptions(builder_, activation_type).Union()); @@ -70,6 +71,7 @@ class QuantizedMulOpModel : public BaseMulOpModel { TEST(FloatMulOpTest, NoActivation) { FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5}); @@ -78,9 +80,10 @@ TEST(FloatMulOpTest, NoActivation) { ElementsAreArray(ArrayFloatNear({-0.2, 0.04, 0.21, 0.4}))); } -TEST(FloatMulOpTest, ActivationRELU1) { - FloatMulOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}, - {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU1); +TEST(FloatMulOpTest, ActivationRELU_N1_TO_1) { + FloatMulOpModel m( + {TensorType_FLOAT32, {1, 2, 2, 1}}, {TensorType_FLOAT32, {1, 2, 2, 1}}, + {TensorType_FLOAT32, {}}, ActivationFunctionType_RELU_N1_TO_1); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 5}); m.Invoke(); @@ -93,6 +96,7 @@ TEST(FloatMulOpTest, VariousInputShapes) { {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; for (int i = 0; i < test_shapes.size(); ++i) { FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); m.PopulateTensor(m.input2(), {0.1, 0.2, 0.3, 0.5, 1.1, 0.1}); @@ -104,8 +108,26 @@ TEST(FloatMulOpTest, VariousInputShapes) { } } +TEST(FloatMulOpTest, WithBroadcast) { + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + FloatMulOpModel m({TensorType_FLOAT32, test_shapes[i]}, + {TensorType_FLOAT32, {}}, // always a scalar + {TensorType_FLOAT32, {}}, ActivationFunctionType_NONE); + m.PopulateTensor(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.PopulateTensor(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray(ArrayFloatNear({-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}))) + << "With shape number " << i; + } +} + TEST(QuantizedMulOpTest, NoActivation) { QuantizedMulOpModel m({TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, -1.0, 1.0}, {TensorType_UINT8, {}, -1.0, 1.0}, ActivationFunctionType_NONE); m.QuantizeAndPopulate(m.input1(), {-0.8, 0.2, 0.9, 0.7}); @@ -116,12 +138,37 @@ TEST(QuantizedMulOpTest, NoActivation) { kQuantizedTolerance))); } +// for quantized Mul, the error shouldn't exceed 2*step +float GetTolerance(int min, int max) { + float kQuantizedStep = (max - min) / 255.0; + float kQuantizedTolerance = 2.0 * kQuantizedStep; + return kQuantizedTolerance; +} + +TEST(QuantizedMulOpTest, WithBroadcast) { + float kQuantizedTolerance = GetTolerance(-3.0, 3.0); + std::vector> test_shapes = { + {6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}}; + for (int i = 0; i < test_shapes.size(); ++i) { + QuantizedMulOpModel m({TensorType_UINT8, test_shapes[i], -3.0, 3.0}, + {TensorType_UINT8, {}, -3.0, 3.0}, // always a scalar + {TensorType_UINT8, {}, -3.0, 3.0}, + ActivationFunctionType_NONE); + m.QuantizeAndPopulate(m.input1(), {-2.0, 0.2, 0.7, 0.8, 1.1, 2.0}); + m.QuantizeAndPopulate(m.input2(), {0.1}); + m.Invoke(); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + {-0.2, 0.02, 0.07, 0.08, 0.11, 0.2}, kQuantizedTolerance))) + << "With shape number " << i; + } +} + } // namespace } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/op_macros.h b/tensorflow/contrib/lite/kernels/op_macros.h index 7535afaf8ea52d855e2e4773e56ce2118a16447c..7568eaa88edfa3260964e16f03299aecb97da6be 100644 --- a/tensorflow/contrib/lite/kernels/op_macros.h +++ b/tensorflow/contrib/lite/kernels/op_macros.h @@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ + +#include #define TF_LITE_FATAL(msg) \ do { \ @@ -29,4 +31,4 @@ limitations under the License. if ((x) != (y)) TF_LITE_FATAL(#x " didn't equal " #y); \ } while (0) -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_OP_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc index 8e9cc07656c8bea83f7cb78ca0b6cc5de7ad1b73..cee3ec6197c698a11004d42dccdfe2bcca088015 100644 --- a/tensorflow/contrib/lite/kernels/optional_tensor_test.cc +++ b/tensorflow/contrib/lite/kernels/optional_tensor_test.cc @@ -243,7 +243,6 @@ class LSTMOpModel : public SingleOpModel { int n_output_; }; - TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { const int n_batch = 1; const int n_input = 2; @@ -282,7 +281,6 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { {0}, // projection_bias tensor }); - lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, 0.04717243, 0.48944736, -0.38535351, -0.17212132}); @@ -334,8 +332,7 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/pad.cc b/tensorflow/contrib/lite/kernels/pad.cc new file mode 100644 index 0000000000000000000000000000000000000000..c29da3862e84d6756bf5ef34b2ca06307b0a065d --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad.cc @@ -0,0 +1,183 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace pad { + +// This file has two implementations of Pad. +enum KernelType { + kReference, + kGenericOptimized, +}; + +struct PadContext { + PadContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + paddings = GetInput(context, node, 1); + output = GetOutput(context, node, 0); + dims = NumDimensions(input); + } + TfLiteTensor* input; + TfLiteTensor* paddings; + TfLiteTensor* output; + int dims; +}; + +// Resizes output array based on the input size and padding size. This function +// is callable from both Prepare() and Eval() as long as the caller ensures the +// paddings data is present. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + PadContext* op_context) { + // Ensures the paddings array is dims x 2. + TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 0), + op_context->dims); + TF_LITE_ENSURE_EQ(context, SizeOfDimension(op_context->paddings, 1), 2); + + // Determines the size of the output tensor. + TfLiteIntArray* input_size = op_context->input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); + const int32* paddings_data = GetTensorData(op_context->paddings); + + for (int idx = 0; idx < op_context->dims; ++idx) { + int before_padding = *paddings_data++; + int after_padding = *paddings_data++; + + TF_LITE_ENSURE_MSG(context, (before_padding >= 0 && after_padding >= 0), + "Pad value has to be greater than equal to 0."); + + output_size->data[idx] = + (input_size->data[idx] + before_padding + after_padding); + } + + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + PadContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + // TODO(nupurgarg): Our current implementations rely on the inputs being 4D. + TF_LITE_ENSURE_EQ(context, op_context.dims, 4); + + // Exit early if paddings is a non-const tensor. Set output tensor to + // dynamic so output size can be determined in Eval. + if (!IsConstantTensor(op_context.paddings)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + PadContext op_context(context, node); + + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + + // TODO(nupurgarg): Change kernel implementation to take in int* instead of + // vector to remove malloc from Eval(). + // Create before and after padding arrays that are accepted by the kernel. + std::vector before_padding; + std::vector after_padding; + const int32* paddings_data = GetTensorData(op_context.paddings); + + // TODO(nupurgarg): Change kernel implementation to use padding arrays in + // forward order (depth, width, height, batch). + // Build paddings in order of int[] = {batch, height, width, depth} to match + // kernel implementation of Pad in referenced_ops.h and optimized_ops.h. + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + before_padding.push_back(paddings_data[idx * 2]); + after_padding.push_back(paddings_data[idx * 2 + 1]); + } + +#define TF_LITE_PAD(type, scalar) \ + type::Pad(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), before_padding, after_padding, \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, float); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, uint8_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, int32_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_PAD(reference_ops, int64_t); + } else if (kernel_type == kGenericOptimized) { + TF_LITE_PAD(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, "Type is currently not supported by Pad."); + return kTfLiteError; + } +#undef TF_LITE_PAD + return kTfLiteOk; +} + +} // namespace pad + +TfLiteRegistration* Register_PAD_REF() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, pad::Prepare, + pad::Eval}; + return &r; +} + +TfLiteRegistration* Register_PAD() { return Register_PAD_GENERIC_OPT(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/pad_test.cc b/tensorflow/contrib/lite/kernels/pad_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..28834ad0719291b2e868bca2d86a6685e6eb9962 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/pad_test.cc @@ -0,0 +1,154 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class PadOpModel : public SingleOpModel { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetPaddings(std::initializer_list paddings) { + PopulateTensor(paddings_, paddings); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int output_; + int paddings_; +}; + +// Tests case where paddings is a const tensor. +// +// Example usage is as follows: +// PadOpDynamicModel m(input_shape, paddings_shape, paddings_data); +// m.SetInput(input_data); +// m.Invoke(); +class PadOpConstModel : public PadOpModel { + public: + PadOpConstModel(std::initializer_list input_shape, + std::initializer_list paddings_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + paddings_ = AddConstInput(TensorType_INT32, paddings, paddings_shape); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Test case where paddings is a non-const tensor. +// +// Example usage is as follows: +// PadOpDynamicModel m(input_shape, paddings_shape); +// m.SetInput(input_data); +// m.SetPaddings(paddings_data); +// m.Invoke(); +class PadOpDynamicModel : public PadOpModel { + public: + PadOpDynamicModel(std::initializer_list input_shape, + std::initializer_list paddings_shape) { + input_ = AddInput(TensorType_FLOAT32); + paddings_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_PAD, BuiltinOptions_PadOptions, + CreatePadOptions(builder_).Union()); + BuildInterpreter({input_shape, paddings_shape}); + } +}; + +TEST(PadOpTest, TooManyDimensions) { + EXPECT_DEATH( + PadOpConstModel({1, 2, 3, 4, 5, 6, 7, 8, 9}, {9, 2}, + {1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9}), + "dims != 4"); +} + +TEST(PadOpTest, UnequalDimensions) { + EXPECT_DEATH(PadOpConstModel({1, 1, 2, 1}, {3, 2}, {1, 1, 2, 2, 3, 3}), + "3 != 4"); +} + +TEST(PadOpTest, InvalidPadValue) { + EXPECT_DEATH( + PadOpConstModel({1, 1, 2, 1}, {4, 2}, {0, 0, 1, -1, 2, -1, 0, 0}), + "Pad value has to be greater than equal to 0."); +} + +TEST(PadOpTest, SimpleConstTest) { + // Padding is represented as four 2-D lists representing above padding and + // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). + PadOpConstModel m({1, 2, 2, 1}, {4, 2}, {0, 0, 1, 1, 1, 1, 0, 0}); + m.SetInput({1, 2, 3, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadOpTest, SimpleDynamicTest) { + PadOpDynamicModel m({1, 2, 2, 1}, {4, 2}); + m.SetInput({1, 2, 3, 4}); + m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 0, 0, 0, 0, 1, 2, 0, 0, 3, 4, + 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); +} + +TEST(PadOpTest, AdvancedConstTest) { + PadOpConstModel m({1, 2, 3, 1}, {4, 2}, {0, 0, 0, 2, 1, 3, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +TEST(PadOpTest, AdvancedDynamicTest) { + PadOpDynamicModel m({1, 2, 3, 1}, {4, 2}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 1, 2, 3, 0, 0, 0, 0, 4, 5, 6, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0})); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/padding.h b/tensorflow/contrib/lite/kernels/padding.h index 3a60274524c468ef29e522de5569e0d8354974c2..40b8476b3779c66e31a04856bce8aebd378f1e5f 100644 --- a/tensorflow/contrib/lite/kernels/padding.h +++ b/tensorflow/contrib/lite/kernels/padding.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ namespace tflite { @@ -25,4 +25,4 @@ inline int ComputePadding(int stride, int in_size, int filter_size, } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_PADDING_H_ diff --git a/tensorflow/contrib/lite/kernels/pooling_test.cc b/tensorflow/contrib/lite/kernels/pooling_test.cc index e1b51ec7d5141bf2a41e7ede3e90ff20ec523819..01c91b2ba905e249c36af19f175c68a7e7f17f6d 100644 --- a/tensorflow/contrib/lite/kernels/pooling_test.cc +++ b/tensorflow/contrib/lite/kernels/pooling_test.cc @@ -155,7 +155,7 @@ TEST(FloatPoolingOpTest, L2Pool) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index ca7a0dd1949a3a31d26be770a7df781cc5fe7533..0f365078cdf4f43d545a69cf5b4ac4d353615106 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -20,7 +20,7 @@ namespace ops { namespace builtin { TfLiteRegistration* Register_RELU(); -TfLiteRegistration* Register_RELU1(); +TfLiteRegistration* Register_RELU_N1_TO_1(); TfLiteRegistration* Register_RELU6(); TfLiteRegistration* Register_TANH(); TfLiteRegistration* Register_LOGISTIC(); @@ -31,6 +31,8 @@ TfLiteRegistration* Register_CONV_2D(); TfLiteRegistration* Register_DEPTHWISE_CONV_2D(); TfLiteRegistration* Register_SVDF(); TfLiteRegistration* Register_RNN(); +TfLiteRegistration* Register_BIDIRECTIONAL_SEQUENCE_RNN(); +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN(); TfLiteRegistration* Register_EMBEDDING_LOOKUP(); TfLiteRegistration* Register_EMBEDDING_LOOKUP_SPARSE(); TfLiteRegistration* Register_FULLY_CONNECTED(); @@ -39,18 +41,30 @@ TfLiteRegistration* Register_HASHTABLE_LOOKUP(); TfLiteRegistration* Register_SOFTMAX(); TfLiteRegistration* Register_CONCATENATION(); TfLiteRegistration* Register_ADD(); +TfLiteRegistration* Register_SPACE_TO_BATCH_ND(); +TfLiteRegistration* Register_DIV(); +TfLiteRegistration* Register_SUB(); +TfLiteRegistration* Register_BATCH_TO_SPACE_ND(); TfLiteRegistration* Register_MUL(); TfLiteRegistration* Register_L2_NORMALIZATION(); TfLiteRegistration* Register_LOCAL_RESPONSE_NORMALIZATION(); TfLiteRegistration* Register_LSTM(); +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM(); +TfLiteRegistration* Register_PAD(); TfLiteRegistration* Register_RESHAPE(); TfLiteRegistration* Register_RESIZE_BILINEAR(); TfLiteRegistration* Register_SKIP_GRAM(); TfLiteRegistration* Register_SPACE_TO_DEPTH(); +TfLiteRegistration* Register_GATHER(); +TfLiteRegistration* Register_TRANSPOSE(); +TfLiteRegistration* Register_MEAN(); +TfLiteRegistration* Register_SQUEEZE(); +TfLiteRegistration* Register_STRIDED_SLICE(); +TfLiteRegistration* Register_EXP(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); - AddBuiltin(BuiltinOperator_RELU1, Register_RELU1()); + AddBuiltin(BuiltinOperator_RELU_N1_TO_1, Register_RELU_N1_TO_1()); AddBuiltin(BuiltinOperator_RELU6, Register_RELU6()); AddBuiltin(BuiltinOperator_TANH, Register_TANH()); AddBuiltin(BuiltinOperator_LOGISTIC, Register_LOGISTIC()); @@ -61,6 +75,10 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_DEPTHWISE_CONV_2D, Register_DEPTHWISE_CONV_2D()); AddBuiltin(BuiltinOperator_SVDF, Register_SVDF()); AddBuiltin(BuiltinOperator_RNN, Register_RNN()); + AddBuiltin(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + Register_BIDIRECTIONAL_SEQUENCE_RNN()); + AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + Register_UNIDIRECTIONAL_SEQUENCE_RNN()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP, Register_EMBEDDING_LOOKUP()); AddBuiltin(BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, Register_EMBEDDING_LOOKUP_SPARSE()); @@ -70,15 +88,28 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_SOFTMAX, Register_SOFTMAX()); AddBuiltin(BuiltinOperator_CONCATENATION, Register_CONCATENATION()); AddBuiltin(BuiltinOperator_ADD, Register_ADD()); + AddBuiltin(BuiltinOperator_SPACE_TO_BATCH_ND, Register_SPACE_TO_BATCH_ND()); + AddBuiltin(BuiltinOperator_BATCH_TO_SPACE_ND, Register_BATCH_TO_SPACE_ND()); AddBuiltin(BuiltinOperator_MUL, Register_MUL()); AddBuiltin(BuiltinOperator_L2_NORMALIZATION, Register_L2_NORMALIZATION()); AddBuiltin(BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION, Register_LOCAL_RESPONSE_NORMALIZATION()); AddBuiltin(BuiltinOperator_LSTM, Register_LSTM()); + AddBuiltin(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + Register_UNIDIRECTIONAL_SEQUENCE_LSTM()); + AddBuiltin(BuiltinOperator_PAD, Register_PAD()); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); + AddBuiltin(BuiltinOperator_GATHER, Register_GATHER()); + AddBuiltin(BuiltinOperator_TRANSPOSE, Register_TRANSPOSE()); + AddBuiltin(BuiltinOperator_MEAN, Register_MEAN()); + AddBuiltin(BuiltinOperator_DIV, Register_DIV()); + AddBuiltin(BuiltinOperator_SUB, Register_SUB()); + AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE()); + AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE()); + AddBuiltin(BuiltinOperator_EXP, Register_EXP()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/register.h b/tensorflow/contrib/lite/kernels/register.h index 28f5e0fcc80a14cf9fb6fb19b795d0c0d55e0df9..b9cff0ae21086b44e0c920095d5f6c9668346f38 100644 --- a/tensorflow/contrib/lite/kernels/register.h +++ b/tensorflow/contrib/lite/kernels/register.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_REGISTER_H_ #include #include "tensorflow/contrib/lite/context.h" @@ -47,4 +47,4 @@ class BuiltinOpResolver : public OpResolver { } // namespace ops } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_BUILTIN_KERNELS_H diff --git a/tensorflow/contrib/lite/kernels/reshape_test.cc b/tensorflow/contrib/lite/kernels/reshape_test.cc index 59ce7d5648c04f78123b16a195d3a4928d28394b..0fbcf6e6aa311d2cac491336ee54ccf58bbda8fd 100644 --- a/tensorflow/contrib/lite/kernels/reshape_test.cc +++ b/tensorflow/contrib/lite/kernels/reshape_test.cc @@ -83,8 +83,7 @@ TEST(ReshapeOpTest, WithStretchDimension) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index 1613c9a89faa3579b913408cc09cdad7f942cb99..9e3e19c09a4012ebdadbc2a7c2ba06c4bfefd206 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -33,32 +33,44 @@ enum KernelType { }; constexpr int kInputTensor = 0; +constexpr int kSizeTensor = 1; constexpr int kOutputTensor = 0; -TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { - auto* params = - reinterpret_cast(node->builtin_data); +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input, + TfLiteTensor* size, TfLiteTensor* output) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + const int32* size_data = GetTensorData(size); + output_size->data[1] = size_data[0]; + output_size->data[2] = size_data[1]; + output_size->data[3] = input->dims->data[3]; + return context->ResizeTensor(context, output, output_size); +} - TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* size = GetInput(context, node, kSizeTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); // TODO(ahentz): Our current implementations rely on the inputs being 4D. TF_LITE_ENSURE_EQ(context, NumDimensions(input), 4); + TF_LITE_ENSURE_EQ(context, NumDimensions(size), 1); // TODO(ahentz): Our current implementations only support float32. - TF_LITE_ENSURE_EQ(context, output->type, kTfLiteFloat32); - TF_LITE_ENSURE_EQ(context, input->type, output->type); - - TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); - output_size->data[0] = input->dims->data[0]; - output_size->data[1] = params->new_height; - output_size->data[2] = params->new_width; - output_size->data[3] = input->dims->data[3]; - - return context->ResizeTensor(context, output, output_size); + TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); + TF_LITE_ENSURE_EQ(context, size->type, kTfLiteInt32); + // ResizeBilinear creates a float tensor even when the input is made of + // integers. + output->type = kTfLiteFloat32; + + if (!IsConstantTensor(size)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, input, size, output); } template @@ -68,15 +80,19 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* input = GetInput(context, node, kInputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* size = GetInput(context, node, kSizeTensor); - // We have to fake a tensor here, to satisfy ResizeBilinear(). - int32 output_size_data[2] = {params->new_height, params->new_width}; + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputTensor(context, input, size, output)); + } if (output->type == kTfLiteFloat32) { -#define TF_LITE_RESIZE_BILINEAR(type) \ - type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ - output_size_data, GetTensorDims({1, 1, 1, 2}), \ - GetTensorData(output), GetTensorDims(output)) +#define TF_LITE_RESIZE_BILINEAR(type) \ + type::ResizeBilinear(GetTensorData(input), GetTensorDims(input), \ + GetTensorData(size), GetTensorDims(size), \ + GetTensorData(output), GetTensorDims(output), \ + params->align_corners) if (kernel_type == kReference) { TF_LITE_RESIZE_BILINEAR(reference_ops); diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 0257c0b557feb352413bcc33cb4e2ecdb32c5111..4e03f3820a5c14ee1692c553db61e385716b1723 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -25,63 +25,101 @@ using ::testing::ElementsAreArray; class ResizeBilinearOpModel : public SingleOpModel { public: - ResizeBilinearOpModel(std::initializer_list input_shape, int new_height, - int new_width) { - input_ = AddInput(TensorType_FLOAT32); - output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, - CreateResizeBilinearOptions(builder_, new_height, new_width).Union()); - BuildInterpreter({input_shape}); + ResizeBilinearOpModel(const TensorData& input, + std::initializer_list size_data = {}) { + bool const_size = size_data.size() != 0; + input_ = AddInput(input); + if (const_size) { + size_ = AddConstInput(TensorType_INT32, size_data, {2}); + } else { + size_ = AddInput({TensorType_INT32, {2}}); + } + output_ = AddOutput(TensorType_FLOAT32); // Always float. + SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, + BuiltinOptions_ResizeBilinearOptions, + CreateResizeBilinearOptions(builder_).Union()); + if (const_size) { + BuildInterpreter({GetShape(input_)}); + } else { + BuildInterpreter({GetShape(input_), GetShape(size_)}); + } } void SetInput(std::initializer_list data) { PopulateTensor(input_, data); } + void SetSize(std::initializer_list data) { PopulateTensor(size_, data); } std::vector GetOutput() { return ExtractVector(output_); } private: int input_; + int size_; int output_; }; TEST(ResizeBilinearOpTest, HorizontalResize) { - ResizeBilinearOpModel m({1, 1, 2, 1}, 1, 3); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); m.SetInput({3, 6}); + m.SetSize({1, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { - ResizeBilinearOpModel m({1, 2, 1, 1}, 3, 1); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); m.SetInput({3, 9}); + m.SetSize({3, 1}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 1}, 3, 3); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12 // }); + m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 3, 5, 6, // 7, 9, 10, // 9, 11, 12, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { - ResizeBilinearOpModel m({2, 2, 2, 1}, 3, 3); + ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12, // 4, 10, // 10, 16 // }); + m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 3, 5, 6, // @@ -91,27 +129,57 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { 8, 12, 14, // 10, 14, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 2}, 3, 3); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // }); + m.SetSize({3, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({ 3, 4, 5, 8, 6, 10, // 7, 8, 9, 12, 10, 14, // 9, 10, 11, 14, 12, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); } } // namespace } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/skip_gram_test.cc b/tensorflow/contrib/lite/kernels/skip_gram_test.cc index e7f6bc904be5e4c23a88f5b4ae7e199346c78ab2..185b64cb44969b57588ea5d0b40f55b6ddf8e11f 100644 --- a/tensorflow/contrib/lite/kernels/skip_gram_test.cc +++ b/tensorflow/contrib/lite/kernels/skip_gram_test.cc @@ -251,7 +251,7 @@ TEST(SkipGramTest, TestInputWithExtraSpace) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/softmax_test.cc b/tensorflow/contrib/lite/kernels/softmax_test.cc index ec8ec03b0d0279cad8543352b1dbaf34c88a7957..6c5338ff0fd26337c9adc8e0b94a0a88edfde37f 100644 --- a/tensorflow/contrib/lite/kernels/softmax_test.cc +++ b/tensorflow/contrib/lite/kernels/softmax_test.cc @@ -136,8 +136,7 @@ TEST(SoftmaxOpTest, CompareWithTFminiBetaNotEq1) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8c9e352f00627eee45ae836b720f2af77140538 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd.cc @@ -0,0 +1,186 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace space_to_batch_nd { + +// This file has two implementations of SpaceToBatchND. +enum KernelType { + kReference, + kGenericOptimized, +}; + +struct SpaceToBatchNDContext { + SpaceToBatchNDContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + block_shape = GetInput(context, node, 1); + paddings = GetInput(context, node, 2); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* block_shape; + TfLiteTensor* paddings; + TfLiteTensor* output; +}; + +// Currently, only 4D NHWC input/output op_context are supported. +// The 4D array need to have exactly 2 spatial dimensions. +// TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND. +const int kInputDimensionNum = 4; +const int kBlockSizeDimensionNum = 1; +const int kSpatialDimensionNum = 2; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + SpaceToBatchNDContext* op_context) { + TfLiteIntArray* input_size = op_context->input->dims; + const int32* block_shape = GetTensorData(op_context->block_shape); + const int32* paddings_data = GetTensorData(op_context->paddings); + + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape), + kBlockSizeDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context->block_shape->dims->data[0], + kSpatialDimensionNum); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->paddings), + kSpatialDimensionNum); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); + + // Ensures the input height and width (with padding) is a multiple of block + // shape height and width. + for (int dim = 0; dim < kSpatialDimensionNum; ++dim) { + int final_dim_size = (input_size->data[dim + 1] + paddings_data[dim * 2] + + paddings_data[dim * 2 + 1]); + TF_LITE_ENSURE_EQ(context, final_dim_size % block_shape[dim], 0); + output_size->data[dim + 1] = final_dim_size / block_shape[dim]; + } + + const int output_batch_size = + input_size->data[0] * block_shape[0] * block_shape[1]; + const int output_channel_size = input_size->data[3]; + + output_size->data[0] = output_batch_size; + output_size->data[3] = output_channel_size; + + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + SpaceToBatchNDContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.input), + kInputDimensionNum); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.block_shape) || + !IsConstantTensor(op_context.paddings)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + SpaceToBatchNDContext op_context(context, node); + + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + +#define TF_LITE_SPACE_TO_BATCH_ND(type, scalar) \ + type::SpaceToBatchND(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.block_shape), \ + GetTensorDims(op_context.block_shape), \ + GetTensorData(op_context.paddings), \ + GetTensorDims(op_context.paddings), \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + switch (op_context.input->type) { // Already know in/out types are same. + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, float); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, uint8_t); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int32_t); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_SPACE_TO_BATCH_ND(reference_ops, int64_t); + } else { + TF_LITE_SPACE_TO_BATCH_ND(optimized_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by SpaceToBatch."); + return kTfLiteError; + } +#undef TF_LITE_SPACE_TO_BATCH_ND + return kTfLiteOk; +} + +} // namespace space_to_batch_nd + +TfLiteRegistration* Register_SPACE_TO_BATCH_ND_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_batch_nd::Prepare, + space_to_batch_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_BATCH_ND_GENERIC_OPT() { + static TfLiteRegistration r = { + nullptr, nullptr, space_to_batch_nd::Prepare, + space_to_batch_nd::Eval}; + return &r; +} + +TfLiteRegistration* Register_SPACE_TO_BATCH_ND() { + // return Register_SPACE_TO_BATCH_ND_REF(); + return Register_SPACE_TO_BATCH_ND_GENERIC_OPT(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..92a4a037d5873e608ee7bdbdfc5eaa5e9b62bc8c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/space_to_batch_nd_test.cc @@ -0,0 +1,199 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class SpaceToBatchNDOpModel : public SingleOpModel { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetBlockShape(std::initializer_list data) { + PopulateTensor(block_shape_, data); + } + + void SetPaddings(std::initializer_list data) { + PopulateTensor(paddings_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int block_shape_; + int paddings_; + int output_; +}; + +// Tests case where block_shape and paddings are const tensors. +// +// Example usage is as follows: +// SpaceToBatchNDOpConstModel m(input_shape, block_shape, paddings); +// m.SetInput(input_data); +// m.Invoke(); +class SpaceToBatchNDOpConstModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpConstModel(std::initializer_list input_shape, + std::initializer_list block_shape, + std::initializer_list paddings) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddConstInput(TensorType_INT32, block_shape, {2}); + paddings_ = AddConstInput(TensorType_INT32, paddings, {2, 2}); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where block_shape and paddings are non-const tensors. +// +// Example usage is as follows: +// SpaceToBatchNDOpDynamicModel m(input_shape); +// m.SetInput(input_data); +// m.SetBlockShape(block_shape); +// m.SetPaddings(paddings); +// m.Invoke(); +class SpaceToBatchNDOpDynamicModel : public SpaceToBatchNDOpModel { + public: + SpaceToBatchNDOpDynamicModel(std::initializer_list input_shape) { + input_ = AddInput(TensorType_FLOAT32); + block_shape_ = AddInput(TensorType_INT32); + paddings_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_SPACE_TO_BATCH_ND, + BuiltinOptions_SpaceToBatchNDOptions, + CreateSpaceToBatchNDOptions(builder_).Union()); + BuildInterpreter({input_shape, {2}, {2, 2}}); + } +}; + +TEST(SpaceToBatchNDOpTest, InvalidShapeTest) { + EXPECT_DEATH(SpaceToBatchNDOpConstModel({1, 3, 3, 1}, {2, 2}, {0, 0, 0, 0}), + "Cannot allocate tensors"); +} + +TEST(SpaceToBatchNDOpTest, SimpleConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, SimpleDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 4, 4, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetPaddings({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, MultipleInputBatchesConstTest) { + SpaceToBatchNDOpConstModel m({2, 2, 4, 1}, {2, 2}, {0, 0, 0, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, MultipleInputBatchesDynamicTest) { + SpaceToBatchNDOpDynamicModel m({2, 2, 4, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); + m.SetBlockShape({2, 2}); + m.SetPaddings({0, 0, 0, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({8, 1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 9, 11, 2, 4, 10, 12, 5, 7, + 13, 15, 6, 8, 14, 16})); +} + +TEST(SpaceToBatchNDOpTest, SimplePaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 5, 2, 1}, {3, 2}, {1, 0, 2, 0}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(SpaceToBatchNDOpTest, SimplePaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 5, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 0, 2, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 5, 0, 0, 0, 6, 0, 1, 0, 7, + 0, 2, 0, 8, 0, 3, 0, 9, 0, 4, 0, 10, + })); +} + +TEST(SpaceToBatchNDOpTest, ComplexPaddingConstTest) { + SpaceToBatchNDOpConstModel m({1, 4, 2, 1}, {3, 2}, {1, 1, 2, 4}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +TEST(SpaceToBatchNDOpTest, ComplexPaddingDynamicTest) { + SpaceToBatchNDOpDynamicModel m({1, 4, 2, 1}); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8}); + m.SetBlockShape({3, 2}); + m.SetPaddings({1, 1, 2, 4}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({6, 2, 4, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({ + 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 6, 0, 0, + 0, 1, 0, 0, 0, 7, 0, 0, 0, 2, 0, 0, 0, 8, 0, 0, + 0, 3, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, + })); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc index 911f08a92ccd6a97bee414c87bd79091808f0ed1..997f354861a235fb511235e4d64544dc8c3ddb34 100644 --- a/tensorflow/contrib/lite/kernels/space_to_depth_test.cc +++ b/tensorflow/contrib/lite/kernels/space_to_depth_test.cc @@ -95,8 +95,7 @@ TEST(SpaceToDepthOpModel, Int64) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); - tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/squeeze.cc b/tensorflow/contrib/lite/kernels/squeeze.cc new file mode 100644 index 0000000000000000000000000000000000000000..29447ab021c7b68ff51070d35262402e08dc7ab9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/squeeze.cc @@ -0,0 +1,99 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace squeeze { + +struct SqueezeContext { + SqueezeContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteSqueezeParams* params; + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + SqueezeContext op_context(context, node); + int input_num_dims = NumDimensions(op_context.input); + int num_squeeze_dims = op_context.params->num_squeeze_dims; + + // Determines number of dimensions of output tensor after squeeze. + const TfLiteIntArray* input_dims = op_context.input->dims; + const int* squeeze_dims = op_context.params->squeeze_dims; + TF_LITE_ENSURE(context, input_num_dims <= 8); + bool should_squeeze[8] = {false}; + int num_squeezed_dims = 0; + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < input_num_dims; ++idx) { + if (input_dims->data[idx] == 1) { + should_squeeze[idx] = true; + ++num_squeezed_dims; + } + } + } else { + for (int idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + input_num_dims + : squeeze_dims[idx]; + TF_LITE_ENSURE(context, current >= 0 && current < input_num_dims && + input_dims->data[current] == 1); + if (!should_squeeze[current]) ++num_squeezed_dims; + should_squeeze[current] = true; + } + } + // Sets output dimensions. + TfLiteIntArray* output_dims = + TfLiteIntArrayCreate(input_num_dims - num_squeezed_dims); + for (int in_idx = 0, out_idx = 0; in_idx < input_num_dims; ++in_idx) { + if (!should_squeeze[in_idx]) { + output_dims->data[out_idx++] = input_dims->data[in_idx]; + } + } + return context->ResizeTensor(context, op_context.output, output_dims); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + SqueezeContext op_context(context, node); + TF_LITE_ENSURE_EQ(context, op_context.input->bytes, op_context.output->bytes); + memcpy(op_context.output->data.raw, op_context.input->data.raw, + op_context.input->bytes); + return kTfLiteOk; +} + +} // namespace squeeze + +TfLiteRegistration* Register_SQUEEZE() { + static TfLiteRegistration r = {nullptr, nullptr, squeeze::Prepare, + squeeze::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/squeeze_test.cc b/tensorflow/contrib/lite/kernels/squeeze_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8aab88357cacbb72784a4bc6e860aeb47783eb3 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/squeeze_test.cc @@ -0,0 +1,124 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; +using ::testing::IsEmpty; + +class BaseSqueezeOpModel : public SingleOpModel { + public: + BaseSqueezeOpModel(const TensorData& input, const TensorData& output, + std::initializer_list axis) { + input_ = AddInput(input); + output_ = AddOutput(output); + SetBuiltinOp( + BuiltinOperator_SQUEEZE, BuiltinOptions_SqueezeOptions, + CreateSqueezeOptions(builder_, builder_.CreateVector(axis)) + .Union()); + BuildInterpreter({GetShape(input_)}); + } + + int input() { return input_; } + + protected: + int input_; + int output_; +}; + +class FloatSqueezeOpModel : public BaseSqueezeOpModel { + public: + using BaseSqueezeOpModel::BaseSqueezeOpModel; + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } +}; + +TEST(FloatSqueezeOpTest, SqueezeAll) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(FloatSqueezeOpTest, SqueezeSelectedAxis) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {2}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(FloatSqueezeOpTest, SqueezeNegativeAxis) { + std::initializer_list data = { + 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, + 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 24, 1}}, + {TensorType_FLOAT32, {24}}, {-1, 0}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({24})); + EXPECT_THAT( + m.GetOutput(), + ElementsAreArray({1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, + 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, + 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0})); +} + +TEST(FloatSqueezeOpTest, SqueezeAllDims) { + std::initializer_list data = {3.85}; + FloatSqueezeOpModel m({TensorType_FLOAT32, {1, 1, 1, 1, 1, 1, 1}}, + {TensorType_FLOAT32, {1}}, {}); + m.SetInput(data); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), IsEmpty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3.85})); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/strided_slice.cc b/tensorflow/contrib/lite/kernels/strided_slice.cc new file mode 100644 index 0000000000000000000000000000000000000000..fb1e11e0ca00abb36d7f29d562711a7bbcbeca1c --- /dev/null +++ b/tensorflow/contrib/lite/kernels/strided_slice.cc @@ -0,0 +1,259 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace strided_slice { + +enum KernelType { + kReference, + // TODO(soroosh): add kGenericOptimized +}; + +constexpr int kInputTensor = 0; +constexpr int kBeginTensor = 1; +constexpr int kEndTensor = 2; +constexpr int kStridesTensor = 3; +constexpr int kOutputTensor = 0; + +struct StridedSliceContext { + StridedSliceContext(TfLiteContext* context, TfLiteNode* node) { + params = reinterpret_cast(node->builtin_data); + input = GetInput(context, node, kInputTensor); + begin = GetInput(context, node, kBeginTensor); + end = GetInput(context, node, kEndTensor); + strides = GetInput(context, node, kStridesTensor); + output = GetOutput(context, node, kOutputTensor); + dims = NumDimensions(input); + } + TfLiteStridedSliceParams* params; + TfLiteTensor* input; + TfLiteTensor* begin; + TfLiteTensor* end; + TfLiteTensor* strides; + TfLiteTensor* output; + int dims; +}; + +// Reverse order of bits in the mask to match the expected order in kernel +inline int ReverseMaskBits(int mask, int num_dimensions) { + int out = 0; + for (int dim = 0; dim < num_dimensions; dim++) { + out <<= 1; + out += (mask & 1); + mask >>= 1; + } + return out; +} + +// This Op only supports 1-4D cases and since we use the reference 4D +// implementation, the 1-3D tensors are mapped to 4D. +const int kMaxDim = 4; + +inline int32_t PositiveRemainder(int32_t dividend, int32_t divisor) { + return (divisor + (dividend % divisor)) % divisor; +} + +inline int32_t ClampedIndex(int32_t index, int dim, bool pos_stride) { + return pos_stride + ? (index >= dim ? dim + : PositiveRemainder( + std::min(std::max(index, -dim), dim), dim)) + : (index < -dim + ? -1 + : PositiveRemainder( + std::min(std::max(index, -dim), dim - 1), dim)); +} + +inline int32_t GetBeginValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData(op_context->strides)[idx] > 0; + return op_context->params->begin_mask & (1 << idx) + ? pos_stride ? 0 : dim - 1 + : ClampedIndex(GetTensorData(op_context->begin)[idx], dim, + pos_stride); +} + +inline int32_t GetEndValueAtIndex(StridedSliceContext* op_context, int idx) { + const int dim = op_context->input->dims->data[idx]; + const bool pos_stride = GetTensorData(op_context->strides)[idx] > 0; + return op_context->params->end_mask & (1 << idx) + ? pos_stride ? dim : -1 + : ClampedIndex(GetTensorData(op_context->end)[idx], dim, + pos_stride); +} + +// Processes the indexing tensors (begin, end and strides) to resize the +// output tensor. This function is callable from both Prepare() and Eval() as +// long as the caller ensures the indexing tensors are present. +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + StridedSliceContext* op_context) { + std::vector output_shape_vector; + + for (int idx = op_context->dims - 1; idx >= 0; --idx) { + int32_t stride = GetTensorData(op_context->strides)[idx]; + TF_LITE_ENSURE_MSG(context, stride != 0, "stride value has to be non-zero"); + + int32_t begin = GetBeginValueAtIndex(op_context, idx); + int32_t end = GetEndValueAtIndex(op_context, idx); + + // This is valid for both positive and negative strides + int32_t dim_shape = ceil((end - begin) / static_cast(stride)); + dim_shape = dim_shape < 0 ? 0 : dim_shape; + if (!(op_context->params->shrink_axis_mask & (1 << idx))) { + output_shape_vector.push_back(dim_shape); + } + } + + TfLiteIntArray* output_shape = + TfLiteIntArrayCreate(output_shape_vector.size()); + + std::reverse_copy(output_shape_vector.begin(), output_shape_vector.end(), + output_shape->data); + + TF_LITE_ENSURE_STATUS( + context->ResizeTensor(context, op_context->output, output_shape)); + + return kTfLiteOk; +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 4); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + StridedSliceContext op_context(context, node); + + // Ensure validity of input tensor and its dimension + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.begin), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.end), 1); + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context.strides), 1); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + // Only INT32 begin/end/strides are supported + // TODO(soroosh) add support for INT64 + TF_LITE_ENSURE_EQ(context, op_context.begin->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.end->type, kTfLiteInt32); + TF_LITE_ENSURE_EQ(context, op_context.strides->type, kTfLiteInt32); + TF_LITE_ENSURE_MSG(context, op_context.dims <= 4, + "StridedSlice op only supports 1D-4D input arrays."); + + // TODO(soroosh): add the following missing functionalities + TF_LITE_ENSURE_MSG(context, op_context.params->ellipsis_mask == 0, + "ellipsis_mask is not implemented yet."); + TF_LITE_ENSURE_MSG(context, op_context.params->new_axis_mask == 0, + "new_axis_mask is not implemented yet."); + + // Postpone allocation of output if any of the indexing tensors is not + // constant + if (!(IsConstantTensor(op_context.begin) && + IsConstantTensor(op_context.end) && + IsConstantTensor(op_context.strides))) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + StridedSliceContext op_context(context, node); + + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + + std::vector starts; + std::vector stops; + std::vector strides; + + for (int idx = op_context.dims - 1; idx >= 0; --idx) { + starts.emplace_back(GetBeginValueAtIndex(&op_context, idx)); + stops.emplace_back(GetEndValueAtIndex(&op_context, idx)); + strides.emplace_back(GetTensorData(op_context.strides)[idx]); + } + + for (int i = op_context.dims; i < kMaxDim; i++) { + starts.emplace_back(0); + stops.emplace_back(1); + strides.emplace_back(1); + } + + op_context.params->begin_mask = + ReverseMaskBits(op_context.params->begin_mask, op_context.dims); + op_context.params->end_mask = + ReverseMaskBits(op_context.params->end_mask, op_context.dims); + op_context.params->shrink_axis_mask = + ReverseMaskBits(op_context.params->shrink_axis_mask, op_context.dims); + +#define TF_LITE_STRIDED_SLICE(kernel_type, data_type) \ + kernel_type::StridedSlice( \ + GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), op_context.params->begin_mask, \ + op_context.params->end_mask, op_context.params->shrink_axis_mask, \ + starts, stops, strides, GetTensorData(op_context.output), \ + GetTensorDims(op_context.output)) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, float); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_STRIDED_SLICE(reference_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported " + "by StridedSlice."); + return kTfLiteError; + } +#undef TF_LITE_STRIDED_SLICE + return kTfLiteOk; +} + +} // namespace strided_slice + +TfLiteRegistration* Register_STRIDED_SLICE_REF() { + static TfLiteRegistration r = { + nullptr, nullptr, strided_slice::Prepare, + strided_slice::Eval}; + return &r; +} + +// TODO(soroosh): add optimized +TfLiteRegistration* Register_STRIDED_SLICE() { + return Register_STRIDED_SLICE_REF(); +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/strided_slice_test.cc b/tensorflow/contrib/lite/kernels/strided_slice_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..5cac04b38364958c5b0794c21742e8b592372ae9 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/strided_slice_test.cc @@ -0,0 +1,532 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::int32; +using ::testing::ElementsAreArray; + +class StridedSliceOpModel : public SingleOpModel { + public: + StridedSliceOpModel(std::initializer_list input_shape, + std::initializer_list begin_shape, + std::initializer_list end_shape, + std::initializer_list strides_shape, int begin_mask, + int end_mask, int ellipsis_mask, int new_axis_mask, + int shrink_axis_mask) { + input_ = AddInput(TensorType_FLOAT32); + begin_ = AddInput(TensorType_INT32); + end_ = AddInput(TensorType_INT32); + strides_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp( + BuiltinOperator_STRIDED_SLICE, BuiltinOptions_StridedSliceOptions, + CreateStridedSliceOptions(builder_, begin_mask, end_mask, ellipsis_mask, + new_axis_mask, shrink_axis_mask) + .Union()); + BuildInterpreter({input_shape, begin_shape, end_shape, strides_shape}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + void SetBegin(std::initializer_list data) { + PopulateTensor(begin_, data); + } + void SetEnd(std::initializer_list data) { + PopulateTensor(end_, data); + } + void SetStrides(std::initializer_list data) { + PopulateTensor(strides_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + private: + int input_; + int begin_; + int end_; + int strides_; + int output_; +}; + +TEST(StridedSliceOpTest, UnsupportedInputSize) { + EXPECT_DEATH( + StridedSliceOpModel({2, 2, 2, 2, 2}, {5}, {5}, {5}, 0, 0, 0, 0, 0), + "StridedSlice op only supports 1D-4D input arrays."); +} + +TEST(StridedSliceOpTest, UnssupportedArgs) { + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 1, 0, 0), + "ellipsis_mask is not implemented yet."); + EXPECT_DEATH(StridedSliceOpModel({3, 2}, {2}, {2}, {2}, 0, 0, 0, 1, 0), + "new_axis_mask is not implemented yet."); +} + +TEST(StridedSliceOpTest, In1D) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); +} + +TEST(StridedSliceOpTest, In1D_EmptyOutput) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({10}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({0})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeBegin) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-5}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In1D_NegativeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({-2}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeEnd) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({5}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} + +TEST(StridedSliceOpTest, In1D_BeginMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-2}); + m.SetEnd({-3}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeBeginNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({5}); + m.SetEnd({2}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4})); +} + +TEST(StridedSliceOpTest, In1D_NegativeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({2}); + m.SetEnd({-4}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2})); +} + +TEST(StridedSliceOpTest, In1D_OutOfRangeEndNegativeStride) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-3}); + m.SetEnd({-5}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 1})); +} + +TEST(StridedSliceOpTest, In1D_EndMask) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 1, 0, 0, 0); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2, 3, 4})); +} + +TEST(StridedSliceOpTest, In1D_NegStride) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3}); + m.SetBegin({-1}); + m.SetEnd({-4}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3, 2, 1})); +} + +TEST(StridedSliceOpTest, In1D_EvenLenStride2) { + StridedSliceOpModel m({2}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2}); + m.SetBegin({0}); + m.SetEnd({2}); + m.SetStrides({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TEST(StridedSliceOpTest, In1D_OddLenStride2) { + StridedSliceOpModel m({3}, {1}, {1}, {1}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3}); + m.SetBegin({0}); + m.SetEnd({3}); + m.SetStrides({2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); +} + +TEST(StridedSliceOpTest, In2D_Identity) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(StridedSliceOpTest, In2D) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5})); +} + +TEST(StridedSliceOpTest, In2D_Stride2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3})); +} + +TEST(StridedSliceOpTest, In2D_NegStride) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -1}); + m.SetEnd({2, -4}); + m.SetStrides({2, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); +} + +TEST(StridedSliceOpTest, In2D_BeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 1, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 4, 5})); +} + +TEST(StridedSliceOpTest, In2D_EndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, 0}); + m.SetEnd({2, 2}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({4, 5, 6})); +} + +TEST(StridedSliceOpTest, In2D_NegStrideBeginMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 2, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -2}); + m.SetEnd({2, -4}); + m.SetStrides({1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({6, 5, 4})); +} + +TEST(StridedSliceOpTest, In2D_NegStrideEndMask) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 2, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({1, -2}); + m.SetEnd({2, -3}); + m.SetStrides({1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({5, 4})); +} + +TEST(StridedSliceOpTest, In3D_Identity) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); +} + +TEST(StridedSliceOpTest, In3D_NegStride) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({-1, -1, -1}); + m.SetEnd({-3, -4, -3}); + m.SetStrides({-1, -1, -1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3, 2})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1})); +} + +TEST(StridedSliceOpTest, In3D_Strided2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 0); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({2, 2, 2}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 5})); +} + +TEST(StridedSliceOpTest, In1D_ShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({2})); +} + +TEST(StridedSliceOpTest, In1D_EmptyOutputShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({2}); + m.SetEnd({1}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In1D_BeginMaskShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 1, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({1}); + m.SetEnd({3}); + m.SetStrides({1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TEST(StridedSliceOpTest, In1D_NegativeBeginNegativeStrideShrinkAxisMask1) { + StridedSliceOpModel m({4}, {1}, {1}, {1}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4}); + m.SetBegin({-2}); + m.SetEnd({-3}); + m.SetStrides({-1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask1) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask2) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 2); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4})); +} + +TEST(StridedSliceOpTest, In2D_ShrinkAxisMask3) { + StridedSliceOpModel m({2, 3}, {2}, {2}, {2}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetBegin({0, 0}); + m.SetEnd({2, 3}); + m.SetStrides({1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis1) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 1); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis2) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 2); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 7, 8})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis3) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 3); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis4) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 4); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5, 7, 9, 11})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis5) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 5); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 3, 5})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis6) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 6); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 7})); +} + +TEST(StridedSliceOpTest, In3D_IdentityShrinkAxis7) { + StridedSliceOpModel m({2, 3, 2}, {3}, {3}, {3}, 0, 0, 0, 0, 7); + m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + m.SetBegin({0, 0, 0}); + m.SetEnd({2, 3, 2}); + m.SetStrides({1, 1, 1}); + m.Invoke(); + EXPECT_TRUE(m.GetOutputShape().empty()); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/sub.cc b/tensorflow/contrib/lite/kernels/sub.cc new file mode 100644 index 0000000000000000000000000000000000000000..ddaf498d5bac0109429224e7cf66cb3debcabc22 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/sub.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/quantization_util.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace sub { + +// This file has three implementation of Div. +enum KernelType { + kReference, + kGenericOptimized, // Neon-free + kNeonOptimized, +}; + +constexpr int kInputTensor1 = 0; +constexpr int kInputTensor2 = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + TF_LITE_ENSURE_EQ(context, NumDimensions(input1), NumDimensions(input2)); + for (int i = 0; i < NumDimensions(input1); ++i) { + TF_LITE_ENSURE_EQ(context, SizeOfDimension(input1, i), + SizeOfDimension(input2, i)); + } + + TF_LITE_ENSURE_EQ(context, input1->type, output->type); + TF_LITE_ENSURE_EQ(context, input2->type, output->type); + + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input1->dims); + return context->ResizeTensor(context, output, output_size); +} + +template +void EvalSubFloat(TfLiteContext* context, TfLiteNode* node, + TfLiteSubParams* params, TfLiteTensor* input1, + TfLiteTensor* input2, TfLiteTensor* output) { + float output_activation_min, output_activation_max; + CalculateActivationRangeFloat(params->activation, &output_activation_min, + &output_activation_max); +#define TF_LITE_Sub(type) \ + type::Sub(GetTensorData(input1), GetTensorDims(input1), \ + GetTensorData(input2), GetTensorDims(input2), \ + output_activation_min, output_activation_max, \ + GetTensorData(output), GetTensorDims(output)) + if (kernel_type == kReference) { + TF_LITE_Sub(reference_ops); + } else { + TF_LITE_Sub(optimized_ops); + } +#undef TF_LITE_Sub +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input1 = GetInput(context, node, kInputTensor1); + TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + if (output->type == kTfLiteFloat32) { + EvalSubFloat(context, node, params, input1, input2, output); + } else { + context->ReportError(context, "Inputs and outputs not all float types."); + return kTfLiteError; + } + + return kTfLiteOk; +} + +} // namespace sub + +TfLiteRegistration* Register_SUB_REF() { + static TfLiteRegistration r = {nullptr, nullptr, sub::Prepare, + sub::Eval}; + return &r; +} + +TfLiteRegistration* Register_SUB_GENERIC_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, sub::Prepare, + sub::Eval}; + return &r; +} + +TfLiteRegistration* Register_SUB_NEON_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, sub::Prepare, + sub::Eval}; + return &r; +} + +TfLiteRegistration* Register_SUB() { +#ifdef USE_NEON + return Register_SUB_NEON_OPT(); +#else + return Register_SUB_GENERIC_OPT(); +#endif +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/svdf.cc b/tensorflow/contrib/lite/kernels/svdf.cc index 72f705fe4242b01c1516c99d3500484e8729fd9a..c69755447d5093e25d408eb6dea80750937465e7 100644 --- a/tensorflow/contrib/lite/kernels/svdf.cc +++ b/tensorflow/contrib/lite/kernels/svdf.cc @@ -15,8 +15,8 @@ limitations under the License. #include #include #include -#include #include +#include #include #include diff --git a/tensorflow/contrib/lite/kernels/svdf_test.cc b/tensorflow/contrib/lite/kernels/svdf_test.cc index d956025e9dfc9b6c03e55657023fb042c8ac485d..0f166dc69b95f3459388135b3a6c4d9b73a31cb4 100644 --- a/tensorflow/contrib/lite/kernels/svdf_test.cc +++ b/tensorflow/contrib/lite/kernels/svdf_test.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ // Unit test for TFLite SVDF op. -#include #include +#include #include #include @@ -306,7 +306,7 @@ TEST(SVDFOpTest, BlackBoxTestRank2) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index f716ba8741fd469e7ee405ac300924b53c5c48e5..6f56aa6bf38781e860e33e8ac3b6a0bb8b50bb01 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -49,7 +49,7 @@ std::vector> ArrayFloatNear(const std::vector& values, return matchers; } -int SingleOpModel::AddTensor(TensorData t) { +int SingleOpModel::AddTensor(TensorData t, std::initializer_list data) { int id = tensors_.size(); // This is slightly different depending on whether we are adding a @@ -78,8 +78,23 @@ int SingleOpModel::AddTensor(TensorData t) { builder_.CreateVector({t.zero_point})); } - tensors_.push_back(CreateTensor(builder_, builder_.CreateVector({}), - t.type, /*buffer=*/0, + int buffer_id = 0; + if (data.size()) { + // Initialize buffers list with empty buffer to allow for non-const tensors. + if (buffers_.empty()) { + buffers_.push_back(CreateBuffer(builder_, builder_.CreateVector({}))); + } + + // Add data as a Buffer to buffers list. + buffer_id = buffers_.size(); + auto data_buffer = + builder_.CreateVector(reinterpret_cast(data.begin()), + sizeof(int) * data.size()); + buffers_.push_back(CreateBuffer(builder_, data_buffer)); + } + + tensors_.push_back(CreateTensor(builder_, builder_.CreateVector(t.shape), + t.type, /*buffer=*/buffer_id, /*name=*/0, q_params)); tensor_data_[id] = t; @@ -88,7 +103,15 @@ int SingleOpModel::AddTensor(TensorData t) { } int SingleOpModel::AddInput(const TensorData& t) { - int id = AddTensor(t); + int id = AddTensor(t, {}); + inputs_.push_back(id); + return id; +} + +int SingleOpModel::AddConstInput(TensorType type, + std::initializer_list data, + std::initializer_list shape) { + int id = AddTensor(TensorData{type, shape}, data); inputs_.push_back(id); return id; } @@ -100,7 +123,7 @@ int SingleOpModel::AddNullInput() { } int SingleOpModel::AddOutput(const TensorData& t) { - int id = AddTensor(t); + int id = AddTensor(t, {}); outputs_.push_back(id); return id; } @@ -142,19 +165,21 @@ void SingleOpModel::BuildInterpreter( subgraphs.push_back(subgraph); auto subgraphs_flatbuffer = builder_.CreateVector(subgraphs); - std::vector> buffers_vec; - auto buffers = builder_.CreateVector(buffers_vec); + auto buffers = builder_.CreateVector(buffers_); auto description = builder_.CreateString("programmatic model"); builder_.Finish(CreateModel(builder_, TFLITE_SCHEMA_VERSION, opcodes, subgraphs_flatbuffer, description, buffers)); auto* model = GetModel(builder_.GetBufferPointer()); - ops::builtin::BuiltinOpResolver builtins; - for (const auto& reg : custom_registrations_) { - builtins.AddCustom(reg.first.data(), reg.second()); + if (!resolver_) { + auto resolver = new ops::builtin::BuiltinOpResolver(); + for (const auto& reg : custom_registrations_) { + resolver->AddCustom(reg.first.data(), reg.second()); + } + resolver_ = std::unique_ptr(resolver); } - InterpreterBuilder(model, builtins)(&interpreter_); + InterpreterBuilder(model, *resolver_)(&interpreter_); CHECK(interpreter_ != nullptr); @@ -180,4 +205,17 @@ int32_t SingleOpModel::GetTensorSize(int index) const { return total_size; } +template <> +std::vector SingleOpModel::ExtractVector(int index) { + TfLiteTensor* tensor_ptr = interpreter_->tensor(index); + CHECK(tensor_ptr != nullptr); + const int num_strings = GetStringCount(tensor_ptr); + std::vector result; + result.reserve(num_strings); + for (int i = 0; i < num_strings; ++i) { + const auto str = GetString(tensor_ptr, i); + result.emplace_back(str.str, str.len); + } + return result; +} } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index e68e49466119c50ec123edb84f1b1b6390a15a60..7d476ba1eaffbb24fb77390c0e71c32d60b6411e 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#define TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ #include @@ -24,16 +24,11 @@ limitations under the License. #include "tensorflow/contrib/lite/kernels/register.h" #include "tensorflow/contrib/lite/model.h" #include "tensorflow/contrib/lite/string_util.h" +#include "tensorflow/contrib/lite/testing/util.h" #include "tensorflow/core/platform/logging.h" namespace tflite { -inline void LogToStderr() { -#ifdef PLATFORM_GOOGLE - FLAGS_logtostderr = true; -#endif -} - // A gmock matcher that check that elements of a float vector match to a given // tolerance. std::vector<::testing::Matcher> ArrayFloatNear( @@ -90,6 +85,23 @@ struct TensorData { int32_t zero_point; }; +class SingleOpResolver : public OpResolver { + public: + SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration) + : op_(op), registration_(registration) {} + TfLiteRegistration* FindOp(BuiltinOperator op) const override { + if (op == op_) { + return registration_; + } + return nullptr; + } + TfLiteRegistration* FindOp(const char* op) const override { return nullptr; } + + private: + const BuiltinOperator op_; + TfLiteRegistration* registration_; +}; + class SingleOpModel { public: SingleOpModel() {} @@ -103,6 +115,10 @@ class SingleOpModel { int AddInput(TensorType type) { return AddInput(TensorData{type}); } int AddInput(const TensorData& t); + // Add a Tensor containing const data and return the tensor id. + int AddConstInput(TensorType type, std::initializer_list data, + std::initializer_list shape); + // Add a null input tensor (optional input) and return kOptionalTensor. int AddNullInput(); @@ -179,14 +195,19 @@ class SingleOpModel { return result; } + void SetResolver(std::unique_ptr resolver) { + resolver_ = std::move(resolver); + } + protected: int32_t GetTensorSize(int index) const; flatbuffers::FlatBufferBuilder builder_; std::unique_ptr interpreter_; + std::unique_ptr resolver_; private: - int AddTensor(TensorData t); + int AddTensor(TensorData t, std::initializer_list data); std::map tensor_data_; std::vector inputs_; @@ -194,9 +215,43 @@ class SingleOpModel { std::vector> tensors_; std::vector> opcodes_; std::vector> operators_; + std::vector> buffers_; std::map> custom_registrations_; }; +// Base class for single op unit tests. +// The tests are parameterized to test multiple kernels for a single op. +// The parameters are strings like "optimized" and "reference" to have better +// readability in test reports. +// +// To use this class: +// * Define a constant map from strings to TfLiteRegistration. +// * Implement a test class that inherits SingleOpTest. +// * Instantiate the test cases with SingleOpTest::GetKernelTags helper +// function. +// * Call GetRegistration to get the TfLiteRegistration to be used before +// building the interpreter. +class SingleOpTest : public ::testing::TestWithParam { + public: + static std::vector GetKernelTags( + const std::map& kernel_map) { + std::vector tags; + for (auto it : kernel_map) { + tags.push_back(it.first); + } + return tags; + } + + protected: + virtual const std::map& GetKernelMap() = 0; + TfLiteRegistration* GetRegistration() { + return GetKernelMap().at(GetParam()); + } +}; + +// Strings have a special implementation that is in test_util.cc +template <> +std::vector SingleOpModel::ExtractVector(int index); } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/kernels/transpose.cc b/tensorflow/contrib/lite/kernels/transpose.cc new file mode 100644 index 0000000000000000000000000000000000000000..d3c10a9bb7b07404ccd8cfe2636473a622b91787 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/transpose.cc @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace transpose { + +// This file has two implementations of Transpose. +enum KernelType { + kReference, +}; + +struct TransposeContext { + TransposeContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + perm = GetInput(context, node, 1); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* perm; + TfLiteTensor* output; +}; + +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, + TransposeContext* op_context) { + int dims = NumDimensions(op_context->input); + const int* perm_data = GetTensorData(op_context->perm); + + // Ensure validity of the permutations tensor as a 1D tensor. + TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->perm), 1); + TF_LITE_ENSURE_EQ(context, op_context->perm->dims->data[0], dims); + for (int idx = 0; idx < dims; ++idx) { + TF_LITE_ENSURE_MSG(context, (perm_data[idx] >= 0 && perm_data[idx] < dims), + "Transpose op permutations array is out of bounds."); + } + + // Determine size of output tensor. + TfLiteIntArray* input_size = op_context->input->dims; + TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); + for (int idx = 0; idx < dims; ++idx) { + output_size->data[idx] = input_size->data[perm_data[idx]]; + } + + return context->ResizeTensor(context, op_context->output, output_size); +} + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + TransposeContext op_context(context, node); + + // Ensure validity of input tensor. + TF_LITE_ENSURE_MSG(context, NumDimensions(op_context.input) <= 4, + "Transpose op only supports 1D-4D input arrays."); + TF_LITE_ENSURE_EQ(context, op_context.input->type, op_context.output->type); + + if (!IsConstantTensor(op_context.perm)) { + SetTensorToDynamic(op_context.output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, &op_context); +} + +template +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TransposeContext op_context(context, node); + + // Resize the output tensor if the output tensor is dynamic. + if (IsDynamicTensor(op_context.output)) { + TF_LITE_ENSURE_OK(context, ResizeOutputTensor(context, &op_context)); + } + + // Reverse the permuted axes and convert to 4D due to the way Dims are + // constructed in GetTensorDims. + const int* perm_data = GetTensorData(op_context.perm); + const int size = op_context.perm->dims->data[0]; + const int kOutputDimensionNum = 4; + int reversed_perm[kOutputDimensionNum]; + + for (int output_k = 0, input_k = size - 1; output_k < size; + ++output_k, --input_k) { + reversed_perm[output_k] = size - perm_data[input_k] - 1; + } + for (int k = size; k < kOutputDimensionNum; ++k) { + reversed_perm[k] = k; + } + +#define TF_LITE_TRANSPOSE(type, scalar) \ + type::Transpose(GetTensorData(op_context.input), \ + GetTensorDims(op_context.input), \ + GetTensorData(op_context.output), \ + GetTensorDims(op_context.output), reversed_perm) + + switch (op_context.input->type) { + case kTfLiteFloat32: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, float); + } + break; + case kTfLiteUInt8: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, uint8_t); + } + break; + case kTfLiteInt32: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, int32_t); + } + break; + case kTfLiteInt64: + if (kernel_type == kReference) { + TF_LITE_TRANSPOSE(reference_ops, int64_t); + } + break; + default: + context->ReportError(context, + "Type is currently not supported by Transpose."); + return kTfLiteError; + } +#undef TF_LITE_TRANSPOSE + + return kTfLiteOk; +} + +} // namespace transpose + +TfLiteRegistration* Register_TRANSPOSE_REF() { + static TfLiteRegistration r = {nullptr, nullptr, transpose::Prepare, + transpose::Eval}; + return &r; +} + +TfLiteRegistration* Register_TRANSPOSE() { return Register_TRANSPOSE_REF(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/transpose_test.cc b/tensorflow/contrib/lite/kernels/transpose_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..337bc144b967392523bf784603cca4c1b968cdf2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/transpose_test.cc @@ -0,0 +1,347 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +void RunTestPermutation(const std::vector& shape, + const std::vector& perms, + std::vector* input_transposed) { + // Count elements and allocate output. + int count = 1; + for (auto factor : shape) count *= factor; + input_transposed->resize(count); + + // Create the dummy data + std::vector input(count); + for (int i = 0; i < input.size(); i++) { + input[i] = i; + } + + // Create reversed and padded perms. + int reversed_perms[4]; + for (int output_k = 0, input_k = shape.size() - 1; output_k < shape.size(); + output_k++, input_k--) { + reversed_perms[output_k] = shape.size() - perms[input_k] - 1; + } + // Unused dimensions should not be permuted so pad with identity transform + // subset. + for (int k = shape.size(); k < 4; k++) { + reversed_perms[k] = k; + } + + // Make input and output dims (i.e. reversed shape and dest_shape). + Dims<4> input_dims = GetTensorDims(shape); + Dims<4> output_dims; + for (int i = 0; i < 4; i++) { + output_dims.sizes[i] = input_dims.sizes[reversed_perms[i]]; + } + output_dims.strides[0] = 1; + for (int k = 1; k < 4; k++) { + output_dims.strides[k] = + output_dims.strides[k - 1] * output_dims.sizes[k - 1]; + } + + reference_ops::Transpose(input.data(), input_dims, + input_transposed->data(), output_dims, + reversed_perms); +} + +TEST(TransposeTest, TestRefOps1D) { + // Basic 1D identity. + std::vector out; + RunTestPermutation({3}, {0}, &out); + ASSERT_EQ(out, std::vector({0, 1, 2})); +} + +TEST(TransposeTest, TestRefOps2D) { + std::vector out; + // Basic 2D. + RunTestPermutation({3, 2}, {1, 0}, &out); + ASSERT_EQ(out, std::vector({0, 2, 4, 1, 3, 5})); + // Identity. + RunTestPermutation({3, 2}, {0, 1}, &out); + ASSERT_EQ(out, std::vector({0, 1, 2, 3, 4, 5})); +} + +TEST(TransposeTest, TestRefOps3D) { + std::vector out; + // Test 3 dimensional + { + std::vector ref({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23}); + RunTestPermutation({2, 3, 4}, {2, 0, 1}, &out); + ASSERT_EQ(out, ref); + } + // Test 3 dimensional identity transform + { + RunTestPermutation({2, 3, 4}, {0, 1, 2}, &out); + std::vector ref(out.size()); + for (int k = 0; k < ref.size(); k++) ref[k] = k; + ASSERT_EQ(out, ref); + } +} + +TEST(TransposeTest, TestRefOps4D) { + std::vector out; + // Basic 4d. + RunTestPermutation({2, 3, 4, 5}, {2, 0, 1, 3}, &out); + ASSERT_EQ( + out, + std::vector( + {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, + 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, + 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, + 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109, + 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54, + 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59, + 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119})); + RunTestPermutation({2, 3, 4, 5}, {0, 1, 2, 3}, &out); + // Basic identity. + std::vector ref(out.size()); + for (int k = 0; k < ref.size(); k++) ref[k] = k; + ASSERT_EQ(out, ref); +} + +class TransposeOpModel : public SingleOpModel { + public: + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetPerm(std::initializer_list data) { + PopulateTensor(perm_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int perm_; + int output_; +}; + +// Tests case where perm is a const tensor. +// +// Example usage is as follows: +// SpaceToBatchNDOpConstModel m(input_shape, perm_shape, perm_data); +// m.SetInput(input_data); +// m.Invoke(); +class TransposeOpConstModel : public TransposeOpModel { + public: + TransposeOpConstModel(std::initializer_list input_shape, + std::initializer_list perm_shape, + std::initializer_list perm) { + input_ = AddInput(TensorType_FLOAT32); + perm_ = AddConstInput(TensorType_INT32, perm, perm_shape); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_).Union()); + BuildInterpreter({input_shape}); + } +}; + +// Tests case where perm is a non-const tensor. +// +// Example usage is as follows: +// TransposeOpDynamicModel m(input_shape, perm_shape); +// m.SetInput(input_data); +// m.SetPerm(perm_data); +// m.Invoke(); +class TransposeOpDynamicModel : public TransposeOpModel { + public: + TransposeOpDynamicModel(std::initializer_list input_shape, + std::initializer_list perm_shape) { + input_ = AddInput(TensorType_FLOAT32); + perm_ = AddInput(TensorType_INT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_TRANSPOSE, BuiltinOptions_TransposeOptions, + CreateTransposeOptions(builder_).Union()); + BuildInterpreter({input_shape, perm_shape}); + } +}; + +TEST(TransposeTest, TestUnequalPermSize) { + EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {2}, {2, 2}), "2 != 4"); +} + +TEST(TransposeTest, TestPermOutOfBounds) { + EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, -1, -2, -3}), + "Transpose op permutations array is out of bounds."); + EXPECT_DEATH(TransposeOpConstModel({1, 3, 3, 1}, {4}, {0, 1, 2, 4}), + "Transpose op permutations array is out of bounds."); +} + +TEST(TransposeTest, Test1DInputConstTensor) { + TransposeOpConstModel m({3}, {1}, {0}); + m.SetInput({1, 2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(TransposeTest, Test1DInputDynamicTensor) { + TransposeOpDynamicModel m({3}, {1}); + m.SetInput({1, 2, 3}); + m.SetPerm({0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3})); +} + +TEST(TransposeTest, Test2DInputConstTensor) { + TransposeOpConstModel m({3, 2}, {2}, {1, 0}); + m.SetInput({0, 1, 2, 3, 4, 5}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5})); +} + +TEST(TransposeTest, Test2DInputDynamicTensor) { + TransposeOpDynamicModel m({3, 2}, {2}); + m.SetInput({0, 1, 2, 3, 4, 5}); + m.SetPerm({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({2, 3})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({0, 2, 4, 1, 3, 5})); +} + +TEST(TransposeTest, Test3DInputConstTensor) { + TransposeOpConstModel m({2, 3, 4}, {3}, {2, 0, 1}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); +} + +TEST(TransposeTest, Test3DInputDynamicTensor) { + TransposeOpDynamicModel m({2, 3, 4}, {3}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}); + m.SetPerm({2, 0, 1}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3})); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray({0, 4, 8, 12, 16, 20, 1, 5, 9, 13, 17, 21, + 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23})); +} + +TEST(TransposeTest, Test5DInputTensor) { + EXPECT_DEATH(TransposeOpConstModel({1, 2, 3, 4, 5}, {5}, {0, 1, 2, 3, 4}), + "Transpose op only supports 1D-4D input arrays."); +} + +TEST(TransposeTest, SimpleTestNoReorderConstTensor) { + TransposeOpConstModel m({1, 2, 3, 1}, {4}, {0, 1, 2, 3}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(TransposeTest, SimpleTestNoReorderDynamicTensor) { + TransposeOpDynamicModel m({1, 2, 3, 1}, {4}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.SetPerm({0, 1, 2, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 2, 3, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 2, 3, 4, 5, 6})); +} + +TEST(TransposeTest, SimpleTestWithReorderConstTensor) { + TransposeOpConstModel m({1, 2, 3, 1}, {4}, {2, 1, 3, 0}); + m.SetInput({1, 2, 3, 4, 5, 6}); + m.Invoke(); + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 2, 1, 1})); + EXPECT_THAT(m.GetOutput(), ElementsAreArray({1, 4, 2, 5, 3, 6})); +} + +TEST(TransposeTest, ComplexTestWithReorderConstTensor) { + TransposeOpConstModel m({2, 3, 4, 5}, {4}, {2, 0, 1, 3}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5})); + auto result = ElementsAreArray( + {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, + 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, + 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, + 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109, + 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54, + 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59, + 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}); + EXPECT_THAT(m.GetOutput(), result); +} + +TEST(TransposeTest, ComplexTestWithReorderDynamicTensor) { + TransposeOpDynamicModel m({2, 3, 4, 5}, {4}); + m.SetInput({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, + 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, + 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, + 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, + 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, + 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119}); + m.SetPerm({2, 0, 1, 3}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({4, 2, 3, 5})); + auto result = ElementsAreArray( + {0, 1, 2, 3, 4, 20, 21, 22, 23, 24, 40, 41, 42, 43, 44, + 60, 61, 62, 63, 64, 80, 81, 82, 83, 84, 100, 101, 102, 103, 104, + 5, 6, 7, 8, 9, 25, 26, 27, 28, 29, 45, 46, 47, 48, 49, + 65, 66, 67, 68, 69, 85, 86, 87, 88, 89, 105, 106, 107, 108, 109, + 10, 11, 12, 13, 14, 30, 31, 32, 33, 34, 50, 51, 52, 53, 54, + 70, 71, 72, 73, 74, 90, 91, 92, 93, 94, 110, 111, 112, 113, 114, + 15, 16, 17, 18, 19, 35, 36, 37, 38, 39, 55, 56, 57, 58, 59, + 75, 76, 77, 78, 79, 95, 96, 97, 98, 99, 115, 116, 117, 118, 119}); + EXPECT_THAT(m.GetOutput(), result); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc new file mode 100644 index 0000000000000000000000000000000000000000..9cdb58714edb5fee771fc45f3c53a570f8fb28d1 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm.cc @@ -0,0 +1,527 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor_utils.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace unidirectional_sequence_lstm { + +// Input Tensors of size {max_time, n_batch, n_input} +constexpr int kInputTensor = 0; + +// Input weight tensors of size: {n_cell, n_input} +constexpr int kInputToInputWeightsTensor = 1; // Optional +constexpr int kInputToForgetWeightsTensor = 2; +constexpr int kInputToCellWeightsTensor = 3; +constexpr int kInputToOutputWeightsTensor = 4; + +// Recurrent weight tensors of size {n_cell, n_output} +constexpr int kRecurrentToInputWeightsTensor = 5; // Optional +constexpr int kRecurrentToForgetWeightsTensor = 6; +constexpr int kRecurrentToCellWeightsTensor = 7; +constexpr int kRecurrentToOutputWeightsTensor = 8; + +// Peephole weights tensors of size {n_cell}, representing a diagonal matrix. +constexpr int kCellToInputWeightsTensor = 9; // Optional +constexpr int kCellToForgetWeightsTensor = 10; // Optional +constexpr int kCellToOutputWeightsTensor = 11; // Optional + +// Gates bias tensors of size {n_cell} +constexpr int kInputGateBiasTensor = 12; // Optional +constexpr int kForgetGateBiasTensor = 13; +constexpr int kCellGateBiasTensor = 14; +constexpr int kOutputGateBiasTensor = 15; + +// Projection weight tensor of size {n_output, n_cell} +constexpr int kProjectionWeightsTensor = 16; // Optional +// Projection bias tensor of size {n_output} +constexpr int kProjectionBiasTensor = 17; // Optional + +// Output tensors. +constexpr int kScratchBufferTensor = 0; +constexpr int kOutputStateTensor = 1; +constexpr int kCellStateTensor = 2; +constexpr int kOutputTensor = 3; + +// Check that input tensor dimensions matches with each other. +TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, + TfLiteNode* node, int n_input, + int n_output, int n_cell) { + auto* params = reinterpret_cast(node->builtin_data); + + // Making sure clipping parameters have valid values. + // == 0 means no clipping + // > 0 means clipping + TF_LITE_ENSURE(context, params->cell_clip >= 0); + TF_LITE_ENSURE(context, params->proj_clip >= 0); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + if (input_to_input_weights) { + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_input_weights->dims->data[1], n_input); + } + + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_forget_weights->dims->data[1], n_input); + + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, input_to_cell_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + if (recurrent_to_input_weights) { + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_input_weights->dims->data[1], + n_output); + } + + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[0], + n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_forget_weights->dims->data[1], + n_output); + + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[0], n_cell); + TF_LITE_ENSURE_EQ(context, recurrent_to_cell_weights->dims->data[1], + n_output); + + // We make sure the input-gate's parameters are either both present (regular + // LSTM) or not at all (CIFG-LSTM). + const bool cifg_weights_all_or_none = + ((input_to_input_weights != nullptr) && + (recurrent_to_input_weights != nullptr)) || + ((input_to_input_weights == nullptr) && + (recurrent_to_input_weights == nullptr)); + TF_LITE_ENSURE(context, cifg_weights_all_or_none == true); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + if (cell_to_input_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_input_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + if (cell_to_forget_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_forget_weights->dims->data[0], n_cell); + } + + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + if (cell_to_output_weights) { + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_to_output_weights->dims->data[0], n_cell); + } + + // Making sure the peephole weights are there all or none. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool peephole_weights_all_or_none = + ((cell_to_input_weights != nullptr || use_cifg) && + (cell_to_forget_weights != nullptr) && + (cell_to_output_weights != nullptr)) || + ((cell_to_input_weights == nullptr) && + (cell_to_forget_weights == nullptr) && + (cell_to_output_weights == nullptr)); + TF_LITE_ENSURE(context, peephole_weights_all_or_none == true); + + // Make sure the input gate bias is present only when not a CIFG-LSTM. + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + if (use_cifg) { + TF_LITE_ENSURE_EQ(context, input_gate_bias, nullptr); + } else { + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, input_gate_bias->dims->data[0], n_cell); + } + + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, forget_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, cell_bias->dims->data[0], n_cell); + + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, output_gate_bias->dims->data[0], n_cell); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + if (projection_weights) { + TF_LITE_ENSURE_EQ(context, projection_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[0], n_output); + TF_LITE_ENSURE_EQ(context, projection_weights->dims->data[1], n_cell); + } + + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + if (projection_bias) { + TF_LITE_ENSURE_EQ(context, projection_bias->dims->size, 1); + TF_LITE_ENSURE_EQ(context, projection_bias->dims->data[0], n_output); + } + + // Making sure the projection tensors are consistent: + // 1) If projection weight is not present, then projection bias should not be + // present. + // 2) If projection weight is present, then projection bias is optional. + // TODO(ghodrat): make sure this is correct. + const bool projecton_tensors_consistent = + ((projection_weights != nullptr) || (projection_bias == nullptr)); + TF_LITE_ENSURE(context, projecton_tensors_consistent == true); + + return kTfLiteOk; +} + +// Resize the output, state and scratch tensors based on the sizes of the input +// tensors. Also check that the size of the input tensors match each other. +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 18); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 4); + + // Inferring batch size, number of outputs and sequence length and + // number of cells from the input tensors. + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TF_LITE_ENSURE(context, input->dims->size > 1); + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + const int n_cell = input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, input_to_output_weights->dims->data[1], n_input); + + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, recurrent_to_output_weights->dims->data[0], + n_cell); + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Check that input tensor dimensions matches with each other. + CheckInputTensorDimensions(context, node, n_input, n_output, n_cell); + + // Get the pointer to output, state and scratch buffer tensors. + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + // TODO(ghodrat): Modify this as soon as we have a finalized method for + // scratch buffers. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + + // Resize the output and output_state tensors. + TfLiteIntArray* output_size = TfLiteIntArrayCreate(3); + output_size->data[0] = max_time; + output_size->data[1] = n_batch; + output_size->data[2] = n_output; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size)); + + TfLiteIntArray* output_state_size = TfLiteIntArrayCreate(2); + output_state_size->data[0] = n_batch; + output_state_size->data[1] = n_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, output_state, output_state_size)); + + // Resize the scratch buffer tensor. + TfLiteIntArray* cell_size = TfLiteIntArrayCreate(2); + cell_size->data[0] = n_batch; + cell_size->data[1] = n_cell; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, cell_state, cell_size)); + + // Mark state tensors as persistent tensors. + output_state->allocation_type = kTfLiteArenaRwPersistent; + cell_state->allocation_type = kTfLiteArenaRwPersistent; + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + const bool use_cifg = (input_to_input_weights == nullptr); + if (use_cifg) { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 3; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } else { + TfLiteIntArray* scratch_buffer_size = TfLiteIntArrayCreate(2); + scratch_buffer_size->data[0] = n_batch; + // Reserving space for Input, Cell, Forget, Output gates + scratch_buffer_size->data[1] = n_cell * 4; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, scratch_buffer, + scratch_buffer_size)); + } + return kTfLiteOk; +} + +// The LSTM Op engine. +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + + TfLiteTensor* input_to_input_weights = + GetOptionalInputTensor(context, node, kInputToInputWeightsTensor); + TfLiteTensor* input_to_forget_weights = + GetInput(context, node, kInputToForgetWeightsTensor); + TfLiteTensor* input_to_cell_weights = + GetInput(context, node, kInputToCellWeightsTensor); + TfLiteTensor* input_to_output_weights = + GetInput(context, node, kInputToOutputWeightsTensor); + + TfLiteTensor* recurrent_to_input_weights = + GetOptionalInputTensor(context, node, kRecurrentToInputWeightsTensor); + TfLiteTensor* recurrent_to_forget_weights = + GetInput(context, node, kRecurrentToForgetWeightsTensor); + TfLiteTensor* recurrent_to_cell_weights = + GetInput(context, node, kRecurrentToCellWeightsTensor); + TfLiteTensor* recurrent_to_output_weights = + GetInput(context, node, kRecurrentToOutputWeightsTensor); + + TfLiteTensor* cell_to_input_weights = + GetOptionalInputTensor(context, node, kCellToInputWeightsTensor); + TfLiteTensor* cell_to_forget_weights = + GetOptionalInputTensor(context, node, kCellToForgetWeightsTensor); + TfLiteTensor* cell_to_output_weights = + GetOptionalInputTensor(context, node, kCellToOutputWeightsTensor); + + TfLiteTensor* input_gate_bias = + GetOptionalInputTensor(context, node, kInputGateBiasTensor); + TfLiteTensor* forget_gate_bias = + GetInput(context, node, kForgetGateBiasTensor); + TfLiteTensor* cell_bias = GetInput(context, node, kCellGateBiasTensor); + TfLiteTensor* output_gate_bias = + GetInput(context, node, kOutputGateBiasTensor); + + TfLiteTensor* projection_weights = + GetOptionalInputTensor(context, node, kProjectionWeightsTensor); + TfLiteTensor* projection_bias = + GetOptionalInputTensor(context, node, kProjectionBiasTensor); + + TfLiteTensor* output_state = GetOutput(context, node, kOutputStateTensor); + TfLiteTensor* cell_state = GetOutput(context, node, kCellStateTensor); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + + const int max_time = input->dims->data[0]; + const int n_batch = input->dims->data[1]; + const int n_input = input->dims->data[2]; + // n_cell and n_output will be the same size when there is no projection. + const int n_cell = input_to_output_weights->dims->data[0]; + const int n_output = recurrent_to_output_weights->dims->data[1]; + + // Since we have already checked that weights are all there or none, we can + // check the existense of only one to the get the condition. + const bool use_cifg = (input_to_input_weights == nullptr); + const bool use_peephole = (cell_to_output_weights != nullptr); + + // Index the scratch buffers pointers to the global scratch buffer. + TfLiteTensor* scratch_buffer = GetOutput(context, node, kScratchBufferTensor); + float* input_gate_scratch = nullptr; + float* cell_scratch = nullptr; + float* forget_gate_scratch = nullptr; + float* output_gate_scratch = nullptr; + if (use_cifg) { + cell_scratch = scratch_buffer->data.f; + forget_gate_scratch = scratch_buffer->data.f + n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + } else { + input_gate_scratch = scratch_buffer->data.f; + cell_scratch = scratch_buffer->data.f + n_cell * n_batch; + forget_gate_scratch = scratch_buffer->data.f + 2 * n_cell * n_batch; + output_gate_scratch = scratch_buffer->data.f + 3 * n_cell * n_batch; + } + + for (int t = 0; t < max_time; t++) { + const float* input_ptr_time = input->data.f + t * n_batch * n_input; + // Initialize scratch buffers with bias. + if (!use_cifg) { + tensor_utils::VectorBatchVectorAssign(input_gate_bias->data.f, n_cell, + n_batch, input_gate_scratch); + } + tensor_utils::VectorBatchVectorAssign(forget_gate_bias->data.f, n_cell, + n_batch, forget_gate_scratch); + tensor_utils::VectorBatchVectorAssign(cell_bias->data.f, n_cell, n_batch, + cell_scratch); + tensor_utils::VectorBatchVectorAssign(output_gate_bias->data.f, n_cell, + n_batch, output_gate_scratch); + + // For each batch and cell: compute input_weight * input. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_input_weights->data.f, n_cell, n_input, input_ptr_time, + n_batch, input_gate_scratch, /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_forget_weights->data.f, n_cell, n_input, input_ptr_time, + n_batch, forget_gate_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_cell_weights->data.f, n_cell, n_input, input_ptr_time, n_batch, + cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + input_to_output_weights->data.f, n_cell, n_input, input_ptr_time, + n_batch, output_gate_scratch, /*result_stride=*/1); + + // For each batch and cell: compute recurrent_weight * output_state. + if (!use_cifg) { + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_input_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, input_gate_scratch, + /*result_stride=*/1); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_forget_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, forget_gate_scratch, + /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_cell_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, cell_scratch, /*result_stride=*/1); + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + recurrent_to_output_weights->data.f, n_cell, n_output, + output_state->data.f, n_batch, output_gate_scratch, + /*result_stride=*/1); + + // For each batch and cell: update input gate. + if (!use_cifg) { + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_input_weights->data.f, n_cell, cell_state->data.f, n_batch, + input_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(input_gate_scratch, n_cell * n_batch, + input_gate_scratch); + } + + // For each batch and cell: update forget gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_forget_weights->data.f, n_cell, cell_state->data.f, n_batch, + forget_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(forget_gate_scratch, n_cell * n_batch, + forget_gate_scratch); + + // For each batch and cell: update the cell. + tensor_utils::VectorVectorCwiseProduct(forget_gate_scratch, + cell_state->data.f, n_batch * n_cell, + cell_state->data.f); + tensor_utils::ApplyActivationToVector(cell_scratch, n_batch * n_cell, + params->activation, cell_scratch); + if (use_cifg) { + tensor_utils::Sub1Vector(forget_gate_scratch, n_batch * n_cell, + forget_gate_scratch); + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, forget_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } else { + tensor_utils::VectorVectorCwiseProductAccumulate( + cell_scratch, input_gate_scratch, n_batch * n_cell, + cell_state->data.f); + } + if (params->cell_clip > 0.0) { + tensor_utils::ClipVector(cell_state->data.f, n_batch * n_cell, + params->cell_clip, cell_state->data.f); + } + + // For each batch and cell: update the output gate. + if (use_peephole) { + tensor_utils::VectorBatchVectorCwiseProductAccumulate( + cell_to_output_weights->data.f, n_cell, cell_state->data.f, n_batch, + output_gate_scratch); + } + tensor_utils::ApplySigmoidToVector(output_gate_scratch, n_batch * n_cell, + output_gate_scratch); + tensor_utils::ApplyActivationToVector(cell_state->data.f, n_batch * n_cell, + params->activation, cell_scratch); + tensor_utils::VectorVectorCwiseProduct(output_gate_scratch, cell_scratch, + n_batch * n_cell, + output_gate_scratch); + + // For each batch: update the projection and output_state. + const bool use_projection_weight = (projection_weights != nullptr); + const bool use_projection_bias = (projection_bias != nullptr); + float* output_ptr_time = output->data.f + t * n_batch * n_output; + if (use_projection_weight) { + if (use_projection_bias) { + tensor_utils::VectorBatchVectorAssign(projection_bias->data.f, n_output, + n_batch, output_ptr_time); + } else { + tensor_utils::ZeroVector(output_ptr_time, n_batch * n_output); + } + tensor_utils::MatrixBatchVectorMultiplyAccumulate( + projection_weights->data.f, n_output, n_cell, output_gate_scratch, + n_batch, output_ptr_time, /*result_stride=*/1); + if (params->proj_clip > 0.0) { + tensor_utils::ClipVector(output_ptr_time, n_batch * n_output, + params->proj_clip, output_ptr_time); + } + } else { + tensor_utils::CopyVector(output_gate_scratch, n_batch * n_output, + output_ptr_time); + } + tensor_utils::CopyVector(output_ptr_time, n_batch * n_output, + output_state->data.f); + } + return kTfLiteOk; +} + +} // namespace unidirectional_sequence_lstm + +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_LSTM() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + unidirectional_sequence_lstm::Prepare, + unidirectional_sequence_lstm::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..93b635ae576e99854796d9fa997e5bf355b20534 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_lstm_test.cc @@ -0,0 +1,1089 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Sequential LSTM op. + +#include +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class UnidirectionalLSTMOpModel : public SingleOpModel { + public: + UnidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, + int sequence_length, bool use_cifg, + bool use_peephole, bool use_projection_weights, + bool use_projection_bias, float cell_clip, + float proj_clip, + const std::vector>& input_shapes) + : n_batch_(n_batch), + n_input_(n_input), + n_cell_(n_cell), + n_output_(n_output), + sequence_length_(sequence_length) { + input_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + input_to_input_weights_ = AddNullInput(); + } else { + input_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + input_to_forget_weights_ = AddInput(TensorType_FLOAT32); + input_to_cell_weights_ = AddInput(TensorType_FLOAT32); + input_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_cifg) { + recurrent_to_input_weights_ = AddNullInput(); + } else { + recurrent_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + + recurrent_to_forget_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_cell_weights_ = AddInput(TensorType_FLOAT32); + recurrent_to_output_weights_ = AddInput(TensorType_FLOAT32); + + if (use_peephole) { + if (use_cifg) { + cell_to_input_weights_ = AddNullInput(); + } else { + cell_to_input_weights_ = AddInput(TensorType_FLOAT32); + } + cell_to_forget_weights_ = AddInput(TensorType_FLOAT32); + cell_to_output_weights_ = AddInput(TensorType_FLOAT32); + } else { + cell_to_input_weights_ = AddNullInput(); + cell_to_forget_weights_ = AddNullInput(); + cell_to_output_weights_ = AddNullInput(); + } + + if (use_cifg) { + input_gate_bias_ = AddNullInput(); + } else { + input_gate_bias_ = AddInput(TensorType_FLOAT32); + } + forget_gate_bias_ = AddInput(TensorType_FLOAT32); + cell_bias_ = AddInput(TensorType_FLOAT32); + output_gate_bias_ = AddInput(TensorType_FLOAT32); + + if (use_projection_weights) { + projection_weights_ = AddInput(TensorType_FLOAT32); + if (use_projection_bias) { + projection_bias_ = AddInput(TensorType_FLOAT32); + } else { + projection_bias_ = AddNullInput(); + } + } else { + projection_weights_ = AddNullInput(); + projection_bias_ = AddNullInput(); + } + + scratch_buffer_ = AddOutput(TensorType_FLOAT32); + // TODO(ghodrat): Modify these states when we have a permanent solution for + // persistent buffer. + output_state_ = AddOutput(TensorType_FLOAT32); + cell_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + BuiltinOptions_LSTMOptions, + CreateLSTMOptions(builder_, ActivationFunctionType_TANH, + cell_clip, proj_clip) + .Union()); + BuildInterpreter(input_shapes); + } + + void SetInputToInputWeights(std::initializer_list f) { + PopulateTensor(input_to_input_weights_, f); + } + + void SetInputToForgetWeights(std::initializer_list f) { + PopulateTensor(input_to_forget_weights_, f); + } + + void SetInputToCellWeights(std::initializer_list f) { + PopulateTensor(input_to_cell_weights_, f); + } + + void SetInputToOutputWeights(std::initializer_list f) { + PopulateTensor(input_to_output_weights_, f); + } + + void SetRecurrentToInputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_input_weights_, f); + } + + void SetRecurrentToForgetWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_forget_weights_, f); + } + + void SetRecurrentToCellWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_cell_weights_, f); + } + + void SetRecurrentToOutputWeights(std::initializer_list f) { + PopulateTensor(recurrent_to_output_weights_, f); + } + + void SetCellToInputWeights(std::initializer_list f) { + PopulateTensor(cell_to_input_weights_, f); + } + + void SetCellToForgetWeights(std::initializer_list f) { + PopulateTensor(cell_to_forget_weights_, f); + } + + void SetCellToOutputWeights(std::initializer_list f) { + PopulateTensor(cell_to_output_weights_, f); + } + + void SetInputGateBias(std::initializer_list f) { + PopulateTensor(input_gate_bias_, f); + } + + void SetForgetGateBias(std::initializer_list f) { + PopulateTensor(forget_gate_bias_, f); + } + + void SetCellBias(std::initializer_list f) { + PopulateTensor(cell_bias_, f); + } + + void SetOutputGateBias(std::initializer_list f) { + PopulateTensor(output_gate_bias_, f); + } + + void SetProjectionWeights(std::initializer_list f) { + PopulateTensor(projection_weights_, f); + } + + void SetProjectionBias(std::initializer_list f) { + PopulateTensor(projection_bias_, f); + } + + void ResetOutputState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(output_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void ResetCellState() { + const int zero_buffer_size = n_cell_ * n_batch_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(cell_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int num_inputs() { return n_input_; } + int num_outputs() { return n_output_; } + int num_cells() { return n_cell_; } + int num_batches() { return n_batch_; } + int sequence_length() { return sequence_length_; } + + private: + int input_; + int input_to_input_weights_; + int input_to_forget_weights_; + int input_to_cell_weights_; + int input_to_output_weights_; + + int recurrent_to_input_weights_; + int recurrent_to_forget_weights_; + int recurrent_to_cell_weights_; + int recurrent_to_output_weights_; + + int cell_to_input_weights_; + int cell_to_forget_weights_; + int cell_to_output_weights_; + + int input_gate_bias_; + int forget_gate_bias_; + int cell_bias_; + int output_gate_bias_; + + int projection_weights_; + int projection_bias_; + + int output_; + int output_state_; + int cell_state_; + int scratch_buffer_; + + int n_batch_; + int n_input_; + int n_cell_; + int n_output_; + int sequence_length_; +}; + +TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + UnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/false, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + // Input should have n_input * sequence_length many values. + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.02973187, 0.1229473, 0.20885126, + -0.15358765, -0.03716109, 0.12507336, + 0.41193449, -0.20860538, -0.15053082, + 0.09120187, 0.24278517, -0.12222792}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output; + float* golden_end = + golden_start + lstm.num_outputs() * lstm.sequence_length(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + UnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, + /*use_peephole=*/true, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {0, 0}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {0, 0}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {0}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToCellWeights({-0.49770179, -0.27711356, -0.09624726, 0.05100781, + 0.04717243, 0.48944736, -0.38535351, + -0.17212132}); + + lstm.SetInputToForgetWeights({-0.55291498, -0.42866567, 0.13056988, + -0.3633365, -0.22755712, 0.28253698, 0.24407166, + 0.33826375}); + + lstm.SetInputToOutputWeights({0.10725588, -0.02335852, -0.55932593, + -0.09426838, -0.44257352, 0.54939759, + 0.01533556, 0.42751634}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToCellWeights( + {0.54066205, -0.32668582, -0.43562764, -0.56094903, 0.42957711, + 0.01841056, -0.32764608, -0.33027974, -0.10826075, 0.20675004, + 0.19069612, -0.03026325, -0.54532051, 0.33003211, 0.44901288, + 0.21193194}); + + lstm.SetRecurrentToForgetWeights( + {-0.13832897, -0.0515101, -0.2359007, -0.16661474, -0.14340827, + 0.36986142, 0.23414481, 0.55899, 0.10798943, -0.41174671, 0.17751795, + -0.34484994, -0.35874045, -0.11352962, 0.27268326, 0.54058349}); + + lstm.SetRecurrentToOutputWeights( + {0.41613156, 0.42610586, -0.16495961, -0.5663873, 0.30579174, -0.05115908, + -0.33941799, 0.23364776, 0.11178309, 0.09481031, -0.26424935, 0.46261835, + 0.50248802, 0.26114327, -0.43736315, 0.33149987}); + + lstm.SetCellToForgetWeights( + {0.47485286, -0.51955009, -0.24458408, 0.31544167}); + lstm.SetCellToOutputWeights( + {-0.17135078, 0.82760304, 0.85573703, -0.77109635}); + + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_golden_output[] = {-0.36444446, -0.00352185, 0.12886585, + -0.05163646, -0.42312205, -0.01218222, + 0.24201041, -0.08124574, -0.358325, + -0.04621704, 0.21641694, -0.06471302}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + float* golden_start = lstm_golden_output; + float* golden_end = + golden_start + lstm.num_outputs() * lstm.sequence_length(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { + const int n_batch = 2; + const int n_input = 5; + const int n_cell = 20; + const int n_output = 16; + const int sequence_length = 4; + + UnidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/true, /*use_projection_weights=*/true, + /*use_projection_bias=*/false, + /*cell_clip=*/0.0, /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {n_cell}, // cell_to_input_weight tensor + {n_cell}, // cell_to_forget_weight tensor + {n_cell}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {n_output, n_cell}, // projection_weight tensor + {0}, // projection_bias tensor + }); + + lstm.SetInputToInputWeights( + {0.021393683, 0.06124551, 0.046905167, -0.014657677, -0.03149463, + 0.09171803, 0.14647801, 0.10797193, -0.0057968358, 0.0019193048, + -0.2726754, 0.10154029, -0.018539885, 0.080349885, -0.10262385, + -0.022599787, -0.09121155, -0.008675967, -0.045206103, -0.0821282, + -0.008045952, 0.015478081, 0.055217247, 0.038719587, 0.044153627, + -0.06453243, 0.05031825, -0.046935108, -0.008164439, 0.014574226, + -0.1671009, -0.15519552, -0.16819797, -0.13971269, -0.11953059, + 0.25005487, -0.22790983, 0.009855087, -0.028140958, -0.11200698, + 0.11295408, -0.0035217577, 0.054485075, 0.05184695, 0.064711206, + 0.10989193, 0.11674786, 0.03490607, 0.07727357, 0.11390585, + -0.1863375, -0.1034451, -0.13945189, -0.049401227, -0.18767063, + 0.042483903, 0.14233552, 0.13832581, 0.18350165, 0.14545603, + -0.028545704, 0.024939531, 0.050929718, 0.0076203286, -0.0029723682, + -0.042484224, -0.11827596, -0.09171104, -0.10808628, -0.16327988, + -0.2273378, -0.0993647, -0.017155107, 0.0023917493, 0.049272764, + 0.0038534778, 0.054764505, 0.089753784, 0.06947234, 0.08014476, + -0.04544234, -0.0497073, -0.07135631, -0.048929106, -0.004042012, + -0.009284026, 0.018042054, 0.0036860977, -0.07427302, -0.11434604, + -0.018995456, 0.031487543, 0.012834908, 0.019977754, 0.044256654, + -0.39292613, -0.18519334, -0.11651281, -0.06809892, 0.011373677}); + + lstm.SetInputToForgetWeights( + {-0.0018401089, -0.004852237, 0.03698424, 0.014181704, 0.028273236, + -0.016726194, -0.05249759, -0.10204261, 0.00861066, -0.040979505, + -0.009899187, 0.01923892, -0.028177269, -0.08535103, -0.14585495, + 0.10662567, -0.01909731, -0.017883534, -0.0047269356, -0.045103323, + 0.0030784295, 0.076784775, 0.07463696, 0.094531395, 0.0814421, + -0.12257899, -0.033945758, -0.031303465, 0.045630626, 0.06843887, + -0.13492945, -0.012480007, -0.0811829, -0.07224499, -0.09628791, + 0.045100946, 0.0012300825, 0.013964662, 0.099372394, 0.02543059, + 0.06958324, 0.034257296, 0.0482646, 0.06267997, 0.052625068, + 0.12784666, 0.07077897, 0.025725935, 0.04165009, 0.07241905, + 0.018668644, -0.037377294, -0.06277783, -0.08833636, -0.040120605, + -0.011405586, -0.007808335, -0.010301386, -0.005102167, 0.027717464, + 0.05483423, 0.11449111, 0.11289652, 0.10939839, 0.13396506, + -0.08402166, -0.01901462, -0.044678304, -0.07720565, 0.014350063, + -0.11757958, -0.0652038, -0.08185733, -0.076754324, -0.092614375, + 0.10405491, 0.052960336, 0.035755895, 0.035839386, -0.012540553, + 0.036881298, 0.02913376, 0.03420159, 0.05448447, -0.054523353, + 0.02582715, 0.02327355, -0.011857179, -0.0011980024, -0.034641717, + -0.026125094, -0.17582615, -0.15923657, -0.27486774, -0.0006143371, + 0.0001771948, -8.470171e-05, 0.02651807, 0.045790765, 0.06956496}); + + lstm.SetInputToCellWeights( + {-0.04580283, -0.09549462, -0.032418985, -0.06454633, + -0.043528453, 0.043018587, -0.049152344, -0.12418144, + -0.078985475, -0.07596889, 0.019484362, -0.11434962, + -0.0074034138, -0.06314844, -0.092981495, 0.0062155537, + -0.025034338, -0.0028890965, 0.048929527, 0.06235075, + 0.10665918, -0.032036792, -0.08505916, -0.10843358, + -0.13002433, -0.036816437, -0.02130134, -0.016518239, + 0.0047691227, -0.0025825808, 0.066017866, 0.029991534, + -0.10652836, -0.1037554, -0.13056071, -0.03266643, + -0.033702414, -0.006473424, -0.04611692, 0.014419339, + -0.025174323, 0.0396852, 0.081777506, 0.06157468, + 0.10210095, -0.009658194, 0.046511717, 0.03603906, + 0.0069369148, 0.015960095, -0.06507666, 0.09551598, + 0.053568836, 0.06408714, 0.12835667, -0.008714329, + -0.20211966, -0.12093674, 0.029450472, 0.2849013, + -0.029227901, 0.1164364, -0.08560263, 0.09941786, + -0.036999565, -0.028842626, -0.0033637602, -0.017012902, + -0.09720865, -0.11193351, -0.029155117, -0.017936034, + -0.009768936, -0.04223324, -0.036159635, 0.06505112, + -0.021742892, -0.023377212, -0.07221364, -0.06430552, + 0.05453865, 0.091149814, 0.06387331, 0.007518393, + 0.055960953, 0.069779344, 0.046411168, 0.10509911, + 0.07463894, 0.0075130584, 0.012850982, 0.04555431, + 0.056955688, 0.06555285, 0.050801456, -0.009862683, + 0.00826772, -0.026555609, -0.0073611983, -0.0014897042}); + + lstm.SetInputToOutputWeights( + {-0.0998932, -0.07201956, -0.052803773, -0.15629593, -0.15001918, + -0.07650751, 0.02359855, -0.075155355, -0.08037709, -0.15093534, + 0.029517552, -0.04751393, 0.010350531, -0.02664851, -0.016839722, + -0.023121163, 0.0077019283, 0.012851257, -0.05040649, -0.0129761, + -0.021737747, -0.038305793, -0.06870586, -0.01481247, -0.001285394, + 0.10124236, 0.083122835, 0.053313006, -0.062235646, -0.075637154, + -0.027833903, 0.029774971, 0.1130802, 0.09218906, 0.09506135, + -0.086665764, -0.037162706, -0.038880914, -0.035832845, -0.014481564, + -0.09825003, -0.12048569, -0.097665586, -0.05287633, -0.0964047, + -0.11366429, 0.035777505, 0.13568819, 0.052451383, 0.050649304, + 0.05798951, -0.021852335, -0.099848844, 0.014740475, -0.078897946, + 0.04974699, 0.014160473, 0.06973932, 0.04964942, 0.033364646, + 0.08190124, 0.025535367, 0.050893165, 0.048514254, 0.06945813, + -0.078907564, -0.06707616, -0.11844508, -0.09986688, -0.07509403, + 0.06263226, 0.14925587, 0.20188436, 0.12098451, 0.14639415, + 0.0015017595, -0.014267382, -0.03417257, 0.012711468, 0.0028300495, + -0.024758482, -0.05098548, -0.0821182, 0.014225672, 0.021544158, + 0.08949725, 0.07505268, -0.0020780868, 0.04908258, 0.06476295, + -0.022907063, 0.027562456, 0.040185735, 0.019567577, -0.015598739, + -0.049097303, -0.017121866, -0.083368234, -0.02332002, -0.0840956}); + + lstm.SetInputGateBias( + {0.02234832, 0.14757581, 0.18176508, 0.10380666, 0.053110216, + -0.06928846, -0.13942584, -0.11816189, 0.19483899, 0.03652339, + -0.10250295, 0.036714908, -0.18426876, 0.036065217, 0.21810818, + 0.02383196, -0.043370757, 0.08690144, -0.04444982, 0.00030581196}); + + lstm.SetForgetGateBias({0.035185695, -0.042891346, -0.03032477, 0.23027696, + 0.11098921, 0.15378423, 0.09263801, 0.09790885, + 0.09508917, 0.061199076, 0.07665568, -0.015443159, + -0.03499149, 0.046190713, 0.08895977, 0.10899629, + 0.40694186, 0.06030037, 0.012413437, -0.06108739}); + + lstm.SetCellBias({-0.024379363, 0.0055531194, 0.23377132, 0.033463873, + -0.1483596, -0.10639995, -0.091433935, 0.058573797, + -0.06809782, -0.07889636, -0.043246906, -0.09829136, + -0.4279842, 0.034901652, 0.18797937, 0.0075234566, + 0.016178843, 0.1749513, 0.13975595, 0.92058027}); + + lstm.SetOutputGateBias( + {0.046159424, -0.0012809046, 0.03563469, 0.12648113, 0.027195795, + 0.35373217, -0.018957434, 0.008907322, -0.0762701, 0.12018895, + 0.04216877, 0.0022856654, 0.040952638, 0.3147856, 0.08225149, + -0.057416286, -0.14995944, -0.008040261, 0.13208859, 0.029760877}); + + lstm.SetRecurrentToInputWeights( + {-0.001374326, -0.078856036, 0.10672688, 0.029162422, + -0.11585556, 0.02557986, -0.13446963, -0.035785314, + -0.01244275, 0.025961924, -0.02337298, -0.044228926, + -0.055839065, -0.046598054, -0.010546039, -0.06900766, + 0.027239809, 0.022582639, -0.013296484, -0.05459212, + 0.08981, -0.045407712, 0.08682226, -0.06867011, + -0.14390695, -0.02916037, 0.000996957, 0.091420636, + 0.14283475, -0.07390571, -0.06402044, 0.062524505, + -0.093129106, 0.04860203, -0.08364217, -0.08119002, + 0.009352075, 0.22920375, 0.0016303885, 0.11583097, + -0.13732095, 0.012405723, -0.07551853, 0.06343048, + 0.12162708, -0.031923793, -0.014335606, 0.01790974, + -0.10650317, -0.0724401, 0.08554849, -0.05727212, + 0.06556731, -0.042729504, -0.043227166, 0.011683251, + -0.013082158, -0.029302018, -0.010899579, -0.062036745, + -0.022509435, -0.00964907, -0.01567329, 0.04260106, + -0.07787477, -0.11576462, 0.017356863, 0.048673786, + -0.017577527, -0.05527947, -0.082487635, -0.040137455, + -0.10820036, -0.04666372, 0.022746278, -0.07851417, + 0.01068115, 0.032956902, 0.022433773, 0.0026891115, + 0.08944216, -0.0685835, 0.010513544, 0.07228705, + 0.02032331, -0.059686817, -0.0005566496, -0.086984694, + 0.040414046, -0.1380399, 0.094208956, -0.05722982, + 0.012092817, -0.04989123, -0.086576, -0.003399834, + -0.04696032, -0.045747425, 0.10091314, 0.048676282, + -0.029037097, 0.031399418, -0.0040285117, 0.047237843, + 0.09504992, 0.041799378, -0.049185462, -0.031518843, + -0.10516937, 0.026374253, 0.10058866, -0.0033195973, + -0.041975245, 0.0073591834, 0.0033782164, -0.004325073, + -0.10167381, 0.042500053, -0.01447153, 0.06464186, + -0.017142897, 0.03312627, 0.009205989, 0.024138335, + -0.011337001, 0.035530265, -0.010912711, 0.0706555, + -0.005894094, 0.051841937, -0.1401738, -0.02351249, + 0.0365468, 0.07590991, 0.08838724, 0.021681072, + -0.10086113, 0.019608743, -0.06195883, 0.077335775, + 0.023646897, -0.095322326, 0.02233014, 0.09756986, + -0.048691444, -0.009579111, 0.07595467, 0.11480546, + -0.09801813, 0.019894179, 0.08502348, 0.004032281, + 0.037211012, 0.068537936, -0.048005626, -0.091520436, + -0.028379958, -0.01556313, 0.06554592, -0.045599163, + -0.01672207, -0.020169014, -0.011877351, -0.20212261, + 0.010889619, 0.0047078193, 0.038385306, 0.08540671, + -0.017140968, -0.0035865551, 0.016678626, 0.005633034, + 0.015963363, 0.00871737, 0.060130805, 0.028611384, + 0.10109069, -0.015060172, -0.07894427, 0.06401885, + 0.011584063, -0.024466386, 0.0047652307, -0.09041358, + 0.030737216, -0.0046374933, 0.14215417, -0.11823516, + 0.019899689, 0.006106124, -0.027092824, 0.0786356, + 0.05052217, -0.058925, -0.011402121, -0.024987547, + -0.0013661642, -0.06832946, -0.015667673, -0.1083353, + -0.00096863037, -0.06988685, -0.053350925, -0.027275559, + -0.033664223, -0.07978348, -0.025200296, -0.017207067, + -0.058403496, -0.055697463, 0.005798788, 0.12965427, + -0.062582195, 0.0013350133, -0.10482091, 0.0379771, + 0.072521195, -0.0029455067, -0.13797039, -0.03628521, + 0.013806405, -0.017858358, -0.01008298, -0.07700066, + -0.017081132, 0.019358726, 0.0027079724, 0.004635139, + 0.062634714, -0.02338735, -0.039547626, -0.02050681, + 0.03385117, -0.083611414, 0.002862572, -0.09421313, + 0.058618143, -0.08598433, 0.00972939, 0.023867095, + -0.053934585, -0.023203006, 0.07452513, -0.048767887, + -0.07314807, -0.056307215, -0.10433547, -0.06440842, + 0.04328182, 0.04389765, -0.020006588, -0.09076438, + -0.11652589, -0.021705797, 0.03345259, -0.010329105, + -0.025767034, 0.013057034, -0.07316461, -0.10145612, + 0.06358255, 0.18531723, 0.07759293, 0.12006465, + 0.1305557, 0.058638252, -0.03393652, 0.09622831, + -0.16253184, -2.4580743e-06, 0.079869635, -0.070196845, + -0.005644518, 0.06857898, -0.12598175, -0.035084512, + 0.03156317, -0.12794146, -0.031963028, 0.04692781, + 0.030070418, 0.0071660685, -0.095516115, -0.004643372, + 0.040170413, -0.062104587, -0.0037324072, 0.0554317, + 0.08184801, -0.019164372, 0.06791302, 0.034257166, + -0.10307039, 0.021943003, 0.046745934, 0.0790918, + -0.0265588, -0.007824208, 0.042546265, -0.00977924, + -0.0002440307, -0.017384544, -0.017990116, 0.12252321, + -0.014512694, -0.08251313, 0.08861942, 0.13589665, + 0.026351685, 0.012641483, 0.07466548, 0.044301085, + -0.045414884, -0.051112458, 0.03444247, -0.08502782, + -0.04106223, -0.028126027, 0.028473156, 0.10467447}); + + lstm.SetRecurrentToForgetWeights( + {-0.057784554, -0.026057621, -0.068447545, -0.022581743, + 0.14811787, 0.10826372, 0.09471067, 0.03987225, + -0.0039523416, 0.00030638507, 0.053185795, 0.10572994, + 0.08414449, -0.022036452, -0.00066928595, -0.09203576, + 0.032950465, -0.10985798, -0.023809856, 0.0021431844, + -0.02196096, -0.00326074, 0.00058621005, -0.074678116, + -0.06193199, 0.055729095, 0.03736828, 0.020123724, + 0.061878487, -0.04729229, 0.034919553, -0.07585433, + -0.04421272, -0.044019096, 0.085488975, 0.04058006, + -0.06890133, -0.030951202, -0.024628663, -0.07672815, + 0.034293607, 0.08556707, -0.05293577, -0.033561368, + -0.04899627, 0.0241671, 0.015736353, -0.095442444, + -0.029564252, 0.016493602, -0.035026584, 0.022337519, + -0.026871363, 0.004780428, 0.0077918363, -0.03601621, + 0.016435321, -0.03263031, -0.09543275, -0.047392778, + 0.013454138, 0.028934088, 0.01685226, -0.086110644, + -0.046250615, -0.01847454, 0.047608484, 0.07339695, + 0.034546845, -0.04881143, 0.009128804, -0.08802852, + 0.03761666, 0.008096139, -0.014454086, 0.014361001, + -0.023502491, -0.0011840804, -0.07607001, 0.001856849, + -0.06509276, -0.006021153, -0.08570962, -0.1451793, + 0.060212336, 0.055259194, 0.06974018, 0.049454916, + -0.027794661, -0.08077226, -0.016179763, 0.1169753, + 0.17213494, -0.0056326236, -0.053934924, -0.0124349, + -0.11520337, 0.05409887, 0.088759385, 0.0019655675, + 0.0042065294, 0.03881498, 0.019844765, 0.041858196, + -0.05695512, 0.047233116, 0.038937137, -0.06542224, + 0.014429736, -0.09719407, 0.13908425, -0.05379757, + 0.012321099, 0.082840554, -0.029899208, 0.044217527, + 0.059855383, 0.07711018, -0.045319796, 0.0948846, + -0.011724666, -0.0033288454, -0.033542685, -0.04764985, + -0.13873616, 0.040668588, 0.034832682, -0.015319203, + -0.018715994, 0.046002675, 0.0599172, -0.043107376, + 0.0294216, -0.002314414, -0.022424703, 0.0030315618, + 0.0014641669, 0.0029166266, -0.11878115, 0.013738511, + 0.12375372, -0.0006038222, 0.029104086, 0.087442465, + 0.052958444, 0.07558703, 0.04817258, 0.044462286, + -0.015213451, -0.08783778, -0.0561384, -0.003008196, + 0.047060397, -0.002058388, 0.03429439, -0.018839769, + 0.024734668, 0.024614193, -0.042046934, 0.09597743, + -0.0043254104, 0.04320769, 0.0064070094, -0.0019131786, + -0.02558259, -0.022822596, -0.023273505, -0.02464396, + -0.10991725, -0.006240552, 0.0074488563, 0.024044557, + 0.04383914, -0.046476185, 0.028658995, 0.060410924, + 0.050786525, 0.009452605, -0.0073054377, -0.024810238, + 0.0052906186, 0.0066939713, -0.0020913032, 0.014515517, + 0.015898481, 0.021362653, -0.030262267, 0.016587038, + -0.011442813, 0.041154444, -0.007631438, -0.03423484, + -0.010977775, 0.036152758, 0.0066366293, 0.11915515, + 0.02318443, -0.041350313, 0.021485701, -0.10906167, + -0.028218046, -0.00954771, 0.020531068, -0.11995105, + -0.03672871, 0.024019798, 0.014255957, -0.05221243, + -0.00661567, -0.04630967, 0.033188973, 0.10107534, + -0.014027541, 0.030796422, -0.10270911, -0.035999842, + 0.15443139, 0.07684145, 0.036571592, -0.035900835, + -0.0034699554, 0.06209149, 0.015920248, -0.031122351, + -0.03858649, 0.01849943, 0.13872518, 0.01503974, + 0.069941424, -0.06948533, -0.0088794185, 0.061282158, + -0.047401894, 0.03100163, -0.041533746, -0.10430945, + 0.044574402, -0.01425562, -0.024290353, 0.034563623, + 0.05866852, 0.023947537, -0.09445152, 0.035450947, + 0.02247216, -0.0042998926, 0.061146557, -0.10250651, + 0.020881841, -0.06747029, 0.10062043, -0.0023941975, + 0.03532124, -0.016341697, 0.09685456, -0.016764693, + 0.051808182, 0.05875331, -0.04536488, 0.001626336, + -0.028892258, -0.01048663, -0.009793449, -0.017093895, + 0.010987891, 0.02357273, -0.00010856845, 0.0099760275, + -0.001845119, -0.03551521, 0.0018358806, 0.05763657, + -0.01769146, 0.040995963, 0.02235177, -0.060430344, + 0.11475477, -0.023854522, 0.10071741, 0.0686208, + -0.014250481, 0.034261297, 0.047418304, 0.08562733, + -0.030519066, 0.0060542435, 0.014653856, -0.038836084, + 0.04096551, 0.032249358, -0.08355519, -0.026823482, + 0.056386515, -0.010401743, -0.028396193, 0.08507674, + 0.014410365, 0.020995233, 0.17040324, 0.11511526, + 0.02459721, 0.0066619175, 0.025853224, -0.023133837, + -0.081302024, 0.017264642, -0.009585969, 0.09491168, + -0.051313367, 0.054532815, -0.014298593, 0.10657464, + 0.007076659, 0.10964551, 0.0409152, 0.008275321, + -0.07283536, 0.07937492, 0.04192024, -0.1075027}); + + lstm.SetRecurrentToCellWeights( + {-0.037322544, 0.018592842, 0.0056175636, -0.06253426, + 0.055647098, -0.05713207, -0.05626563, 0.005559383, + 0.03375411, -0.025757805, -0.088049285, 0.06017052, + -0.06570978, 0.007384076, 0.035123326, -0.07920549, + 0.053676967, 0.044480428, -0.07663568, 0.0071805613, + 0.08089997, 0.05143358, 0.038261272, 0.03339287, + -0.027673481, 0.044746667, 0.028349208, 0.020090483, + -0.019443132, -0.030755889, -0.0040000007, 0.04465846, + -0.021585021, 0.0031670958, 0.0053199246, -0.056117613, + -0.10893326, 0.076739706, -0.08509834, -0.027997585, + 0.037871376, 0.01449768, -0.09002357, -0.06111149, + -0.046195522, 0.0422062, -0.005683705, -0.1253618, + -0.012925729, -0.04890792, 0.06985068, 0.037654128, + 0.03398274, -0.004781977, 0.007032333, -0.031787455, + 0.010868644, -0.031489216, 0.09525667, 0.013939797, + 0.0058680447, 0.0167067, 0.02668468, -0.04797466, + -0.048885044, -0.12722108, 0.035304096, 0.06554885, + 0.00972396, -0.039238118, -0.05159735, -0.11329045, + 0.1613692, -0.03750952, 0.06529313, -0.071974665, + -0.11769596, 0.015524369, -0.0013754242, -0.12446318, + 0.02786344, -0.014179351, 0.005264273, 0.14376344, + 0.015983658, 0.03406988, -0.06939408, 0.040699873, + 0.02111075, 0.09669095, 0.041345075, -0.08316494, + -0.07684199, -0.045768797, 0.032298047, -0.041805092, + 0.0119405, 0.0061010392, 0.12652606, 0.0064572375, + -0.024950314, 0.11574242, 0.04508852, -0.04335324, + 0.06760663, -0.027437469, 0.07216407, 0.06977076, + -0.05438599, 0.034033038, -0.028602652, 0.05346137, + 0.043184172, -0.037189785, 0.10420091, 0.00882477, + -0.054019816, -0.074273005, -0.030617684, -0.0028467078, + 0.024302477, -0.0038869337, 0.005332455, 0.0013399826, + 0.04361412, -0.007001822, 0.09631092, -0.06702025, + -0.042049985, -0.035070654, -0.04103342, -0.10273396, + 0.0544271, 0.037184782, -0.13150354, -0.0058036847, + -0.008264958, 0.042035464, 0.05891794, 0.029673764, + 0.0063542654, 0.044788733, 0.054816857, 0.062257513, + -0.00093483756, 0.048938446, -0.004952862, -0.007730018, + -0.04043371, -0.017094059, 0.07229206, -0.023670016, + -0.052195564, -0.025616996, -0.01520939, 0.045104615, + -0.007376126, 0.003533447, 0.006570588, 0.056037236, + 0.12436656, 0.051817212, 0.028532185, -0.08686856, + 0.11868599, 0.07663395, -0.07323171, 0.03463402, + -0.050708205, -0.04458982, -0.11590894, 0.021273347, + 0.1251325, -0.15313013, -0.12224372, 0.17228661, + 0.023029093, 0.086124025, 0.006445803, -0.03496501, + 0.028332196, 0.04449512, -0.042436164, -0.026587414, + -0.006041347, -0.09292539, -0.05678812, 0.03897832, + 0.09465633, 0.008115513, -0.02171956, 0.08304309, + 0.071401566, 0.019622514, 0.032163795, -0.004167056, + 0.02295182, 0.030739572, 0.056506045, 0.004612461, + 0.06524936, 0.059999723, 0.046395954, -0.0045512207, + -0.1335546, -0.030136576, 0.11584653, -0.014678886, + 0.0020118146, -0.09688814, -0.0790206, 0.039770417, + -0.0329582, 0.07922767, 0.029322514, 0.026405897, + 0.04207835, -0.07073373, 0.063781224, 0.0859677, + -0.10925287, -0.07011058, 0.048005477, 0.03438226, + -0.09606514, -0.006669445, -0.043381985, 0.04240257, + -0.06955775, -0.06769346, 0.043903265, -0.026784198, + -0.017840602, 0.024307009, -0.040079936, -0.019946516, + 0.045318738, -0.12233574, 0.026170589, 0.0074471775, + 0.15978073, 0.10185836, 0.10298046, -0.015476589, + -0.039390966, -0.072174534, 0.0739445, -0.1211869, + -0.0347889, -0.07943156, 0.014809798, -0.12412325, + -0.0030663363, 0.039695457, 0.0647603, -0.08291318, + -0.018529687, -0.004423833, 0.0037507233, 0.084633216, + -0.01514876, -0.056505352, -0.012800942, -0.06994386, + 0.012962922, -0.031234352, 0.07029052, 0.016418684, + 0.03618972, 0.055686004, -0.08663945, -0.017404709, + -0.054761406, 0.029065743, 0.052404847, 0.020238016, + 0.0048197987, -0.0214882, 0.07078733, 0.013016777, + 0.06262858, 0.009184685, 0.020785125, -0.043904778, + -0.0270329, -0.03299152, -0.060088247, -0.015162964, + -0.001828936, 0.12642565, -0.056757294, 0.013586685, + 0.09232601, -0.035886683, 0.06000002, 0.05229691, + -0.052580316, -0.082029596, -0.010794592, 0.012947712, + -0.036429964, -0.085508935, -0.13127148, -0.017744139, + 0.031502828, 0.036232427, -0.031581745, 0.023051167, + -0.05325106, -0.03421577, 0.028793324, -0.034633752, + -0.009881397, -0.043551125, -0.018609839, 0.0019097115, + -0.008799762, 0.056595087, 0.0022273948, 0.055752404}); + + lstm.SetRecurrentToOutputWeights({ + 0.025825322, -0.05813119, 0.09495884, -0.045984812, -0.01255415, + -0.0026479573, -0.08196161, -0.054914974, -0.0046604523, -0.029587349, + -0.044576716, -0.07480124, -0.082868785, 0.023254942, 0.027502948, + -0.0039728214, -0.08683098, -0.08116779, -0.014675607, -0.037924774, + -0.023314456, -0.007401714, -0.09255757, 0.029460307, -0.08829125, + -0.005139627, -0.08989442, -0.0555066, 0.13596267, -0.025062224, + -0.048351806, -0.03850004, 0.07266485, -0.022414139, 0.05940088, + 0.075114764, 0.09597592, -0.010211725, -0.0049794707, -0.011523867, + -0.025980417, 0.072999895, 0.11091378, -0.081685916, 0.014416728, + 0.043229222, 0.034178585, -0.07530371, 0.035837382, -0.085607, + -0.007721233, -0.03287832, -0.043848954, -0.06404588, -0.06632928, + -0.073643476, 0.008214239, -0.045984086, 0.039764922, 0.03474462, + 0.060612556, -0.080590084, 0.049127717, 0.04151091, -0.030063879, + 0.008801774, -0.023021035, -0.019558564, 0.05158114, -0.010947698, + -0.011825728, 0.0075720972, 0.0699727, -0.0039981045, 0.069350146, + 0.08799282, 0.016156472, 0.035502106, 0.11695009, 0.006217345, + 0.13392477, -0.037875112, 0.025745004, 0.08940699, -0.00924166, + 0.0046702605, -0.036598757, -0.08811812, 0.10522024, -0.032441203, + 0.008176899, -0.04454919, 0.07058152, 0.0067963637, 0.039206743, + 0.03259838, 0.03725492, -0.09515802, 0.013326398, -0.052055415, + -0.025676316, 0.03198509, -0.015951829, -0.058556724, 0.036879618, + 0.043357447, 0.028362012, -0.05908629, 0.0059240665, -0.04995891, + -0.019187413, 0.0276265, -0.01628143, 0.0025863599, 0.08800015, + 0.035250366, -0.022165963, -0.07328642, -0.009415526, -0.07455109, + 0.11690406, 0.0363299, 0.07411125, 0.042103454, -0.009660886, + 0.019076364, 0.018299393, -0.046004917, 0.08891175, 0.0431396, + -0.026327137, -0.051502608, 0.08979574, -0.051670972, 0.04940282, + -0.07491107, -0.021240504, 0.022596184, -0.034280192, 0.060163025, + -0.058211457, -0.051837247, -0.01349775, -0.04639988, -0.035936575, + -0.011681591, 0.064818054, 0.0073146066, -0.021745546, -0.043124277, + -0.06471268, -0.07053354, -0.029321948, -0.05330136, 0.016933719, + -0.053782392, 0.13747959, -0.1361751, -0.11569455, 0.0033329215, + 0.05693899, -0.053219706, 0.063698, 0.07977434, -0.07924483, + 0.06936997, 0.0034815092, -0.007305279, -0.037325785, -0.07251102, + -0.033633437, -0.08677009, 0.091591336, -0.14165086, 0.021752775, + 0.019683983, 0.0011612234, -0.058154266, 0.049996935, 0.0288841, + -0.0024567875, -0.14345716, 0.010955264, -0.10234828, 0.1183656, + -0.0010731248, -0.023590032, -0.072285876, -0.0724771, -0.026382286, + -0.0014920527, 0.042667855, 0.0018776858, 0.02986552, 0.009814309, + 0.0733756, 0.12289186, 0.018043943, -0.0458958, 0.049412545, + 0.033632483, 0.05495232, 0.036686596, -0.013781798, -0.010036754, + 0.02576849, -0.08307328, 0.010112348, 0.042521734, -0.05869831, + -0.071689695, 0.03876447, -0.13275425, -0.0352966, -0.023077697, + 0.10285965, 0.084736146, 0.15568255, -0.00040734606, 0.027835453, + -0.10292561, -0.032401145, 0.10053256, -0.026142767, -0.08271222, + -0.0030240538, -0.016368777, 0.1070414, 0.042672627, 0.013456989, + -0.0437609, -0.022309763, 0.11576483, 0.04108048, 0.061026827, + -0.0190714, -0.0869359, 0.037901703, 0.0610107, 0.07202949, + 0.01675338, 0.086139716, -0.08795751, -0.014898893, -0.023771819, + -0.01965048, 0.007955471, -0.043740474, 0.03346837, -0.10549954, + 0.090567775, 0.042013682, -0.03176985, 0.12569028, -0.02421228, + -0.029526481, 0.023851605, 0.031539805, 0.05292009, -0.02344001, + -0.07811758, -0.08834428, 0.10094801, 0.16594367, -0.06861939, + -0.021256343, -0.041093912, -0.06669611, 0.035498552, 0.021757556, + -0.09302526, -0.015403468, -0.06614931, -0.051798206, -0.013874718, + 0.03630673, 0.010412845, -0.08077351, 0.046185967, 0.0035662893, + 0.03541868, -0.094149634, -0.034814864, 0.003128424, -0.020674974, + -0.03944324, -0.008110165, -0.11113267, 0.08484226, 0.043586485, + 0.040582247, 0.0968012, -0.065249965, -0.028036479, 0.0050708856, + 0.0017462453, 0.0326779, 0.041296225, 0.09164146, -0.047743853, + -0.015952192, -0.034451712, 0.084197424, -0.05347844, -0.11768019, + 0.085926116, -0.08251791, -0.045081906, 0.0948852, 0.068401024, + 0.024856757, 0.06978981, -0.057309967, -0.012775832, -0.0032452994, + 0.01977615, -0.041040014, -0.024264973, 0.063464895, 0.05431621, + }); + + lstm.SetCellToInputWeights( + {0.040369894, 0.030746894, 0.24704495, 0.018586371, -0.037586458, + -0.15312155, -0.11812848, -0.11465643, 0.20259799, 0.11418174, + -0.10116027, -0.011334949, 0.12411352, -0.076769054, -0.052169047, + 0.21198851, -0.38871562, -0.09061183, -0.09683246, -0.21929175}); + + lstm.SetCellToForgetWeights( + {-0.01998659, -0.15568835, -0.24248174, -0.012770197, 0.041331276, + -0.072311886, -0.052123554, -0.0066330447, -0.043891653, 0.036225766, + -0.047248036, 0.021479502, 0.033189066, 0.11952997, -0.020432774, + 0.64658105, -0.06650122, -0.03467612, 0.095340036, 0.23647355}); + + lstm.SetCellToOutputWeights( + {0.08286371, -0.08261836, -0.51210177, 0.002913762, 0.17764764, + -0.5495371, -0.08460716, -0.24552552, 0.030037103, 0.04123544, + -0.11940523, 0.007358328, 0.1890978, 0.4833202, -0.34441817, + 0.36312827, -0.26375428, 0.1457655, -0.19724406, 0.15548733}); + + lstm.SetProjectionWeights( + {-0.009802181, 0.09401916, 0.0717386, -0.13895074, 0.09641832, + 0.060420845, 0.08539281, 0.054285463, 0.061395317, 0.034448683, + -0.042991187, 0.019801661, -0.16840284, -0.015726732, -0.23041931, + -0.024478018, -0.10959692, -0.013875541, 0.18600968, -0.061274476, + 0.0138165, -0.08160894, -0.07661644, 0.032372914, 0.16169067, + 0.22465782, -0.03993472, -0.004017731, 0.08633481, -0.28869787, + 0.08682067, 0.17240396, 0.014975425, 0.056431185, 0.031037588, + 0.16702051, 0.0077946745, 0.15140012, 0.29405436, 0.120285, + -0.188994, -0.027265169, 0.043389652, -0.022061434, 0.014777949, + -0.20203483, 0.094781205, 0.19100232, 0.13987629, -0.036132768, + -0.06426278, -0.05108664, 0.13221376, 0.009441198, -0.16715929, + 0.15859416, -0.040437475, 0.050779544, -0.022187516, 0.012166504, + 0.027685808, -0.07675938, -0.0055694645, -0.09444123, 0.0046453946, + 0.050794356, 0.10770313, -0.20790008, -0.07149004, -0.11425117, + 0.008225835, -0.035802525, 0.14374903, 0.15262283, 0.048710253, + 0.1847461, -0.007487823, 0.11000021, -0.09542012, 0.22619456, + -0.029149994, 0.08527916, 0.009043713, 0.0042746216, 0.016261552, + 0.022461696, 0.12689082, -0.043589946, -0.12035478, -0.08361797, + -0.050666027, -0.1248618, -0.1275799, -0.071875185, 0.07377272, + 0.09944291, -0.18897448, -0.1593054, -0.06526116, -0.040107165, + -0.004618631, -0.067624845, -0.007576253, 0.10727444, 0.041546922, + -0.20424393, 0.06907816, 0.050412357, 0.00724631, 0.039827548, + 0.12449835, 0.10747581, 0.13708383, 0.09134148, -0.12617786, + -0.06428341, 0.09956831, 0.1208086, -0.14676677, -0.0727722, + 0.1126304, 0.010139365, 0.015571211, -0.038128063, 0.022913318, + -0.042050496, 0.16842307, -0.060597885, 0.10531834, -0.06411776, + -0.07451711, -0.03410368, -0.13393489, 0.06534304, 0.003620307, + 0.04490757, 0.05970546, 0.05197996, 0.02839995, 0.10434969, + -0.013699693, -0.028353551, -0.07260381, 0.047201227, -0.024575593, + -0.036445823, 0.07155557, 0.009672501, -0.02328883, 0.009533515, + -0.03606021, -0.07421458, -0.028082801, -0.2678904, -0.13221288, + 0.18419984, -0.13012612, -0.014588381, -0.035059117, -0.04824723, + 0.07830115, -0.056184657, 0.03277091, 0.025466874, 0.14494097, + -0.12522776, -0.098633975, -0.10766018, -0.08317623, 0.08594209, + 0.07749552, 0.039474737, 0.1776665, -0.07409566, -0.0477268, + 0.29323658, 0.10801441, 0.1154011, 0.013952499, 0.10739139, + 0.10708251, -0.051456142, 0.0074137426, -0.10430189, 0.10034707, + 0.045594677, 0.0635285, -0.0715442, -0.089667566, -0.10811871, + 0.00026344223, 0.08298446, -0.009525053, 0.006585689, -0.24567553, + -0.09450807, 0.09648481, 0.026996298, -0.06419476, -0.04752702, + -0.11063944, -0.23441927, -0.17608605, -0.052156363, 0.067035615, + 0.19271925, -0.0032889997, -0.043264326, 0.09663576, -0.057112187, + -0.10100678, 0.0628376, 0.04447668, 0.017961001, -0.10094388, + -0.10190601, 0.18335468, 0.10494553, -0.052095775, -0.0026118709, + 0.10539724, -0.04383912, -0.042349473, 0.08438151, -0.1947263, + 0.02251204, 0.11216432, -0.10307853, 0.17351969, -0.039091777, + 0.08066188, -0.00561982, 0.12633002, 0.11335965, -0.0088127935, + -0.019777594, 0.06864014, -0.059751723, 0.016233567, -0.06894641, + -0.28651384, -0.004228674, 0.019708522, -0.16305895, -0.07468996, + -0.0855457, 0.099339016, -0.07580735, -0.13775392, 0.08434318, + 0.08330512, -0.12131499, 0.031935584, 0.09180414, -0.08876437, + -0.08049874, 0.008753825, 0.03498998, 0.030215185, 0.03907079, + 0.089751154, 0.029194152, -0.03337423, -0.019092513, 0.04331237, + 0.04299654, -0.036394123, -0.12915532, 0.09793732, 0.07512415, + -0.11319543, -0.032502122, 0.15661901, 0.07671967, -0.005491124, + -0.19379048, -0.218606, 0.21448623, 0.017840758, 0.1416943, + -0.07051762, 0.19488361, 0.02664691, -0.18104725, -0.09334311, + 0.15026465, -0.15493552, -0.057762887, -0.11604192, -0.262013, + -0.01391798, 0.012185008, 0.11156489, -0.07483202, 0.06693364, + -0.26151478, 0.046425626, 0.036540434, -0.16435726, 0.17338543, + -0.21401681, -0.11385144, -0.08283257, -0.069031075, 0.030635102, + 0.010969227, 0.11109743, 0.010919218, 0.027526086, 0.13519906, + 0.01891392, -0.046839405, -0.040167913, 0.017953383, -0.09700955, + 0.0061885654, -0.07000971, 0.026893595, -0.038844477, 0.14543656}); + + static float lstm_input[][20] = { + {// Batch0: 4 (input_sequence_size) * 5 (n_input) + 0.787926, 0.151646, 0.071352, 0.118426, 0.458058, 0.596268, 0.998386, + 0.568695, 0.864524, 0.571277, 0.073204, 0.296072, 0.743333, 0.069199, + 0.045348, 0.867394, 0.291279, 0.013714, 0.482521, 0.626339}, + + {// Batch1: 4 (input_sequence_size) * 5 (n_input) + 0.295743, 0.544053, 0.690064, 0.858138, 0.497181, 0.642421, 0.524260, + 0.134799, 0.003639, 0.162482, 0.640394, 0.930399, 0.050782, 0.432485, + 0.988078, 0.082922, 0.563329, 0.865614, 0.333232, 0.259916}}; + + static float lstm_golden_output[][64] = { + {// Batch0: 4 (input_sequence_size) * 16 (n_output) + -0.00396806, 0.029352, -0.00279226, 0.0159977, -0.00835576, + -0.0211779, 0.0283512, -0.0114597, 0.00907307, -0.0244004, + -0.0152191, -0.0259063, 0.00914318, 0.00415118, 0.017147, + 0.0134203, -0.0166936, 0.0381209, 0.000889694, 0.0143363, + -0.0328911, -0.0234288, 0.0333051, -0.012229, 0.0110322, + -0.0457725, -0.000832209, -0.0202817, 0.0327257, 0.0121308, + 0.0155969, 0.0312091, -0.0213783, 0.0350169, 0.000324794, + 0.0276012, -0.0263374, -0.0371449, 0.0446149, -0.0205474, + 0.0103729, -0.0576349, -0.0150052, -0.0292043, 0.0376827, + 0.0136115, 0.0243435, 0.0354492, -0.0189322, 0.0464512, + -0.00251373, 0.0225745, -0.0308346, -0.0317124, 0.0460407, + -0.0189395, 0.0149363, -0.0530162, -0.0150767, -0.0340193, + 0.0286833, 0.00824207, 0.0264887, 0.0305169}, + {// Batch1: 4 (input_sequence_size) * 16 (n_output) + -0.013869, 0.0287268, -0.00334693, 0.00733398, -0.0287926, + -0.0186926, 0.0193662, -0.0115437, 0.00422612, -0.0345232, + 0.00223253, -0.00957321, 0.0210624, 0.013331, 0.0150954, + 0.02168, -0.0141913, 0.0322082, 0.00227024, 0.0260507, + -0.0188721, -0.0296489, 0.0399134, -0.0160509, 0.0116039, + -0.0447318, -0.0150515, -0.0277406, 0.0316596, 0.0118233, + 0.0214762, 0.0293641, -0.0204549, 0.0450315, -0.00117378, + 0.0167673, -0.0375007, -0.0238314, 0.038784, -0.0174034, + 0.0131743, -0.0506589, -0.0048447, -0.0240239, 0.0325789, + 0.00790065, 0.0220157, 0.0333314, -0.0264787, 0.0387855, + -0.000764675, 0.0217599, -0.037537, -0.0335206, 0.0431679, + -0.0211424, 0.010203, -0.062785, -0.00832363, -0.025181, + 0.0412031, 0.0118723, 0.0239643, 0.0394009}}; + + // Resetting cell_state and output_state + lstm.ResetCellState(); + lstm.ResetOutputState(); + + for (int i = 0; i < lstm.sequence_length(); i++) { + float* batch0_start = lstm_input[0] + i * lstm.num_inputs(); + float* batch0_end = batch0_start + lstm.num_inputs(); + + lstm.SetInput(2 * i * lstm.num_inputs(), batch0_start, batch0_end); + + float* batch1_start = lstm_input[1] + i * lstm.num_inputs(); + float* batch1_end = batch1_start + lstm.num_inputs(); + lstm.SetInput((2 * i + 1) * lstm.num_inputs(), batch1_start, batch1_end); + } + + lstm.Invoke(); + + std::vector expected; + for (int i = 0; i < lstm.sequence_length(); i++) { + float* golden_start_batch0 = lstm_golden_output[0] + i * lstm.num_outputs(); + float* golden_end_batch0 = golden_start_batch0 + lstm.num_outputs(); + float* golden_start_batch1 = lstm_golden_output[1] + i * lstm.num_outputs(); + float* golden_end_batch1 = golden_start_batch1 + lstm.num_outputs(); + expected.insert(expected.end(), golden_start_batch0, golden_end_batch0); + expected.insert(expected.end(), golden_start_batch1, golden_end_batch1); + } + EXPECT_THAT(lstm.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac00c37b67dcbe77023a2495a698967ca555b1d5 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn.cc @@ -0,0 +1,168 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/activation_functor.h" +#include "tensorflow/contrib/lite/kernels/internal/kernel_utils.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace unidirectional_sequence_rnn { + +constexpr int kInputTensor = 0; +constexpr int kWeightsTensor = 1; +constexpr int kRecurrentWeightsTensor = 2; +constexpr int kBiasTensor = 3; +constexpr int kHiddenStateTensor = 0; +constexpr int kOutputTensor = 1; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + // Check we have all the inputs and outputs we need. + TF_LITE_ENSURE_EQ(context, node->inputs->size, 4); + TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + + // Check all the parameters of tensor match within themselves and match the + // input configuration. + auto* params = reinterpret_cast(node->builtin_data); + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + TF_LITE_ASSERT_EQ(input->dims->data[2], input_weights->dims->data[1]); + TF_LITE_ASSERT_EQ(input_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[0], bias->dims->data[0]); + TF_LITE_ASSERT_EQ(recurrent_weights->dims->data[1], bias->dims->data[0]); + + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[kHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Resize state. + TfLiteIntArray* hidden_state_size_array = TfLiteIntArrayCreate(2); + hidden_state_size_array->data[0] = batch_size; + hidden_state_size_array->data[1] = num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, hidden_state, + hidden_state_size_array)); + + // Mark hidden state as a persistent tensor. + hidden_state->allocation_type = kTfLiteArenaRwPersistent; + + // Resize output. + TfLiteIntArray* output_size_array = TfLiteIntArrayCreate(3); + output_size_array->data[0] = (time_major) ? max_time : batch_size; + output_size_array->data[1] = (time_major) ? batch_size : max_time; + output_size_array->data[2] = num_units; + TF_LITE_ENSURE_OK(context, + context->ResizeTensor(context, output, output_size_array)); + + return kTfLiteOk; +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + auto* params = reinterpret_cast(node->builtin_data); + + TfLiteTensor* input = &context->tensors[node->inputs->data[kInputTensor]]; + TfLiteTensor* input_weights = + &context->tensors[node->inputs->data[kWeightsTensor]]; + TfLiteTensor* recurrent_weights = + &context->tensors[node->inputs->data[kRecurrentWeightsTensor]]; + TfLiteTensor* bias = &context->tensors[node->inputs->data[kBiasTensor]]; + TfLiteTensor* hidden_state = + &context->tensors[node->outputs->data[kHiddenStateTensor]]; + TfLiteTensor* output = &context->tensors[node->outputs->data[kOutputTensor]]; + + // Initialize the pointer bias. + const float* bias_ptr = bias->data.f; + + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; + const int num_units = input_weights->dims->data[0]; + const int input_size = input->dims->data[2]; + + // Initialize input_weights and recurrent_weights. + const float* input_weights_ptr = input_weights->data.f; + const float* recurrent_weights_ptr = recurrent_weights->data.f; + + if (time_major) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f; + // Unroll the sequence and use batch batch operations for efficiency. + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + float* output_ptr_batch = output->data.f + s * num_units * batch_size; + + kernel_utils::RnnBatchStep(input_ptr_batch, input_weights_ptr, + recurrent_weights_ptr, bias_ptr, input_size, + num_units, batch_size, params->activation, + hidden_state_ptr_batch, output_ptr_batch); + } + } else { + // For each batch + for (int b = 0; b < batch_size; b++) { + // Initialize the pointer to hidden state. + float* hidden_state_ptr_batch = hidden_state->data.f + b * num_units; + for (int s = 0; s < max_time; s++) { + // Initialize the pointer to input and output. + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + float* output_ptr_batch = + output->data.f + b * num_units * max_time + s * num_units; + + kernel_utils::RnnBatchStep( + input_ptr_batch, input_weights_ptr, recurrent_weights_ptr, bias_ptr, + input_size, num_units, /*batch_size=*/1, params->activation, + hidden_state_ptr_batch, output_ptr_batch); + } + } + } + return kTfLiteOk; +} + +} // namespace unidirectional_sequence_rnn + +TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_RNN() { + static TfLiteRegistration r = {/*init=*/nullptr, /*free=*/nullptr, + unidirectional_sequence_rnn::Prepare, + unidirectional_sequence_rnn::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7e32969763b59620dc3534708f965750680002d2 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -0,0 +1,351 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Unit test for TFLite Sequential RNN op. + +#include +#include + +#include +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +static float rnn_input[] = { + 0.23689353, 0.285385, 0.037029743, -0.19858193, -0.27569133, + 0.43773448, 0.60379338, 0.35562468, -0.69424844, -0.93421471, + -0.87287879, 0.37144363, -0.62476718, 0.23791671, 0.40060222, + 0.1356622, -0.99774903, -0.98858172, -0.38952237, -0.47685933, + 0.31073618, 0.71511042, -0.63767755, -0.31729108, 0.33468103, + 0.75801885, 0.30660987, -0.37354088, 0.77002847, -0.62747043, + -0.68572164, 0.0069220066, 0.65791464, 0.35130811, 0.80834007, + -0.61777675, -0.21095741, 0.41213346, 0.73784804, 0.094794154, + 0.47791874, 0.86496925, -0.53376222, 0.85315156, 0.10288584, + 0.86684, -0.011186242, 0.10513687, 0.87825835, 0.59929144, + 0.62827742, 0.18899453, 0.31440187, 0.99059987, 0.87170351, + -0.35091716, 0.74861872, 0.17831337, 0.2755419, 0.51864719, + 0.55084288, 0.58982027, -0.47443086, 0.20875752, -0.058871567, + -0.66609079, 0.59098077, 0.73017097, 0.74604273, 0.32882881, + -0.17503482, 0.22396147, 0.19379807, 0.29120302, 0.077113032, + -0.70331609, 0.15804303, -0.93407321, 0.40182066, 0.036301374, + 0.66521823, 0.0300982, -0.7747041, -0.02038002, 0.020698071, + -0.90300065, 0.62870288, -0.23068321, 0.27531278, -0.095755219, + -0.712036, -0.17384434, -0.50593495, -0.18646687, -0.96508682, + 0.43519354, 0.14744234, 0.62589407, 0.1653645, -0.10651493, + -0.045277178, 0.99032974, -0.88255352, -0.85147917, 0.28153265, + 0.19455957, -0.55479527, -0.56042433, 0.26048636, 0.84702539, + 0.47587705, -0.074295521, -0.12287641, 0.70117295, 0.90532446, + 0.89782166, 0.79817224, 0.53402734, -0.33286154, 0.073485017, + -0.56172788, -0.044897556, 0.89964068, -0.067662835, 0.76863563, + 0.93455386, -0.6324693, -0.083922029}; + +static float rnn_golden_output[] = { + 0.496726, 0, 0.965996, 0, 0.0584254, 0, + 0, 0.12315, 0, 0, 0.612266, 0.456601, + 0, 0.52286, 1.16099, 0.0291232, + + 0, 0, 0.524901, 0, 0, 0, + 0, 1.02116, 0, 1.35762, 0, 0.356909, + 0.436415, 0.0355727, 0, 0, + + 0, 0, 0, 0.262335, 0, 0, + 0, 1.33992, 0, 2.9739, 0, 0, + 1.31914, 2.66147, 0, 0, + + 0.942568, 0, 0, 0, 0.025507, 0, + 0, 0, 0.321429, 0.569141, 1.25274, 1.57719, + 0.8158, 1.21805, 0.586239, 0.25427, + + 1.04436, 0, 0.630725, 0, 0.133801, 0.210693, + 0.363026, 0, 0.533426, 0, 1.25926, 0.722707, + 0, 1.22031, 1.30117, 0.495867, + + 0.222187, 0, 0.72725, 0, 0.767003, 0, + 0, 0.147835, 0, 0, 0, 0.608758, + 0.469394, 0.00720298, 0.927537, 0, + + 0.856974, 0.424257, 0, 0, 0.937329, 0, + 0, 0, 0.476425, 0, 0.566017, 0.418462, + 0.141911, 0.996214, 1.13063, 0, + + 0.967899, 0, 0, 0, 0.0831304, 0, + 0, 1.00378, 0, 0, 0, 1.44818, + 1.01768, 0.943891, 0.502745, 0, + + 0.940135, 0, 0, 0, 0, 0, + 0, 2.13243, 0, 0.71208, 0.123918, 1.53907, + 1.30225, 1.59644, 0.70222, 0, + + 0.804329, 0, 0.430576, 0, 0.505872, 0.509603, + 0.343448, 0, 0.107756, 0.614544, 1.44549, 1.52311, + 0.0454298, 0.300267, 0.562784, 0.395095, + + 0.228154, 0, 0.675323, 0, 1.70536, 0.766217, + 0, 0, 0, 0.735363, 0.0759267, 1.91017, + 0.941888, 0, 0, 0, + + 0, 0, 1.5909, 0, 0, 0, + 0, 0.5755, 0, 0.184687, 0, 1.56296, + 0.625285, 0, 0, 0, + + 0, 0, 0.0857888, 0, 0, 0, + 0, 0.488383, 0.252786, 0, 0, 0, + 1.02817, 1.85665, 0, 0, + + 0.00981836, 0, 1.06371, 0, 0, 0, + 0, 0, 0, 0.290445, 0.316406, 0, + 0.304161, 1.25079, 0.0707152, 0, + + 0.986264, 0.309201, 0, 0, 0, 0, + 0, 1.64896, 0.346248, 0, 0.918175, 0.78884, + 0.524981, 1.92076, 2.07013, 0.333244, + + 0.415153, 0.210318, 0, 0, 0, 0, + 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, + 0.628881, 3.58099, 1.49974, 0}; + +class UnidirectionalRNNOpModel : public SingleOpModel { + public: + UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, + bool time_major) + : batches_(batches), + sequence_len_(sequence_len), + units_(units), + input_size_(size) { + input_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(TensorType_FLOAT32); + recurrent_weights_ = AddInput(TensorType_FLOAT32); + bias_ = AddInput(TensorType_FLOAT32); + hidden_state_ = AddOutput(TensorType_FLOAT32); + output_ = AddOutput(TensorType_FLOAT32); + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_SequenceRNNOptions, + CreateSequenceRNNOptions(builder_, time_major, + ActivationFunctionType_RELU) + .Union()); + if (time_major) { + BuildInterpreter({{sequence_len_, batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } else { + BuildInterpreter({{batches_, sequence_len_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } + } + + void SetBias(std::initializer_list f) { PopulateTensor(bias_, f); } + + void SetWeights(std::initializer_list f) { + PopulateTensor(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list f) { + PopulateTensor(recurrent_weights_, f); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(int offset, float* begin, float* end) { + PopulateTensor(input_, offset, begin, end); + } + + void ResetHiddenState() { + const int zero_buffer_size = units_ * batches_; + std::unique_ptr zero_buffer(new float[zero_buffer_size]); + memset(zero_buffer.get(), 0, zero_buffer_size * sizeof(float)); + PopulateTensor(hidden_state_, 0, zero_buffer.get(), + zero_buffer.get() + zero_buffer_size); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + int input_size() { return input_size_; } + int num_units() { return units_; } + int num_batches() { return batches_; } + int sequence_len() { return sequence_len_; } + + private: + int input_; + int weights_; + int recurrent_weights_; + int bias_; + int hidden_state_; + int output_; + + int batches_; + int sequence_len_; + int units_; + int input_size_; +}; + +// TODO(mirkov): add another test which directly compares to TF once TOCO +// supports the conversion from dynamic_rnn with BasicRNNCell. +TEST(FullyConnectedOpTest, BlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/false); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output; + float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len(); + std::vector expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/true); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_batch_start = rnn_golden_output + i * rnn.num_units(); + float* golden_batch_end = golden_batch_start + rnn.num_units(); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + } + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + // On Linux, add: tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/lib_package/BUILD b/tensorflow/contrib/lite/lib_package/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3c1b8d3d45f2bb382bbe6b789ec6ac7ec89ebc66 --- /dev/null +++ b/tensorflow/contrib/lite/lib_package/BUILD @@ -0,0 +1,16 @@ +package(default_visibility = ["//visibility:private"]) + +# Create the LICENSE file for libraries that are used by TensorFlow Lite +# C library. +genrule( + name = "clicenses_generate", + srcs = [ + "//third_party/eigen3:LICENSE", + "@arm_neon_2_x86_sse//:LICENSE", + "@farmhash_archive//:COPYING", + "@gemmlowp//:LICENSE", + ], + outs = ["LICENSE"], + cmd = "$(location :concat_licenses.sh) $(SRCS) >$@", + tools = [":concat_licenses.sh"], +) diff --git a/tensorflow/contrib/lite/lib_package/concat_licenses.sh b/tensorflow/contrib/lite/lib_package/concat_licenses.sh new file mode 100755 index 0000000000000000000000000000000000000000..2070f64e9fa4384234361556da0ed6f5089319b3 --- /dev/null +++ b/tensorflow/contrib/lite/lib_package/concat_licenses.sh @@ -0,0 +1,28 @@ +#!/usr/bin/env bash +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Script aimed to combining multiple license files into a single one. + +for f in $@ +do + echo "--------------------------------------------------------------------------------" + echo "BEGIN LICENSE FOR $f" + echo "--------------------------------------------------------------------------------" + cat $f + echo "--------------------------------------------------------------------------------" + echo "END LICENSE FOR $f" + echo "--------------------------------------------------------------------------------" +done diff --git a/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh new file mode 100755 index 0000000000000000000000000000000000000000..b58ae266017caf8781c28331f49a8f5bc1550767 --- /dev/null +++ b/tensorflow/contrib/lite/lib_package/create_ios_frameworks.sh @@ -0,0 +1,81 @@ +#!/bin/bash -x +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +set -e + +echo "Starting" +TFLITE_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)/.." + +TMP_DIR=$(mktemp -d) +echo "Package dir: " $TMP_DIR +FW_DIR=$TMP_DIR/tensorflow_lite_ios_frameworks +FW_DIR_TFLITE=$FW_DIR/tensorflow_lite.framework +FW_DIR_TFLITE_HDRS=$FW_DIR_TFLITE/Headers + +echo "Creating target Headers directories" +mkdir -p $FW_DIR_TFLITE_HDRS + +echo "Headers, populating: TensorFlow Lite" +cd $TFLITE_DIR/../../.. + +find tensorflow/contrib/lite -name '*.h' \ + -not -path 'tensorflow/contrib/lite/downloads/*' \ + -not -path 'tensorflow/contrib/lite/examples/*' \ + -not -path 'tensorflow/contrib/lite/gen/*' \ + -not -path 'tensorflow/contrib/lite/toco/*' \ + -not -path 'tensorflow/contrib/lite/nnapi/*' \ + -not -path 'tensorflow/contrib/lite/java/*' \ + | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - +cd $FW_DIR_TFLITE_HDRS +tar xf tmp.tar +rm -f tmp.tar + +echo "Headers, populating: Flatbuffer" +cd $TFLITE_DIR/downloads/flatbuffers/include/ +find . -name '*.h' | tar -cf $FW_DIR_TFLITE_HDRS/tmp.tar -T - +cd $FW_DIR_TFLITE_HDRS +tar xf tmp.tar +rm -f tmp.tar + +cd $TFLITE_DIR/../../.. +echo "Generate master LICENSE file and copy to target" +bazel build //tensorflow/tools/lib_package:clicenses_generate +cp $TFLITE_DIR/../../../bazel-genfiles/tensorflow/tools/lib_package/include/tensorflow/c/LICENSE \ + $FW_DIR_TFLITE + +echo "Copying static libraries" +cp $TFLITE_DIR/gen/lib/libtensorflow-lite.a \ + $FW_DIR_TFLITE/tensorflow_lite + +# This is required, otherwise they interfere with the documentation of the +# pod at cocoapods.org. +echo "Remove all README files" +cd $FW_DIR_TFLITE_HDRS +find . -type f -name README\* -exec rm -f {} \; +find . -type f -name readme\* -exec rm -f {} \; + +TARGET_GEN_LOCATION="$TFLITE_DIR/gen/ios_frameworks" +echo "Moving results to target: " $TARGET_GEN_LOCATION +cd $FW_DIR +zip -q -r tensorflow_lite.framework.zip tensorflow_lite.framework -x .DS_Store +rm -rf $TARGET_GEN_LOCATION +mkdir -p $TARGET_GEN_LOCATION +cp -r tensorflow_lite.framework.zip $TARGET_GEN_LOCATION + +echo "Cleaning up" +rm -rf $TMP_DIR + +echo "Finished" diff --git a/tensorflow/contrib/lite/memory_planner.h b/tensorflow/contrib/lite/memory_planner.h new file mode 100644 index 0000000000000000000000000000000000000000..5cd6c208500f3ea84ab8146f7f136e8b7851ff03 --- /dev/null +++ b/tensorflow/contrib/lite/memory_planner.h @@ -0,0 +1,45 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ +#define TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ + +#include "tensorflow/contrib/lite/context.h" + +namespace tflite { + +// A MemoryPlanner is responsible for planning and executing a number of +// memory-related operations that are necessary in TF Lite. +class MemoryPlanner { + public: + virtual ~MemoryPlanner() {} + + // Plans the necessary memory allocations. This is the MemoryPlanner's + // pre-processing step and is called when the graph structure is known but + // actual size of the tensors is not. + virtual TfLiteStatus PlanAllocations() = 0; + + // Allocates the necessary memory to execute all nodes in the interval + // [first_node, last_node]. + virtual TfLiteStatus ExecuteAllocations(int first_node, int last_node) = 0; + + // Invalidates allocations made earliers. This is called when tensors sizes + // have change. All planned allocations remain, but can't be used until + // ExecuteAllocations() is called. + virtual TfLiteStatus ResetAllocations() = 0; +}; + +} // namespace tflite + +#endif // TENSORFLOW_CONTRIB_LITE_MEMORY_PLANNER_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index e2f3560e61baae88a4afaafaa202cde784063efc..2ee0cac11ca8b5c964e04f9baa2471ab27b6972d 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -30,17 +30,6 @@ limitations under the License. namespace tflite { -namespace { -inline const tflite::Model* VerifyAndGetModel(const void* buf, size_t len) { - ::flatbuffers::Verifier verifier(static_cast(buf), len); - if (VerifyModelBuffer(verifier)) { - return ::tflite::GetModel(buf); - } else { - return nullptr; - } -} -} // namespace - const char* kEmptyTensorName = ""; std::unique_ptr FlatBufferModel::BuildFromFile( @@ -60,6 +49,14 @@ std::unique_ptr FlatBufferModel::BuildFromBuffer( return model; } +std::unique_ptr FlatBufferModel::BuildFromModel( + const tflite::Model* model_spec, ErrorReporter* error_reporter) { + std::unique_ptr model; + model.reset(new FlatBufferModel(model_spec, error_reporter)); + if (!model->initialized()) model.reset(); + return model; +} + FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, ErrorReporter* error_reporter, bool use_nnapi) : error_reporter_(error_reporter ? error_reporter @@ -72,10 +69,9 @@ FlatBufferModel::FlatBufferModel(const char* filename, bool mmap_file, } else { allocation_ = new FileCopyAllocation(filename, error_reporter); } - if (!allocation_->valid()) return; - if (!CheckModelIdentifier()) return; + if (!allocation_->valid() || !CheckModelIdentifier()) return; - model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); + model_ = ::tflite::GetModel(allocation_->base()); } bool FlatBufferModel::CheckModelIdentifier() const { @@ -96,7 +92,14 @@ FlatBufferModel::FlatBufferModel(const char* ptr, size_t num_bytes, allocation_ = new MemoryAllocation(ptr, num_bytes, error_reporter); if (!allocation_->valid()) return; - model_ = VerifyAndGetModel(allocation_->base(), allocation_->bytes()); + model_ = ::tflite::GetModel(allocation_->base()); +} + +FlatBufferModel::FlatBufferModel(const Model* model, + ErrorReporter* error_reporter) + : error_reporter_(error_reporter ? error_reporter + : DefaultErrorReporter()) { + model_ = model; } FlatBufferModel::~FlatBufferModel() { delete allocation_; } @@ -160,6 +163,27 @@ std::vector FlatBufferIntArrayToVector(T* flat_array) { return ret; } +// Copies the contents from the flatbuffer int vector `flatbuffer` into the +// int array `buffer`. `flat_vector` and `buffer` represent the same +// configuration operation for a given operation. +void FlatBufferIntVectorToArray(int max_size_of_buffer, + const flatbuffers::Vector* flat_vector, + int* buffer, ErrorReporter* error_reporter) { + if (!flat_vector) { + error_reporter->Report("Input array not provided for operation.\n"); + } else { + int num_dimensions = flat_vector->Length(); + if (num_dimensions > max_size_of_buffer / sizeof(int)) { + error_reporter->Report( + "Found too many dimensions in the operation's input array.\n"); + } else { + for (int i = 0; i < num_dimensions; ++i) { + buffer[i] = flat_vector->Get(i); + } + } + } +} + // Allocate a structure using C malloc, but make sure the structure is a // POD structure that doesn't require constructors to run. The reason we do // this, is that Interpreter's C extension part will take ownership and wants @@ -175,6 +199,9 @@ T* MallocPOD() { // This handles builtin data explicitly as there are flatbuffer schemas. // // Returns memory that must be feed. +// +// TODO(nupurgarg): Pass in void ** and return TfLiteStatus to ensure program +// crashes if error reporter is called. void* ParseOpData(const Operator* op, BuiltinOperator op_type, ErrorReporter* error_reporter) { auto parse_padding = [](Padding padding) { @@ -192,7 +219,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, return kTfLiteActNone; case ActivationFunctionType_RELU: return kTfLiteActRelu; - case ActivationFunctionType_RELU1: + case ActivationFunctionType_RELU_N1_TO_1: return kTfLiteActRelu1; case ActivationFunctionType_RELU6: return kTfLiteActRelu6; @@ -248,9 +275,10 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_TANH: case BuiltinOperator_LOGISTIC: case BuiltinOperator_RELU: - case BuiltinOperator_RELU1: + case BuiltinOperator_RELU_N1_TO_1: case BuiltinOperator_RELU6: case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_EXP: break; case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = @@ -301,6 +329,18 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + TfLiteSequenceRNNParams* params = MallocPOD(); + if (auto* sequence_rnn_params = + op->builtin_options_as_SequenceRNNOptions()) { + params->activation = + parse_activation(sequence_rnn_params->fused_activation_function()); + params->time_major = sequence_rnn_params->time_major(); + } + builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_RNN: { TfLiteRNNParams* params = MallocPOD(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { @@ -375,6 +415,24 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_DIV: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_DivOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SUB: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SubOptions()) { + params->activation = + parse_activation(schema_params->fused_activation_function()); + } + builtin_data = reinterpret_cast(params); + break; + } case BuiltinOperator_L2_NORMALIZATION: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_L2NormOptions()) { @@ -396,6 +454,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { TfLiteLSTMParams* params = MallocPOD(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { @@ -411,29 +470,21 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ResizeBilinearOptions()) { - params->new_height = schema_params->new_height(); - params->new_width = schema_params->new_width(); + params->align_corners = schema_params->align_corners(); } builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_PAD: { + break; + } case BuiltinOperator_RESHAPE: { auto* params = MallocPOD(); if (auto* schema_params = op->builtin_options_as_ReshapeOptions()) { auto* new_shape = schema_params->new_shape(); - if (!new_shape) { - error_reporter->Report("No new_shape provided for Reshape\n"); - } else { - params->num_dimensions = new_shape->Length(); - if (params->num_dimensions > sizeof(params->shape) / sizeof(int)) { - error_reporter->Report( - "Found too many dimensions in Reshape's new_shape\n"); - } else { - for (int i = 0; i < params->num_dimensions; ++i) { - params->shape[i] = new_shape->Get(i); - } - } - } + FlatBufferIntVectorToArray(sizeof(params->shape), new_shape, + params->shape, error_reporter); + params->num_dimensions = new_shape->Length(); } builtin_data = reinterpret_cast(params); break; @@ -456,6 +507,56 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_GATHER: { + TfLiteGatherParams* params = MallocPOD(); + params->axis = 0; + if (auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SPACE_TO_BATCH_ND: { + break; + } + case BuiltinOperator_BATCH_TO_SPACE_ND: { + break; + } + case BuiltinOperator_TRANSPOSE: { + break; + } + case BuiltinOperator_MEAN: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_MeanOptions()) { + params->keep_dims = schema_params->keep_dims(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_SQUEEZE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_SqueezeOptions()) { + const auto& squeeze_dims = schema_params->squeeze_dims(); + FlatBufferIntVectorToArray(sizeof(params->squeeze_dims), squeeze_dims, + params->squeeze_dims, error_reporter); + params->num_squeeze_dims = squeeze_dims->Length(); + } + builtin_data = reinterpret_cast(params); + break; + } + case BuiltinOperator_STRIDED_SLICE: { + auto* params = MallocPOD(); + if (auto* schema_params = op->builtin_options_as_StridedSliceOptions()) { + params->begin_mask = schema_params->begin_mask(); + params->end_mask = schema_params->end_mask(); + params->ellipsis_mask = schema_params->ellipsis_mask(); + params->new_axis_mask = schema_params->new_axis_mask(); + params->shrink_axis_mask = schema_params->shrink_axis_mask(); + } + builtin_data = reinterpret_cast(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h index 15659d33f37dfb2f119480ed88d2e1b81f34c145..a467df5bb4eee3f6ce814512cb8b74bf09a6a4e7 100644 --- a/tensorflow/contrib/lite/model.h +++ b/tensorflow/contrib/lite/model.h @@ -31,8 +31,8 @@ limitations under the License. // OpResolver must be defined to provide your kernel implementations to the // interpreter. This is environment specific and may consist of just the builtin // ops, or some custom operators you defined to extend tflite. -#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ -#define THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#ifndef TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#define TENSORFLOW_CONTRIB_LITE_MODEL_H_ #include #include "tensorflow/contrib/lite/error_reporter.h" @@ -45,18 +45,25 @@ namespace tflite { // or mmapped. This uses flatbuffers as the serialization format. class FlatBufferModel { public: - // Build a model based on a file. Return a nullptr in case of failure. + // Builds a model based on a file. Returns a nullptr in case of failure. static std::unique_ptr BuildFromFile( const char* filename, ErrorReporter* error_reporter = DefaultErrorReporter()); - // Build a model based on a pre-loaded flatbuffer. The caller retains + // Builds a model based on a pre-loaded flatbuffer. The caller retains // ownership of the buffer and should keep it alive until the returned object - // is destroyed. Return a nullptr in case of failure. + // is destroyed. Returns a nullptr in case of failure. static std::unique_ptr BuildFromBuffer( const char* buffer, size_t buffer_size, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Builds a model directly from a flatbuffer pointer. The caller retains + // ownership of the buffer and should keep it alive until the returned object + // is destroyed. Returns a nullptr in case of failure. + static std::unique_ptr BuildFromModel( + const tflite::Model* model_spec, + ErrorReporter* error_reporter = DefaultErrorReporter()); + // Releases memory or unmaps mmaped meory. ~FlatBufferModel(); @@ -75,7 +82,7 @@ class FlatBufferModel { bool CheckModelIdentifier() const; private: - // Load a model from `filename`. If `mmap_file` is true then use mmap, + // Loads a model from `filename`. If `mmap_file` is true then use mmap, // otherwise make a copy of the model in a buffer. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -85,8 +92,8 @@ class FlatBufferModel { ErrorReporter* error_reporter = DefaultErrorReporter(), bool use_nnapi = false); - // Load a model from `ptr` and `num_bytes` of the model file. The `ptr` has to - // remain alive and unchanged until the end of this flatbuffermodel's + // Loads a model from `ptr` and `num_bytes` of the model file. The `ptr` has + // to remain alive and unchanged until the end of this flatbuffermodel's // lifetime. // // Note, if `error_reporter` is null, then a DefaultErrorReporter() will be @@ -94,6 +101,10 @@ class FlatBufferModel { FlatBufferModel(const char* ptr, size_t num_bytes, ErrorReporter* error_reporter = DefaultErrorReporter()); + // Loads a model from Model flatbuffer. The `model` has to remain alive and + // unchanged until the end of this flatbuffermodel's lifetime. + FlatBufferModel(const Model* model, ErrorReporter* error_reporter); + // Flatbuffer traverser pointer. (Model* is a pointer that is within the // allocated memory of the data allocated by allocation's internals. const tflite::Model* model_ = nullptr; @@ -106,9 +117,9 @@ class FlatBufferModel { // model are mapped to executable function pointers (TfLiteRegistrations). class OpResolver { public: - // Find the op registration for a builtin operator by enum code. + // Finds the op registration for a builtin operator by enum code. virtual TfLiteRegistration* FindOp(tflite::BuiltinOperator op) const = 0; - // Find the op registration of a custom operator by op name. + // Finds the op registration of a custom operator by op name. virtual TfLiteRegistration* FindOp(const char* op) const = 0; virtual ~OpResolver() {} }; @@ -131,7 +142,7 @@ class InterpreterBuilder { public: InterpreterBuilder(const FlatBufferModel& model, const OpResolver& op_resolver); - // Build an interpreter given only the raw flatbuffer Model object (instead + // Builds an interpreter given only the raw flatbuffer Model object (instead // of a FlatBufferModel). Mostly used for testing. // If `error_reporter` is null, then DefaultErrorReporter() is used. InterpreterBuilder(const ::tflite::Model* model, @@ -162,4 +173,4 @@ class InterpreterBuilder { } // namespace tflite -#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_MODEL_H_ +#endif // TENSORFLOW_CONTRIB_LITE_MODEL_H_ diff --git a/tensorflow/contrib/lite/model_test.cc b/tensorflow/contrib/lite/model_test.cc index 61043866420752b552281e353be9a2b41a6aadc8..66f22fd66a9ae0d35553a1f780ef73a5c5994c99 100644 --- a/tensorflow/contrib/lite/model_test.cc +++ b/tensorflow/contrib/lite/model_test.cc @@ -20,12 +20,12 @@ limitations under the License. #include #include #include -#include #include "tensorflow/contrib/lite/model.h" #include #include "tensorflow/contrib/lite/error_reporter.h" +#include "tensorflow/contrib/lite/testing/util.h" // Comparison for TfLiteRegistration. Since TfLiteRegistration is a C object, // we must declare this in global namespace, so argument-dependent operator @@ -246,12 +246,26 @@ TEST(BasicFlatBufferModel, TestNullErrorReporter) { ASSERT_NE(interpreter->Invoke(), kTfLiteOk); } -// Test what happens if we cannot bind any of the ops. -TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { - std::string corrupted_data = "123"; - auto model = FlatBufferModel::BuildFromBuffer(corrupted_data.c_str(), - corrupted_data.length()); - ASSERT_FALSE(model); +// Test that loading model directly from a Model flatbuffer works. +TEST(BasicFlatBufferModel, TestBuildFromModel) { + TestErrorReporter reporter; + FileCopyAllocation model_allocation( + "tensorflow/contrib/lite/testdata/test_model.bin", &reporter); + ASSERT_TRUE(model_allocation.valid()); + ::flatbuffers::Verifier verifier( + reinterpret_cast(model_allocation.base()), + model_allocation.bytes()); + ASSERT_TRUE(VerifyModelBuffer(verifier)); + const Model* model_fb = ::tflite::GetModel(model_allocation.base()); + + auto model = FlatBufferModel::BuildFromModel(model_fb); + ASSERT_TRUE(model); + + std::unique_ptr interpreter; + ASSERT_EQ( + InterpreterBuilder(*model, TrivialResolver(&dummy_reg))(&interpreter), + kTfLiteOk); + ASSERT_NE(interpreter, nullptr); } // TODO(aselle): Add tests for serialization of builtin op data types. @@ -261,7 +275,7 @@ TEST(BasicFlatBufferModel, TestBuildModelFromCorruptedData) { } // namespace tflite int main(int argc, char** argv) { - // On Linux, add: tflite::LogToStderr(); + ::tflite::LogToStderr(); ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/tensorflow/contrib/lite/models/BUILD b/tensorflow/contrib/lite/models/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..6a1255b586ef04b80159156a78f0c4569a4661c5 --- /dev/null +++ b/tensorflow/contrib/lite/models/BUILD @@ -0,0 +1,26 @@ +# Model tests +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts") + +exports_files(glob([ + "testdata/*", +])) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/lite/models/smartreply/BUILD b/tensorflow/contrib/lite/models/smartreply/BUILD index fbdf19f2054cf01aec44e3fcb13d0d0a2ff6f914..733c3f4c7fa0605f24a1e6b4c458e34310c079c4 100644 --- a/tensorflow/contrib/lite/models/smartreply/BUILD +++ b/tensorflow/contrib/lite/models/smartreply/BUILD @@ -1,7 +1,92 @@ package(default_visibility = ["//visibility:public"]) +load("//tensorflow/contrib/lite:build_def.bzl", "tflite_copts", "gen_selected_ops") + licenses(["notice"]) # Apache 2.0 +gen_selected_ops( + name = "smartreply_ops", + model = "@tflite_smartreply//:smartreply.tflite", +) + +cc_library( + name = "custom_ops", + srcs = [ + "ops/extract_feature.cc", + "ops/normalize.cc", + "ops/predict.cc", + ":smartreply_ops", + ], + copts = tflite_copts(), + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + "@farmhash_archive//:farmhash", + ], +) + +cc_library( + name = "predictor_lib", + srcs = ["predictor.cc"], + hdrs = ["predictor.h"], + copts = tflite_copts(), + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/tools:mutable_op_resolver", + "@com_google_absl//absl/strings", + "@com_googlesource_code_re2//:re2", + ], +) + +cc_test( + name = "extract_feature_op_test", + size = "small", + srcs = ["ops/extract_feature_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + "@farmhash_archive//:farmhash", + ], +) + +cc_test( + name = "normalize_op_test", + size = "small", + srcs = ["ops/normalize_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + +cc_test( + name = "predict_op_test", + size = "small", + srcs = ["ops/predict_test.cc"], + deps = [ + ":custom_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite:string_util", + "//tensorflow/contrib/lite/kernels:builtin_ops", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml new file mode 100644 index 0000000000000000000000000000000000000000..75ed9432c8fcdfd77a64d3c659e6336c977cdda2 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/AndroidManifest.xml @@ -0,0 +1,38 @@ + + + + + + + + + + + + + + + + diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..f8767b443a2aa64b666c3b6bfb7db30cc0be62ea --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD @@ -0,0 +1,65 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +load( + "//tensorflow/contrib/lite:build_def.bzl", + "tflite_copts", + "tflite_jni_binary", +) + +filegroup( + name = "assets", + srcs = [ + "@tflite_smartreply//:model_files", + ], +) + +android_binary( + name = "SmartReplyDemo", + srcs = glob(["java/**/*.java"]), + assets = [":assets"], + assets_dir = "", + custom_package = "com.example.android.smartreply", + manifest = "AndroidManifest.xml", + nocompress_extensions = [ + ".tflite", + ], + resource_files = glob(["res/**"]), + tags = ["manual"], + deps = [ + ":smartreply_runtime", + "@androidsdk//com.android.support:support-v13-25.2.0", + "@androidsdk//com.android.support:support-v4-25.2.0", + ], +) + +cc_library( + name = "smartreply_runtime", + srcs = ["libsmartreply_jni.so"], + visibility = ["//visibility:public"], +) + +tflite_jni_binary( + name = "libsmartreply_jni.so", + deps = [ + ":smartreply_jni_lib", + ], +) + +cc_library( + name = "smartreply_jni_lib", + srcs = [ + "smartreply_jni.cc", + ], + copts = tflite_copts(), + linkopts = [ + "-lm", + "-ldl", + ], + deps = [ + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/models/smartreply:predictor_lib", + ], + alwayslink = 1, +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..3c882ffc43fde577801428151a43b592e8faaed1 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/BUILD @@ -0,0 +1,15 @@ +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(glob(["*"])) + +filegroup( + name = "assets_files", + srcs = glob( + ["**/*"], + exclude = [ + "BUILD", + ], + ), +) diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt new file mode 100644 index 0000000000000000000000000000000000000000..a0a5b46b5f8d5fd6a0297c8056bb2fb9b6ad9ada --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/assets/backoff_response.txt @@ -0,0 +1,16 @@ +Ok +Yes +No +👍 +☺ +😟 +❤️ +Lol +Thanks +Got it +Done +Nice +I don't know +What? +Why? +What's up? diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java new file mode 100644 index 0000000000000000000000000000000000000000..02fec9ae5e971ad756ae6c2b0149a6aacfa27cad --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/MainActivity.java @@ -0,0 +1,99 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.app.Activity; +import android.os.Bundle; +import android.os.Handler; +import android.util.Log; +import android.view.View; +import android.widget.Button; +import android.widget.EditText; +import android.widget.TextView; + +/** + * The main (and only) activity of this demo app. Displays a text box which updates as messages are + * received. + */ +public class MainActivity extends Activity { + private static final String TAG = "SmartReplyDemo"; + private SmartReplyClient client; + + private Button sendButton; + private TextView messageTextView; + private EditText messageInput; + + private Handler handler; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + Log.v(TAG, "onCreate"); + setContentView(R.layout.main_activity); + + client = new SmartReplyClient(getApplicationContext()); + handler = new Handler(); + + sendButton = (Button) findViewById(R.id.send_button); + sendButton.setOnClickListener( + (View v) -> { + send(messageInput.getText().toString()); + }); + + messageTextView = (TextView) findViewById(R.id.message_text); + messageInput = (EditText) findViewById(R.id.message_input); + } + + @Override + protected void onStart() { + super.onStart(); + Log.v(TAG, "onStart"); + handler.post( + () -> { + client.loadModel(); + }); + } + + @Override + protected void onStop() { + super.onStop(); + Log.v(TAG, "onStop"); + handler.post( + () -> { + client.unloadModel(); + }); + } + + private void send(final String message) { + handler.post( + () -> { + messageTextView.append("Input: " + message + "\n"); + + SmartReply[] ans = client.predict(new String[] {message}); + for (SmartReply reply : ans) { + appendMessage("Reply: " + reply.getText()); + } + appendMessage("------"); + }); + } + + private void appendMessage(final String message) { + handler.post( + () -> { + messageTextView.append(message + "\n"); + }); + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java new file mode 100644 index 0000000000000000000000000000000000000000..3357fd17c11f870d1b0998bb26ffa9abf149686b --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReply.java @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.support.annotation.Keep; + +/** + * SmartReply contains predicted message, and confidence. + * + *

NOTE: this class used by JNI, class name and constructor should not be obfuscated. + */ +@Keep +public class SmartReply { + + private final String text; + private final float score; + + @Keep + public SmartReply(String text, float score) { + this.text = text; + this.score = score; + } + + public String getText() { + return text; + } + + public float getScore() { + return score; + } +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java new file mode 100644 index 0000000000000000000000000000000000000000..d5b1ac0ffbc47283aa0c1bf68c0a85ad6228cdcc --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/java/com/example/android/smartreply/SmartReplyClient.java @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +package com.example.android.smartreply; + +import android.content.Context; +import android.content.res.AssetFileDescriptor; +import android.support.annotation.Keep; +import android.support.annotation.WorkerThread; +import android.util.Log; +import java.io.BufferedReader; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.MappedByteBuffer; +import java.nio.channels.FileChannel; +import java.util.ArrayList; +import java.util.List; + +/** Interface to load TfLite model and provide predictions. */ +public class SmartReplyClient implements AutoCloseable { + private static final String TAG = "SmartReplyDemo"; + private static final String MODEL_PATH = "smartreply.tflite"; + private static final String BACKOFF_PATH = "backoff_response.txt"; + private static final String JNI_LIB = "smartreply_jni"; + + private final Context context; + private long storage; + private MappedByteBuffer model; + + private volatile boolean isLibraryLoaded; + + public SmartReplyClient(Context context) { + this.context = context; + } + + public boolean isLoaded() { + return storage != 0; + } + + @WorkerThread + public synchronized void loadModel() { + if (!isLibraryLoaded) { + System.loadLibrary(JNI_LIB); + isLibraryLoaded = true; + } + + try { + model = loadModelFile(); + String[] backoff = loadBackoffList(); + storage = loadJNI(model, backoff); + } catch (IOException e) { + Log.e(TAG, "Fail to load model", e); + return; + } + } + + @WorkerThread + public synchronized SmartReply[] predict(String[] input) { + if (storage != 0) { + return predictJNI(storage, input); + } else { + return new SmartReply[] {}; + } + } + + @WorkerThread + public synchronized void unloadModel() { + close(); + } + + @Override + public synchronized void close() { + if (storage != 0) { + unloadJNI(storage); + storage = 0; + } + } + + private MappedByteBuffer loadModelFile() throws IOException { + AssetFileDescriptor fileDescriptor = context.getAssets().openFd(MODEL_PATH); + FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor()); + try { + FileChannel fileChannel = inputStream.getChannel(); + long startOffset = fileDescriptor.getStartOffset(); + long declaredLength = fileDescriptor.getDeclaredLength(); + return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength); + } finally { + inputStream.close(); + } + } + + private String[] loadBackoffList() throws IOException { + List labelList = new ArrayList(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(context.getAssets().open(BACKOFF_PATH))); + String line; + while ((line = reader.readLine()) != null) { + if (!line.isEmpty()) { + labelList.add(line); + } + } + reader.close(); + String[] ans = new String[labelList.size()]; + labelList.toArray(ans); + return ans; + } + + @Keep + private native long loadJNI(MappedByteBuffer buffer, String[] backoff); + + @Keep + private native SmartReply[] predictJNI(long storage, String[] text); + + @Keep + private native void unloadJNI(long storage); +} diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml new file mode 100644 index 0000000000000000000000000000000000000000..23b4cadc007a4457d33b8c8fecf9b1e7b7436320 --- /dev/null +++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/res/layout/main_activity.xml @@ -0,0 +1,44 @@ + + + + + + + + + + +